Pytorch实验常用代码段汇总

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

但加了这一行,似乎运行结果不一样了。

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train codetxtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer)  ## good job!
nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if os.path.exists(txtfile):
  os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has itdef accuracy(logit, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  output = F.softmax(logit, dim=1) # but actually not need it
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True) # _, pred = logit.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # it seems this is a bug, when not all batch has same size, the mean of accuracy of each batch is not the mean of accu of all dataset
  return res

prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage
prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch codeclass Logger(object):
  '''Save training process to log file with simple plot function.'''
  def __init__(self, fpath, title=None, resume=False):
    self.file = None
    self.resume = resume
    self.title = '' if title == None else title
    if fpath is not None:
      if resume:
        self.file = open(fpath, 'r')
        name = self.file.readline()
        self.names = name.rstrip().split('\t')
        self.numbers = {}
        for _, name in enumerate(self.names):
          self.numbers[name] = []

        for numbers in self.file:
          numbers = numbers.rstrip().split('\t')
          for i in range(0, len(numbers)):
            self.numbers[self.names[i]].append(numbers[i])
        self.file.close()
        self.file = open(fpath, 'a')
      else:
        self.file = open(fpath, 'w')

  def set_names(self, names):
    if self.resume:
      pass
    # initialize numbers as empty list
    self.numbers = {}
    self.names = names
    for _, name in enumerate(self.names):
      self.file.write(name)
      self.file.write('\t')
      self.numbers[name] = []
    self.file.write('\n')
    self.file.flush()

  def append(self, numbers):
    assert len(self.names) == len(numbers), 'Numbers do not match names'
    for index, num in enumerate(numbers):
      self.file.write("{0:.4f}".format(num))
      self.file.write('\t')
      self.numbers[self.names[index]].append(num)
    self.file.write('\n')
    self.file.flush()

  def plot(self, names=None):
    names = self.names if names == None else names
    numbers = self.numbers
    for _, name in enumerate(names):
      x = np.arange(len(numbers[name]))
      plt.plot(x, np.asarray(numbers[name]))
    plt.legend([self.title + '(' + name + ')' for name in names])
    plt.grid(True)

  def close(self):
    if self.file is not None:
      self.file.close()
# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for epoch in range(100):
  logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码

# argparser 命令行工具有一个坑的地方是,无法设置 bool 变量, flag=FALSE, 然后会解释为 字符串,仍然当做 True

发现可以使用如下命令来进行修补,来自 ICML-19-SGC github 上代码

parser.add_argument('--test', action='store_true', default=False, help='inductive training.')

当命令行出现 test 字样时,则为 args.test = true

若未出现 test 字样,则为 args.test = false

6. 使用shell 变量来设置所使用的显卡, 便于利用shell 脚本进行程序的串行,从而挂起来跑。或者多开几个 screen 进行同一张卡上多个程序并行跑,充分利用显卡的内存。

命令行中使用如下语句,或者把语句写在 shell 脚本中 # 不要忘了 export

export CUDA_VISIBLE_DEVICES=1 #设置当前可用显卡为编号为1的显卡(从 0 开始编号),即不在 0 号上跑
export CUDA_VISIBlE_DEVICES=0,1 # 设置当前可用显卡为 0,1 显卡,当 0 用满后,就会自动使用 1 显卡

一般经验,即使多个程序并行跑时,即使显存完全足够,单个程序的速度也会变慢,这可能是由于还有 cpu 和内存的限制。

这里显存占用不是阻碍,应该主要看GPU 利用率(也就是计算单元的使用,如果达到了 99% 就说明程序过多了。)

使用 watch nvidia-smi 来监测每个程序当前是否在正常跑。

7. 使用 python 时间戳来保存并进行区别不同的 result 文件

  参照自己很早之前写的 co-training 的代码

8. 把训练时 命令行窗口的 print 输出全部保存到一个 log 文件:(参照 DIEN)

mkdir dnn_save_path
mkdir dnn_best_model
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &

并且使用如下命令 | tee 命令则可以同时保存到文件并且写到命令行输出:

python script/train.py train DIEN | tee train_dein2.log

9. git clone 可以用来下载 github 上的代码,更快。(由 DIEN 的下载)

git clone https://github.com/mouna99/dien.git 使用这个命令可以下载 github 上的代码库

10. (来自 DIEN ) 对于命令行参数不一定要使用 argparser 来读取,也可以直接使用 sys.argv 读取,不过这样的话,就无法指定关键字参数,只能使用位置参数。

### run.sh ###
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &
#############

if __name__ == '__main__':
  if len(sys.argv) == 4:
    SEED = int(sys.argv[3]) # 0,1,2,3
  else:
    SEED = 3
  tf.set_random_seed(SEED)
  numpy.random.seed(SEED)
  random.seed(SEED)
  if sys.argv[1] == 'train':
    train(model_type=sys.argv[2], seed=SEED)
  elif sys.argv[1] == 'test':
    test(model_type=sys.argv[2], seed=SEED)
  else:
    print('do nothing...')

11.代码的一种逻辑:time_point 是一个参数变量,可以有两种方案来处理

一种直接在外面判断:

#适用于输出变量的个数不同的情况
if time_point:
  A, B, C = f1(x, y, time_point=True)
else:
  A, B = f1(x, y, time_point=False)
# 适用于输出变量个数和类型相同的情况
C, D = f2(x, y, time_point=time_point)

12. 写一个 shell 脚本文件来进行调节超参数, 来自 [NIPS-20 Grand]

mkdir cora
for num in $(seq 0 99) do
  python train_grand.py --hidden 32 --lr 0.01 --patience 200 --seed $num --dropnode_rate 0.5 > cora/"$num".txt
done

13. 使用 或者 不使用 cuda 运行结果可能会不一样,有细微差别。

cuda 也有一个相关的随机数种子的参数,当不使用 cuda 时,这一个随机数种子没有起到作用,因此可能会得到不同的结果。

来自 NIPS-20 Grand (2020.11.18)的实验结果发现。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

(0)

相关推荐

  • 使用anaconda安装pytorch的实现步骤

    使用anaconda安装pytorch过程中出现的问题 在使用anaconda安装pytorch的过程中,出现了很多问题,也在网上查了很多相关的资料,但是都没有奏效.在很多次尝试之后才发现是要先装numpy的原因-下面开始记录一下过程中的一些尝试和错误经验,供大家参考学习.先按照正常步骤一步一步来安装. 使用anaconda直接从网上下载 首先,打开anaconda navigator,然后创建一个环境来放pytorch. 先点击下面的create,然后创建一个新环境. 选择你的python版本

  • 使用Pytorch搭建模型的步骤

    本来是只用Tenorflow的,但是因为TF有些Numpy特性并不支持,比如对数组使用列表进行切片,所以只能转战Pytorch了(pytorch是支持的).还好Pytorch比较容易上手,几乎完美复制了Numpy的特性(但还有一些特性不支持),怪不得热度上升得这么快. 1  模型定义 和TF很像,Pytorch也通过继承父类来搭建自定义模型,同样也是实现两个方法.在TF中是__init__()和call(),在Pytorch中则是__init__()和forward().功能类似,都分别是初始化

  • Anaconda+spyder+pycharm的pytorch配置详解(GPU)

    第一步 : 从清华大学开源软件镜像站下载Anaconda:https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D 安装过程中需要勾选如下图 装好后测试是否装好,先配置环境变量(可能anaconda安装好后自己就有了) 打开CMD,输入代码 conda list 回车出现包的信息则说明安装完成 打开Anaconda Navigator(桌面没有的话就点击左下角看最近添加)可以看到spyder已经下好了 第二步:下载CUDA(GP

  • pytorch学习教程之自定义数据集

    自定义数据集 在训练深度学习模型之前,样本集的制作非常重要.在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,下面完整的试验自定义样本集的整个流程. 开发环境 Ubuntu 18.04 pytorch 1.0 pycharm 实验目的 掌握pytorch中数据集相关的API接口和类 熟悉数据集制作的整个流程 实验过程 1.收集图像样本 以简单的猫狗二分类为例,可以在网上下载一些猫狗图片.创建以下目录: data-------------根目录 data/test-------测

  • pytorch简介

    一.Pytorch是什么?   Pytorch是torch的python版本,是由Facebook开源的神经网络框架,专门针对 GPU 加速的深度神经网络(DNN)编程.Torch 是一个经典的对多维矩阵数据进行操作的张量(tensor )库,在机器学习和其他数学密集型应用有广泛应用.与Tensorflow的静态计算图不同,pytorch的计算图是动态的,可以根据计算需要实时改变计算图.但由于Torch语言采用 Lua,导致在国内一直很小众,并逐渐被支持 Python 的 Tensorflow

  • 详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行.squeeze(a)就是将a中所有为1的维度删掉.不为1的维度没有影响.a.squeeze(N) 就是去掉a中指定的维数为一的维度.还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度. 再看torch.unsque

  • 简述python&pytorch 随机种子的实现

    随机数广泛应用在科学研究, 但是计算机无法产生真正的随机数, 一般成为伪随机数. 它的产生过程: 给定一个随机种子(一个正整数), 根据随机算法和种子产生随机序列. 给定相同的随机种子, 计算机产生的随机数列是一样的(这也许是伪随机的原因). 随机种子是什么? 随机种子是针对随机方法而言的. 随机方法:常见的随机方法有 生成随机数,以及其他的像 随机排序 之类的,后者本质上也是基于生成随机数来实现的.在深度学习中,比较常用的随机方法的应用有:网络的随机初始化,训练集的随机打乱等. 随机种子的取值

  • 如何使用Pytorch搭建模型

    1  模型定义 和TF很像,Pytorch也通过继承父类来搭建模型,同样也是实现两个方法.在TF中是__init__()和call(),在Pytorch中则是__init__()和forward().功能类似,都分别是初始化模型内部结构和进行推理.其它功能比如计算loss和训练函数,你也可以继承在里面,当然这是可选的.下面搭建一个判别MNIST手写字的Demo,首先给出模型代码: import numpy as np import matplotlib.pyplot as plt import

  • 详解anaconda离线安装pytorchGPU版

    在网速不好的情况下,如何用离线的方式安装pytorch.这里默认大家已经安装了anaconda了. 安装Nvidia驱动.cuda.cudnn等依赖 首先安装vs社区版,如果已经安装过可以跳过这一步,下载地址 安装以下两个组件即可,不用全部装上. 之后安装nvidia驱动,注意自己显卡和驱动的对应关系,下载地址 我的显卡是940M,对应如下选项: 安装cuda 这里要注意查看驱动和cuda的对应关系,首先查看自己下载的驱动文件名, 可以看到最开始有个数字,这个就是驱动版本,和cuda会有下图类似

  • Pytorch实验常用代码段汇总

    1. 大幅度提升 Pytorch 的训练速度 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True 但加了这一行,似乎运行结果不一样了. 2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖 # from co-teaching train codetxtfile = save_dir + &q

  • PHP操作MySQL的常用代码段梳理与总结

    这篇文章为大家介绍,实用的PHP网站实际开发中常用到的操作mysql数据库的代码段,所有代码均可靠执行,此文将持续更新!!! 1.向数据库插入数据表 <?php $con = mysql_connect("[数据库地址]","[数据库用户名]","[数据库密码]");//创建MySQL连接 mysql_select_db("[数据库名]", $con);//选择MySQL数据库 $sql = "CREATE T

  • js常用代码段整理

    每段代码前边都有功能注解和参数要求等说明文字,难度不大也就没做更多注释. 为看得清楚,这里依先后顺序做个小目录: 重写window.setTimeout, 理解递归程序的返回规律, 截取长字符串, 取得元素在页面中的绝对位置, 统计.去除重复字符(多种方法实现), 把有序的数组元素随机打乱(多种方法实现). 复制代码 代码如下: /* 功能:修改 window.setTimeout,使之可以传递参数和对象参数 (同样可用于setInterval) 使用方法: setTimeout(回调函数,时间

  • js常用代码段收集

    每段代码前边都有功能注解和参数要求等说明文字,难度不大也就没做更多注释. 为看得清楚,这里依先后顺序做个小目录: 重写window.setTimeout, 理解递归程序的返回规律, 截取长字符串, 取得元素在页面中的绝对位置, 统计.去除重复字符(多种方法实现), 把有序的数组元素随机打乱(多种方法实现). 复制代码 代码如下: /* 功能:修改 window.setTimeout,使之可以传递参数和对象参数 (同样可用于setInterval) 使用方法: setTimeout(回调函数,时间

  • oracle表空单清理常用代码段整理

    1.查询表空间使用情况: sqlplus system/manager@topprod 复制代码 代码如下: SQL>@q_tbsFREE 2.查询temp使用方法: sqlplus system/manager@topprod 复制代码 代码如下: SQL>SELECT d.tablespace_name tablespace_name , d.status tablespace_status , NVL(a.bytes, 0) tablespace_size , NVL(t.bytes,

  • javascript常用代码段搜集

    1.json转字符串 复制代码 代码如下: function json2str(o) {     var arr = [];     var fmt = function (s) {         if (typeof s == 'object' && s != null) return json2str(s);         return /^(string|number)$/.test(typeof s) ? "'" + s + "'" :

  • python 网络编程常用代码段

    服务器端代码: # -*- coding: cp936 -*- import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)#初始化socket sock.bind(("127.0.0.1", 8001))#绑定本机地址,8001端口 sock.listen(5)#等待客户连接 while True: print "waiting client connection..." connec

  • Android开发常用经典代码段集锦

    本文实例总结了Android开发常用经典代码段.分享给大家供大家参考,具体如下: 1.图片旋转 Bitmap bitmapOrg = BitmapFactory.decodeResource(this.getContext().getResources(), R.drawable.moon); Matrix matrix = new Matrix(); matrix.postRotate(-90);//旋转的角度 Bitmap resizedBitmap = Bitmap.createBitma

  • ASP.NET程序中常用代码汇总

    1. 打开新的窗口并传送参数: //传送参数: response.write("<script>window.open('*.aspx?id="+this.DropDownList1.SelectIndex+"&id1="++"')</script>") //接收参数: string a = Request.QueryString("id"); string b = Request.QueryS

  • PHP常用的小程序代码段

    本文实例讲述了PHP常用的小程序代码段.分享给大家供大家参考,具体如下: 1.计算两个时间的相差几天 $startdate=strtotime("2009-12-09"); $enddate=strtotime("2009-12-05"); 上面的php时间日期函数strtotime已经把字符串日期变成了时间戳,这样只要让两数值相减,然后把秒变成天就可以了,比较的简单,如下: $days=round(($enddate-$startdate)/3600/24) ;

随机推荐