RELATEED CONSULTING
相关咨询
选择下列产品马上在线沟通
服务时间:8:30-17:00
你可能遇到了下面的问题
关闭右侧工具栏

新闻中心

这里有您想知道的互联网营销解决方案
Pytorch的使用技巧有哪些

本篇内容主要讲解“Pytorch的使用技巧有哪些”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch的使用技巧有哪些”吧!

创新互联专注于洮北网站建设服务及定制,我们拥有丰富的企业做网站经验。 热诚为您提供洮北营销型网站建设,洮北网站制作、洮北网页设计、洮北网站官网定制、微信平台小程序开发服务,打造洮北网络公司原创品牌,更为您提供洮北网站排名全网营销落地服务。

一、初级“法宝”,sys.stdout

训练模型,最常看的指标就是 Loss。我们可以根据 Loss 的收敛情况,初步判断模型训练的好坏。

如果,Loss 值突然上升了,那说明训练有问题,需要检查数据和代码。

如果,Loss 值趋于稳定,那说明训练完毕了。

观察 Loss 情况,最直观的方法,就是绘制 Loss 曲线图。

Pytorch的使用技巧有哪些

通过绘图,我们可以很清晰的看到,左图还有收敛空间,而右图已经完全收敛。

通过 Loss 曲线,我们可以分析模型训练的好坏,模型是否训练完成,起到一个很好的“监控”作用。

绘制 Loss 曲线图,第一步就是需要保存训练过程中的 Loss 值。

一个最简单的方法是使用,sys.stdout 标准输出重定向,简单好用,实乃“炼丹”必备“良宝”。

import os
import sys
class Logger():
    def __init__(self, filename="log.txt"):
        self.terminal = sys.stdout
        self.log = open(filename, "w")
 
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
 
    def flush(self):
        pass
 
sys.stdout = Logger()
 
print("Jack Cui")
print("https://cuijiahua.com")
print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代码很简单,创建一个 log.py 文件,自己写一个 Logger 类,并采用 sys.stdout 重定向输出。

在 Terminal 中,不仅可以使用 print 打印结果,同时也会将结果保存到 log.txt 文件中。

Pytorch的使用技巧有哪些

运行 log.py,打印 print 内容的同时,也将内容写入了 log.txt 文件中。

使用这个代码,就可以在打印 Loss 的同时,将结果保存到指定的 txt 中,比如保存上篇文章训练 UNet 的 Loss。

Pytorch的使用技巧有哪些

二、中级“法宝”,matplotlib

Matplotlib 是一个 Python 的绘图库,简单好用。

简单几行命令,就可以绘制曲线图、散点图、条形图、直方图、饼图等等。

在深度学习中,一般就是绘制曲线图,比如 Loss 曲线、Acc 曲线。

举一个,简单的例子。

使用 sys.stdout 保存的 train_loss.txt,绘制 Loss 曲线。

Pytorch的使用技巧有哪些

train_loss.txt 下载地址:点击查看

思路非常简单,读取 txt 内容,解析 txt 内容,使用 Matplotlib 绘制曲线。

import matplotlib.pyplot as plt
# Jupyter notebook 中开启
# %matplotlib inline
with open('train_loss.txt', 'r') as f:
    train_loss = f.readlines()
    train_loss = list(map(lambda x:float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

指定 x 和 y 对应的值,就可以绘制。

Pytorch的使用技巧有哪些

是不是很简单?

三、中级“法宝”,Logging

说到保存日志,那不得不提 Python 的内置标准模块 Logging,它主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等,同时,我们也可以设置日志的输出格式。

import logging

def get_logger(LEVEL, log_file = None):
    head = '[%(asctime)-15s] [%(levelname)s] %(message)s'
    if LEVEL == 'info':
        logging.basicConfig(level=logging.INFO, format=head)
    elif LEVEL == 'debug':
        logging.basicConfig(level=logging.DEBUG, format=head)
    logger = logging.getLogger()
    if log_file != None:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)
    return logger

logger = get_logger('info')

logger.info('Jack Cui')
logger.info('https://cuijiahua.com')
logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要几行代码,进行一个简单的封装使用。使用函数 get_logger 创建一个级别为 info 的 logger,如果指定 log_file,则会对日志进行保存。

Pytorch的使用技巧有哪些

logging 默认支持的日志一共有 5 个等级:

Pytorch的使用技巧有哪些

日志级别等级 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默认的日志级别设置为 WARNING,也就是说如果不指定日志级别,只会显示大于等于 WARNING 级别的日志。

例如:

import logging
logging.debug("debug_msg")
logging.info("info_msg")
logging.warning("warning_msg")
logging.error("error_msg")
logging.critical("critical_msg")

运行结果:

WARNING:root:warning_msg
ERROR:root:error_msg
CRITICAL:root:critical_msg

可以看到 info 和 debug 级别的日志不会输出,默认的日志格式也比较简单。

默认的日志格式为日志级别:Logger名称:用户输出消息

当然,我们可以通过,logging.basicConfig 的 format 参数,设置日志格式。

Pytorch的使用技巧有哪些

字段有很多,可谓应有尽有,足以满足我们定制化的需求。

四、高级“法宝”,TensorboardX

上文介绍的“法宝”,并非针对深度学习“炼丹”使用的工具。

而 TensorboardX 则不同,它是专门用于深度学习“炼丹”的高级“法宝”。

早些时候,很多人更喜欢用 Tensorflow 的原因之一,就是 Tensorflow 框架有个一个很好的可视化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起来费劲儿不说,还有很多 Bug。

Pytorch 1.1.0 版本发布后,打破了这个局面,TensorBoard 成为了 Pytorch 的正式可用组件。

在 Pytorch 中,这个可视化工具叫做 TensorBoardX,其实就是针对 Tensorboard 的一个封装,使得 PyTorch 用户也能够调用 Tensorboard。

TensorboardX 安装也非常简单,使用 pip 即可安装。

pip install tensorboardX

tensorboardX 使用也很简单,编写如下代码。

from tensorboardX import SummaryWriter

# 创建 writer1 对象
# log 会保存到 runs/exp 文件夹中
writer1 = SummaryWriter('runs/exp')

# 使用默认参数创建 writer2 对象
# log 会保存到 runs/日期_用户名 格式的文件夹中
writer2 = SummaryWriter()

# 使用 commet 参数,创建 writer3 对象
# log 会保存到 runs/日期_用户名_resnet 格式的文件中
writer3 = SummaryWriter(comment='_resnet')

使用的时候,创建一个 SummaryWriter 对象即可,以上展示了三种初始化 SummaryWriter 的方法:

  • 提供一个路径,将使用该路径来保存日志

  • 无参数,默认将使用 runs/日期_用户名 路径来保存日志

  • 提供一个 comment 参数,将使用 runs/日期_用户名+comment 路径来保存日志

运行结果:

Pytorch的使用技巧有哪些

有了 writer 我们就可以往日志里写入数字、图片、甚至声音等数据。

数字 (scalar)

这个是最简单的,使用 add_scalar 方法来记录数字常量。

add_scalar(tag, scalar_value, global_step=None, walltime=None)

总共 4 个参数。

  • tag (string): 数据名称,不同名称的数据使用不同曲线展示

  • scalar_value (float): 数字常量值

  • global_step (int, optional): 训练的 step

  • walltime (float, optional): 记录发生的时间,默认为 time.time()

需要注意,这里的 scalar_value 一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item() 方法获取其数值。我们一般会使用 add_scalar 方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。

运行如下代码:

from tensorboardX import SummaryWriter    
writer = SummaryWriter('runs/scalar_example')
for i in range(10):
    writer.add_scalar('quadratic', i**2, global_step=i)
    writer.add_scalar('exponential', 2**i, global_step=i)
writer.close()

通过 add_scalar 往日志里写入数字,日志保存到 runs/scalar_example中,writer 用完要记得 close,否则无法保存数据。

在 cmd 中使用如下命令:

tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口号,在浏览器中,就可以使用如下地址,打开 Tensorboad。

http://localhost:8088/

省去了我们自己写代码可视化的麻烦。

Pytorch的使用技巧有哪些

图片 (image)

使用 add_image 方法来记录单个图像数据。注意,该方法需要 pillow 库的支持

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

  • tag (string):数据名称

  • img_tensor (torch.Tensor / numpy.array):图像数据

  • global_step (int, optional):训练的 step

  • walltime (float, optional):记录发生的时间,默认为 time.time()

  • dataformats (string, optional):图像数据的格式,默认为 'CHW',即 Channel x Height x Width,还可以是 'CHW'、'HWC' 或 'HW' 等

我们一般会使用 add_image 来实时观察生成式模型的生成效果,或者可视化分割、目标检测的结果,帮助调试模型。

from tensorboardX import SummaryWriter
from urllib.request import urlretrieve
import cv2

urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')

writer = SummaryWriter('runs/image_example')
for i in range(1, 4):
    writer.add_image('UNet_Seg',
                     cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),
                     global_step=i,
                     dataformats='HWC')
writer.close()

代码就是下载上篇文章数据集里的三张图片,然后使用 Tensorboard 可视化处理来,使用 8088 端口开打 Tensorboard:

tensorboard --logdir=runs/image_example --port=8088

运行结果:

Pytorch的使用技巧有哪些

试想一下,一边训练,一边输出图片结果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方图、运行图、嵌入向量等,可以查看官方手册进行学习,方法都是类似的,简单好用。

到此,相信大家对“Pytorch的使用技巧有哪些”有了更深的了解,不妨来实际操作一番吧!这里是创新互联网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!


当前标题:Pytorch的使用技巧有哪些
网站路径:http://scyingshan.cn/article/gjjcde.html