电磁比赛总结
代码总结
要会静态分析资源占有率,特别是当服务器内存资源不足的时候,提前做好静态分析,设置合理的运行参数,才能提升效率。比如本次实验过程中,做数据增强预处理数据时需要占用大量内存资源,参数设置过大,会导致运行一半后因为内存不足,进程被killed掉,参数设置过小效率又变得很低。
当数据很多,需要占用大量内存时,不要将数据转换为pandas的DataFrame对象,因为它会吃掉更多的内存,此外使用apply方法对数据进行逐行处理的时候,即使使用了加速方法,也没有将数据存储为list然后使用多进程方法处理高效。在本次实验中,后者的速度至少是前者的5倍。
尽量使用class对代码进行封装,而不是使用一个个单独的函数
尽量保证函数的功能单一,这样的函数更容易被复用,多写几个函数没有关系
几种常用的保存/读取数据方式:
1 2 3 4 5 6 7 8 9 10 import pickle import json pickle.dump(data, file=open (file_name,"wb" )) data = pickle.load(file=open (file_name, "rb" )) json.dump(data, file=open (file_name,"w" )) data = json.load(file=open (file_name, "r" ))
使用coment来可视化训练过程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 from comet_ml import Experiment experiment = Experiment(project_name=args.project_name, api_key=args.api_key) experiment.log_parameters(vars (args)) with experiment.train(): experiment.log_metric('epoch_loss' , train_loss, step=epoch) with experiment.validate(): experiment.log_metric('epoch_loss' , val_loss, step=epoch) experiment.end()
使用argparse来构建超参数管理入口
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import argparse parser = argparse.ArgumentParser() parser.add_argument("--drop_data_rate" , default=0.9 , type =float ,required=False ) args = parser.parse_args() pp.pprint(vars (args)) def str2bool (v ): if v.lower() in ('yes' , 'true' , 't' , 'y' , '1' ): return True elif v.lower() in ('no' , 'false' , 'f' , 'n' , '0' ): return False else : raise argparse.ArgumentTypeError('Unsupported value encountered.' )
注: 使用add_argument()添加参数的时候有一个大坑,=当添加参数的type为bool的时候,不能设置type=bool,需要自定义str2bool函数,然后设置type=str2bool。
使用multiprocessing提高处理速度:凡是能够并行处理且数据量巨大的任务,应尽量使用多进程编程
1 2 3 4 5 from multiprocessing import Pool from tqdm import tqdm with Pool(n_workers) as p: result = list (tqdm(p.imap(function, data), total=len (data)))
使用plt将图片保存到内存中,提高处理效率
1 2 3 4 5 from io import BytesIO import matplotlib.pyplot as plt p_bytes = BytesIO() plt.savefig(p_bytes, format ='png' )
关于采样的函数
1 2 3 4 5 6 7 import numpy as np import random indents = np.random.choice(2 , n,replace=True ,p=[pro1, pro2]) random.shuffle(data) sample_resulut = random.sample(data, n)
11.PIL库读取图片
1 2 3 4 import PIL img = PIL.Image.open (image_byte) img = img.convert('RGB' )
plt 画散点图和图片画布设置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 import matplotlib.pyplot as plt plt.rcParams['figure.figsize' ] = (4 , 4 ) plt.rcParams['savefig.dpi' ] = 56 plt.rcParams['figure.dpi' ] = 56 plt.axis('off' ) fig, axs = plt.subplots(n, 1 ) for idx, dim in enumerate (range (n)): axs[idx].axis('off' ) axs[idx].scatter(x, y, s=0.1 ) fig, axs = plt.subplots(1 , 1 ) axs.axis('off' ) axs.scatter(x, y, s=0.1 ) plt.savefig(file_name) p_bytes = BytesIO() plt.savefig(p_bytes, format ='png' ) plt.cla() plt.close('all' )