Pytorch深度学习实战教程(四):必知必会的炼丹法宝

news/2024/11/29 18:30:27/

本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善。

一、前言

训练深度学习模型,就像“炼丹”,模型可能需要训练很多天。

我们不可能像「太上老君」那样,拿着浮尘,24 小时全天守在「八卦炉」前,更何况人家还有炼丹童、天兵天将,轮流值守。

人手不够,“法宝”来凑。

本文就盘点一下,我们可以使用的「炼丹法宝」。

PS:文中出现的所有代码,均可在我的 Github 上下载:点击查看

二、初级“法宝”,sys.stdout

训练模型,最常看的指标就是 Loss。我们可以根据 Loss 的收敛情况,初步判断模型训练的好坏。

如果,Loss 值突然上升了,那说明训练有问题,需要检查数据和代码。

如果,Loss 值趋于稳定,那说明训练完毕了。

观察 Loss 情况,最直观的方法,就是绘制 Loss 曲线图。

通过绘图,我们可以很清晰的看到,左图还有收敛空间,而右图已经完全收敛。

通过 Loss 曲线,我们可以分析模型训练的好坏,模型是否训练完成,起到一个很好的“监控”作用。

绘制 Loss 曲线图,第一步就是需要保存训练过程中的 Loss 值。

一个最简单的方法是使用,sys.stdout 标准输出重定向,简单好用,实乃“炼丹”必备“良宝”。

import os
import sys
class Logger():def __init__(self, filename="log.txt"):self.terminal = sys.stdoutself.log = open(filename, "w")def write(self, message):self.terminal.write(message)self.log.write(message)def flush(self):passsys.stdout = Logger()print("Jack Cui")
print("https://cuijiahua.com")
print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代码很简单,创建一个 log.py 文件,自己写一个 Logger 类,并采用 sys.stdout 重定向输出。

在 Terminal 中,不仅可以使用 print 打印结果,同时也会将结果保存到 log.txt 文件中。

运行 log.py,打印 print 内容的同时,也将内容写入了 log.txt 文件中。

使用这个代码,就可以在打印 Loss 的同时,将结果保存到指定的 txt 中,比如保存上篇文章训练 UNet 的 Loss。

三、中级“法宝”,matplotlib

Matplotlib 是一个 Python 的绘图库,简单好用。

简单几行命令,就可以绘制曲线图、散点图、条形图、直方图、饼图等等。

在深度学习中,一般就是绘制曲线图,比如 Loss 曲线、Acc 曲线。

举一个,简单的例子。

使用 sys.stdout 保存的 train_loss.txt,绘制 Loss 曲线。

train_loss.txt 下载地址:点击查看

思路非常简单,读取 txt 内容,解析 txt 内容,使用 Matplotlib 绘制曲线。

import matplotlib.pyplot as plt
# Jupyter notebook 中开启
# %matplotlib inline
with open('train_loss.txt', 'r') as f:train_loss = f.readlines()train_loss = list(map(lambda x:float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

指定 x 和 y 对应的值,就可以绘制。

是不是很简单?

关于 Matplotlib 更多的详细教程,可以查看官方手册:点击查看

四、中级“法宝”,Logging

说到保存日志,那不得不提 Python 的内置标准模块 Logging,它主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等,同时,我们也可以设置日志的输出格式。

import loggingdef get_logger(LEVEL, log_file = None):head = '[%(asctime)-15s] [%(levelname)s] %(message)s'if LEVEL == 'info':logging.basicConfig(level=logging.INFO, format=head)elif LEVEL == 'debug':logging.basicConfig(level=logging.DEBUG, format=head)logger = logging.getLogger()if log_file != None:fh = logging.FileHandler(log_file)logger.addHandler(fh)return loggerlogger = get_logger('info')logger.info('Jack Cui')
logger.info('https://cuijiahua.com')
logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要几行代码,进行一个简单的封装使用。使用函数 get_logger 创建一个级别为 info 的 logger,如果指定 log_file,则会对日志进行保存。

logging 默认支持的日志一共有 5 个等级:

日志级别等级 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默认的日志级别设置为 WARNING,也就是说如果不指定日志级别,只会显示大于等于 WARNING 级别的日志。

例如:

import logging
logging.debug("debug_msg")
logging.info("info_msg")
logging.warning("warning_msg")
logging.error("error_msg")
logging.critical("critical_msg")

运行结果:

WARNING:root:warning_msg
ERROR:root:error_msg
CRITICAL:root:critical_msg

可以看到 info 和 debug 级别的日志不会输出,默认的日志格式也比较简单。

默认的日志格式为日志级别:Logger名称:用户输出消息

当然,我们可以通过,logging.basicConfig 的 format 参数,设置日志格式。

字段有很多,可谓应有尽有,足以满足我们定制化的需求。

五、高级“法宝”,TensorboardX

上文介绍的“法宝”,并非针对深度学习“炼丹”使用的工具。

而 TensorboardX 则不同,它是专门用于深度学习“炼丹”的高级“法宝”。

早些时候,很多人更喜欢用 Tensorflow 的原因之一,就是 Tensorflow 框架有个一个很好的可视化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起来费劲儿不说,还有很多 Bug。

Pytorch 1.1.0 版本发布后,打破了这个局面,TensorBoard 成为了 Pytorch 的正式可用组件。

在 Pytorch 中,这个可视化工具叫做 TensorBoardX,其实就是针对 Tensorboard 的一个封装,使得 PyTorch 用户也能够调用 Tensorboard。

TensorboardX 安装也非常简单,使用 pip 即可安装。

pip install tensorboardX

tensorboardX 使用也很简单,编写如下代码。

from tensorboardX import SummaryWriter# 创建 writer1 对象
# log 会保存到 runs/exp 文件夹中
writer1 = SummaryWriter('runs/exp')# 使用默认参数创建 writer2 对象
# log 会保存到 runs/日期_用户名 格式的文件夹中
writer2 = SummaryWriter()# 使用 commet 参数,创建 writer3 对象
# log 会保存到 runs/日期_用户名_resnet 格式的文件中
writer3 = SummaryWriter(comment='_resnet')

使用的时候,创建一个 SummaryWriter 对象即可,以上展示了三种初始化 SummaryWriter 的方法:

  • 提供一个路径,将使用该路径来保存日志
  • 无参数,默认将使用 runs/日期_用户名 路径来保存日志
  • 提供一个 comment 参数,将使用 runs/日期_用户名+comment 路径来保存日志

运行结果:

有了 writer 我们就可以往日志里写入数字、图片、甚至声音等数据。

数字 (scalar)

这个是最简单的,使用 add_scalar 方法来记录数字常量。

add_scalar(tag, scalar_value, global_step=None, walltime=None)

总共 4 个参数。

  • tag (string): 数据名称,不同名称的数据使用不同曲线展示
  • scalar_value (float): 数字常量值
  • global_step (int, optional): 训练的 step
  • walltime (float, optional): 记录发生的时间,默认为 time.time()

需要注意,这里的 scalar_value 一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item() 方法获取其数值。我们一般会使用 add_scalar 方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。

运行如下代码:

from tensorboardX import SummaryWriter    
writer = SummaryWriter('runs/scalar_example')
for i in range(10):writer.add_scalar('quadratic', i**2, global_step=i)writer.add_scalar('exponential', 2**i, global_step=i)
writer.close()

通过 add_scalar 往日志里写入数字,日志保存到 runs/scalar_example中,writer 用完要记得 close,否则无法保存数据。

在 cmd 中使用如下命令:

tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口号,在浏览器中,就可以使用如下地址,打开 Tensorboad。

http://localhost:8088/

省去了我们自己写代码可视化的麻烦。

图片 (image)

使用 add_image 方法来记录单个图像数据。注意,该方法需要 pillow 库的支持

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

  • tag (string):数据名称
  • img_tensor (torch.Tensor / numpy.array):图像数据
  • global_step (int, optional):训练的 step
  • walltime (float, optional):记录发生的时间,默认为 time.time()
  • dataformats (string, optional):图像数据的格式,默认为 'CHW',即 Channel x Height x Width,还可以是 'CHW'、'HWC' 或 'HW' 等

我们一般会使用 add_image 来实时观察生成式模型的生成效果,或者可视化分割、目标检测的结果,帮助调试模型。

from tensorboardX import SummaryWriter
from urllib.request import urlretrieve
import cv2urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')writer = SummaryWriter('runs/image_example')
for i in range(1, 4):writer.add_image('UNet_Seg',cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),global_step=i,dataformats='HWC')
writer.close()

代码就是下载上篇文章数据集里的三张图片,然后使用 Tensorboard 可视化处理来,使用 8088 端口开打 Tensorboard:

tensorboard --logdir=runs/image_example --port=8088

运行结果:

试想一下,一边训练,一边输出图片结果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方图、运行图、嵌入向量等,可以查看官方手册进行学习,方法都是类似的,简单好用。

官方文档:点击查看

六、总结

工欲善其事,必先利其器。

本文讲解了深度学习中,常用的“炼丹法宝”的使用方法,sys.stdout、matplotlib、logging、tensorboardX 你更喜欢哪一款?

点赞再看,养成习惯,微信公众号搜索【JackCui-AI】关注一个在互联网摸爬滚打的潜行者

 


http://www.ppmy.cn/news/939740.html

相关文章

企业项目管理八大经典法宝

项目可大可小,大到一个跨年度的工程,小到一个办公室装修,对企业而言,都属于项目管理的范畴。项目可不可行,项目可不可控,项目与企业发展战略关联多大,项目如何推进,项目实施验收&…

unpivot行转列 oracle,oracle-行转列点评oracle11g sql新功能pivot/unpivot

摘要:(简要介绍Oracle11g sql的新功能 pivot/unpivot 的使用方法以及如何使用它们做到行列转换. 蓄势以久的Oracle 11g 终于七月敲锣打鼓隆重推出,接下来就是网上漫天盖地的新功能介绍。11g面向开发的新功能本来就不多,掰着手指头也就是pivot…

《SLA by Short brain》—学好英语口语的终极法宝!

你还在为学不会英语苦恼吗?现在就来告诉你一本秘籍《SLA by Short brain》,只要认真研读躬行,你就能事半功倍地学好英语。 我们普遍对外语学习方法的认识是: 1.要有语言环境,多跟外国人交流,最好是能出国&a…

linux ps 简书,Linux小白学习法宝-命令大全第一部分

命令后带(Mac)标记的,表示该命令在Mac OSX下测试,其它的在Debian下测试。 1. grep命令 文本查找命令, 能够使用正则表达式的方式搜索文本,其搜索对象可以是单个或则多个文件 基本格式 grep [option] [regex] [path] -o 只按行显示匹配的字符 …

LeetCode 1493. 删掉一个元素以后全为 1 的最长子数组 - 二分 + 滑动窗口

删掉一个元素以后全为 1 的最长子数组 提示 中等 90 相关企业 给你一个二进制数组 nums ,你需要从中删掉一个元素。 请你在删掉元素的结果数组中,返回最长的且只包含 1 的非空子数组的长度。 如果不存在这样的子数组,请返回 0 。 提示 1&a…

信息学奥赛IO三大法宝

freopen,,fopen,ifstream/ofstream 1. 开始的时候freopen很直观&#xff0c;特别是改熟悉的键盘输入in(屏幕监控)和屏幕输出out,再直接。 xInOut01.cpp # include<cstdio> int main( ){freopen("in.txt", "r", stdin);freopen("out.txt"…

五年级春期计算机教案,五年级下册信息技术教案

第八课 编辑你的文章 教材简析 本课是“文字世界”(上)单元的第二课。教材的内容主要是要求学生掌握文章编辑的基本方法&#xff0c;但这些看似简单的内容却在Office的操作中是至关重要的一部分&#xff0c;同时也是比较实用的内容&#xff0c;因此在课堂上必须要坚持精讲多做的…

2017年计算机二级有什么好处,2017年计算机二级考试备考的五大法宝

导语&#xff1a;导语&#xff1a;考试前如何调整好状态,每个人的一生都在经历各种各样的考试,我们无法逃避,只能沉着应对,调整好自己的状态,无忧考网。炎炎夏日无忧考网为您准备了考前的小绝招&#xff0c;轻松考试&#xff0c;圆满过关。 >>>2017年计算机等级考试备…