茄子的个人空间

电磁比赛总结

字数统计: 993阅读时长: 4 min
2022/09/05
loading

电磁比赛总结

代码总结

  1. 要会静态分析资源占有率,特别是当服务器内存资源不足的时候,提前做好静态分析,设置合理的运行参数,才能提升效率。比如本次实验过程中,做数据增强预处理数据时需要占用大量内存资源,参数设置过大,会导致运行一半后因为内存不足,进程被killed掉,参数设置过小效率又变得很低。
  2. 当数据很多,需要占用大量内存时,不要将数据转换为pandas的DataFrame对象,因为它会吃掉更多的内存,此外使用apply方法对数据进行逐行处理的时候,即使使用了加速方法,也没有将数据存储为list然后使用多进程方法处理高效。在本次实验中,后者的速度至少是前者的5倍。
  3. 尽量使用class对代码进行封装,而不是使用一个个单独的函数
  4. 尽量保证函数的功能单一,这样的函数更容易被复用,多写几个函数没有关系
  5. 几种常用的保存/读取数据方式:
1
2
3
4
5
6
7
8
9
10
  import pickle
  import json
 
 # 使用pickle
 pickle.dump(data, file=open(file_name,"wb"))
 data = pickle.load(file=open(file_name, "rb"))
 
 # 使用json
 json.dump(data, file=open(file_name,"w"))
 data = json.load(file=open(file_name, "r"))
  1. 使用coment来可视化训练过程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
  from comet_ml import Experiment
 
 experiment = Experiment(project_name=args.project_name,
                         api_key=args.api_key)
 
 experiment.log_parameters(vars(args))
 
 with experiment.train():
     experiment.log_metric('epoch_loss', train_loss, step=epoch)
 
 with experiment.validate():
     experiment.log_metric('epoch_loss', val_loss, step=epoch)
 
 experiment.end()
  1. 使用argparse来构建超参数管理入口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  import argparse
  parser = argparse.ArgumentParser()
 
   parser.add_argument("--drop_data_rate", default=0.9, type=float,required=False)
 
   args = parser.parse_args() #将参数变成可调用对象
 
   pp.pprint(vars(args)) # 打印参数
 
   def str2bool(v):
     if v.lower() in ('yes', 'true', 't', 'y', '1'):
         return True
     elif v.lower() in ('no', 'false', 'f', 'n', '0'):
         return False
     else:
         raise argparse.ArgumentTypeError('Unsupported value encountered.')

注: 使用add_argument()添加参数的时候有一个大坑,=当添加参数的type为bool的时候,不能设置type=bool,需要自定义str2bool函数,然后设置type=str2bool。

  1. 使用multiprocessing提高处理速度:凡是能够并行处理且数据量巨大的任务,应尽量使用多进程编程
1
2
3
4
5
 from multiprocessing import Pool
 from tqdm import tqdm
 
  with Pool(n_workers) as p: # n_workers表示进程数
     result = list(tqdm(p.imap(function, data), total=len(data))) # 这里使用tqdm显示进度条,需要调用imap,如果不用显示进度条调用map函数
  1. 使用plt将图片保存到内存中,提高处理效率
1
2
3
4
5
 from io import BytesIO
 import matplotlib.pyplot as plt
 
 p_bytes = BytesIO() #申请内存
 plt.savefig(p_bytes, format='png')
  1. 关于采样的函数
1
2
3
4
5
6
7
 import numpy as np
 import random
 
 indents = np.random.choice(2, n,replace=True,p=[pro1, pro2]) # 有放回的从[0,1]中取n个值,取0的概率为pro1,取1的概率为pro2
 
 random.shuffle(data)
 sample_resulut = random.sample(data, n)

11.PIL库读取图片

1
2
3
4
 import PIL
 
 img = PIL.Image.open(image_byte)
 img = img.convert('RGB')
  1. plt 画散点图和图片画布设置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
 import matplotlib.pyplot as plt
 
 plt.rcParams['figure.figsize'] = (4, 4) # 设置图片大小
 
 plt.rcParams['savefig.dpi'] = 56 #设置像素
 plt.rcParams['figure.dpi'] = 56 #设置像素
 plt.axis('off') #关闭坐标
 
 # 多张子图
 fig, axs = plt.subplots(n, 1)
 for idx, dim in enumerate(range(n)):
     axs[idx].axis('off')
     axs[idx].scatter(x, y, s=0.1)
 
 # 单张子图
 fig, axs = plt.subplots(1, 1)
 axs.axis('off')
 axs.scatter(x, y, s=0.1)
 
 plt.savefig(file_name) # 保存到磁盘
 p_bytes = BytesIO()
 plt.savefig(p_bytes, format='png') # 保存到内存
 plt.cla()
 plt.close('all') # 清空缓存
CATALOG