2021年的车道线检测新方法。
官方公开视频、论文、源码:
https://www.bilibili.com/video/BV1664y1o7wg
https://arxiv.org/abs/2008.13719
https://github.com/ZJULearning/resa
(该视频对现有车道线检测方法进行了分类、归纳总结,很完善。建议食用)
基于深度学习的车道线检测方法可分为如下几类:
- 实例分隔。例:LaneNet
- 语义分隔。例:SCNN
- 网格化。
- 多项式。
- 基于Anchor。
本方法是延续了SCNN的思路。将车道线检测当作语义分隔问题来处理的。
摘要
为什么通用场景下的语义分隔方法在车道线检测时不适用?
- 严重的遮挡、磨损的车道线
- 车道线标注的内在稀疏性(相较于图像中其他像素很稀疏)
而本文所述方法(RESA: Recurrent Feature-Shift Aggregator for Lane Detection)可以使用普通CNN提取初步特征后丰富车道特征。
RESA利用了车道线的强形状先验信息和捕获了像素间跨行、列的空间信息。
它将特征图的切片在垂直和水平方向上反复移动,并使每个像素都能收集全局信息。
提出了一种双向上采样解码器,在上采样阶段结合了粗粒度和细粒度特征。 它可以将低分辨率特征图细化为像素级预测。
在CULane和Tusimple数据集上进行了验证。
引言
车道线检测很重要,但是现有挑战也很多。
将车道线检测作为语义分割问题的一般形式:这些方法通过编码器-解码器框架解决了该问题。 他们首先使用CNN作为编码器,将语义信息提取到特征图中,然后使用上采样解码器将特征图恢复到其原始大小,最后执行逐像素预测。
但这些方法的问题如下:由于车道线的稀薄和长属性,带注释的通道像素的数量远远少于背景像素。(类别不均衡) 这些方法通常难以提取微妙的车道特征,并且可能会忽略先验的强形状或车道之间的高度相关性,从而导致检测性能较差。
但遮挡严重时,我们只能用常识去推断车道线。因此,普通CNN提取的低质量特征往往会忽略掉微妙的车道特征。
SCNN提出空间卷积的特征,尝试在特征图中的相邻行或列之间传递信息。然而,这种类似于RNN的架构是耗时的。同时,在相邻行或列之间顺序传递信息需要进行多次迭代,并且在长距离传播期间信息可能会丢失。
本文的RESA在特征图中收集信息,并更直接,更有效地传递空间信息。如图1所示,RESA可以通过循环地移动特征图的切片来垂直和水平地聚合信息。RESA将首先在垂直和水平方向上对特征图进行切片,然后使每个切片的特征接收与某个跨度相邻的另一个切片的要素。 每个像素分几步同时更新,最终每个位置都可以在整个空间中收集信息。这样,信息可以在特征图中的像素之间传播。
RESA有三个优点:
- 信息并行传递,时间消耗低。
- stride设置不同,在传播过程中,不同特征图的切片可以聚合到一起而没有信息损失。
- 该模块可简单并入其他网络中。
本文还提出了 the Bilateral Up-Sampling Decoder (BUSD),双边上采样解码器。BUSD有两个分支:一个分支用于捕获粗粒度特征,另一分支是捕获细细节特征。粗支路直接应用双线性上采样并产生模糊图像。相比之下,细分支通过转置卷积实现上采样,然后是两个non-bottleneck模块以修复细微的损失。 结合两个分支,我们的解码器可以将低分辨率特征图精确地恢复为逐像素预测。
相关工作
车道线检测
两类:传统方法和基于深度学习的方法。
传统方法优缺点无需再说。
基于深度学习的方法在上述视频中大体提到。还有一个FastDraw方法可以学一下。
利用空间信息
在神经网络中利用空间信息的一些尝试:
spatial Recurrent Neural Networks (RNNs):这些RNN在图像上水平和垂直传递空间变化的上下文信息。
Graph LSTM:为语义对象解析提供信息传播途径。
SCNN:将传统的逐层卷积泛化为特征图中的逐片卷积,从而使消息能够在同一层的行和列之间的像素之间传递。SCNN将消息作为残差进行传播,并且比以前的工作更易于训练,但是在远距离传播过程中仍然遭受昂贵的计算和信息丢失的困扰。
RESA:比SCNN具有更高的计算效率,同时可以以不同的步幅从切片特征中收集信息以避免信息丢失。
方法
网络架构
如图2(a)所示。三个组件:编码器、聚合器、解码器。
- 编码器:骨架网络(VGG、ResNet等),用于提取特征。原始图片会变成1/8尺寸的特征图。
- RESA:收集空间特征。在每次迭代中,特征图的切片都会在4个方向上反复移动,并在垂直和水平方向传递信息。 最后,RESA需要进行K次迭代,以确保每个位置都可以接收整个特征图中的信息。
- 解码器:双向上采样模块。每个块两次上采样,最后将1/8特征图恢复到原始大小。 双向上采样解码器由粗粒度分支和细粒度分支组成。
经过解码器上采样后,输出特征图将用于预测每个车道的存在和概率分布。 紧随其后的是全连接层以进行存在预测,并将执行二进制分类。 将针对车道概率分布预测进行逐像素预测,这与语义分割任务相同。
RESA
假设有一个3D特征图向量X的尺寸C x H x W(通道数、行数、列数)。
:在第k次迭代的特征图X;c、i、j(通道、行、列)。
RESA的进一步计算定义如下:
其中,
k是迭代次数;L代表等式1和2中的W和H。
f:非线性激活函数。RELU。
带有下标’的X:更新后的元素
sk:第k次迭代中的步长。(由迭代次数k决定,使得信息传递距离是动态的。)
等式1和2分别展示了垂直和水平方向的信息传递公式。
F:一组一维卷积。尺寸是
Nin:输入通道数
Nout:输出通道数
w:卷积核宽度
Nin和Nout都等于C。
等式1和2中的Z是信息传递的立即结果。
正如图2(b)和图2(c)中所示,特征图X在水平方向被分为H个切片,在垂直方向被分为W个切片。
我们仅通过索引计算即可实现递归特征转换信息的传递,而无需其他复杂的操作。
信息传递有4个方向(见图2)。具有相同偏移步长的卷积层权重在相同方向上的所有切片上共享。
以“右到左”信息传递为例:如图3所示,
在k=0次迭代,s1=1;每一列的Xi可以接收到Xi+1的转换特征。由于反复移动,尾部的列也可以接收另一侧的特征。(即Xw-1可接收X0的转换特征。)
在k=1次迭代,s2=2;每一列的Xi可以接收到Xi+2的转换特征。
以X0为例,X0在第二次迭代中接收X2的信息,考虑到X0已经接收了X1的信息(在第一次迭代中);并且X2也接收了X3的信息(第一次迭代中);现在X0在两次迭代中便接收了X0、X1、X2、X3的信息。
下一次迭代类似上述过程。
在所有K次迭代之后,当最终k = K时,每个Xi都可以在整个特征图中聚合信息。
分析
RESA在4个方向上循环应用特征转换操作,并使每个位置都可以感知和聚合同一特征图中的所有空间信息。 车道检测是一项高度依赖周围线索的任务。 例如,一个车道被多辆汽车遮挡,但是我们仍然可以从其他车道,汽车方向,道路形状或其他视觉线索中推断出它。
RESA优点如下:
- 计算效率高。传统信息传递方法(马尔科夫随机场/条件随机场)是一种全连接的方式,计算密集且冗余。SCNN方法是一种类似于RNN的结构,复杂度随着空间大小的增加而线性增加,并且顺序传播不能充分利用计算资源。而RESA的复杂度是空间尺寸的log级别,每次迭代中,所有位置以并行方式进行更新,在次迭代后,每个位置可以聚合来自整个特征图的空间信息。
- 特征信息聚合高效。每个位置可以聚合来自整个特征图的空间信息,且此过程没有信息损失。如图5所示,由于SCNN仅将特征信息传递给相邻的元素,并且在传播过程中丢失信息,因此RESA的性能比SCNN更好。
- 易于组合到其他网络中。1.实现很简单,只需要在特征图中进行索引操作即可。2.不会更改输入特征图的形状(理想的位置是在提取特征CNN之后,如VGG、ResNet、MobileNet)。3.计算时间几乎可被忽略。
双向上采样解码器
上采样特征图到输入尺寸。大多数解码器利用双线性上采样(bilinear upsampling)过程来恢复最终的逐像素预测,这很容易获得粗略结果,但可能会丢失细节。一些方法使用堆叠卷积运算和反卷积运算来获得精确的上采样结果。我们组合他们,得到双向上采样解码器(Bilateral Up-Sampling Decoder ,BUSD)。解码器由两个分支组成,一个分支是恢复粗粒度特征,另一个是修复细微的损耗。如图4所示,输入将通过两个分支,并且将产生通道数量减半的2x上采样输出。 通过这些堆叠的解码器块后,RESA生成的1/8特征图将恢复为与输入图像相同的大小。
粗粒度分支
粗粒度分支将从最后一层快速输出粗略向上采样的特征,这可能会忽略细节。设计了一条简单而浅的路径。 我们首先应用1×1卷积来减少
通道数乘以输入特征图的2倍,然后是BN。 双线性插值直接用于对输入特征图进行上采样。 最后,执行ReLU。
(关于一些细节,还是要结合代码来看)
细粒度分支
细细节分支用于微调来自粗粒度分支的信息丢失,并且路径比另一个更深。 我们使用步幅为2的转置卷积对特征图进行上采样,并同时将通道数减少2倍。 与粗粒度分支中使用的类似设计一样,对ReLU进行上采样。 Non-bottleneck块由具有BN和ReLU的四个3×1和1×3卷积组成,可以保持特征图的形状并以分解的方式有效地提取信息。 在上采样操作之后,我们堆叠了两个non-bottleneck。
实验
数据集
包含TuSimple和CULane两个数据集。TuSimple可参考我的介绍,场景十分简单,只包含高速公路下稳定的光照情况。CULane数据集包含场景复杂些,有55小时,包含9种不同场景,包括市区中的正常人群,人群,弯道,夜晚,夜晚,无线和箭头等。
CULane
每条车道线被理解为30像素宽的线。计算预测和真值的IoU。预测车道线的IoU**大于阈值(0.5)**被认为是 true positives (TP)。
F1-measure同样作为评价指标,计算公式如下:
其中,
FP和FN分别代表false positive 和 false negative。
Tusimple
评价指标是准确率,定义如下。
其中,
是预测对的车道线点数(预测值和真值在一定范围内)。
是每个clip的真值点数量。
同样也计算FP和FN。
参数的设置
类似于SAD(自蒸馏)那篇论文,我们首先将原始图像进行resize:288 × 800 for CULane and 368 × 640 for Tusimple。
SGD:momentum 0.9,weight decay 1e-4
learning rate: 2.5e-2 for CULane; 2.0e-2 for Tusimple
我们在前500batches中使用warmup策略,然后应用多项式学习率衰减策略,将power设置为0.9。
损失函数:与SCNN类似。采用segmentation BCE 损失 和 existence classification CE 损失。(考虑到背景和车道线的类别不均衡,背景的segmentation损失的权重为0.4。
batch size:CULane是8 ;Tusimple是4。
TuSimple数据集的训练epoch总数设置为50,而CULane数据集的训练epoch总数设置为12。
Ubuntu下,4个NVIDIA 2080Ti GPU(11G内存)进行训练;Pytorch1.1。
使用ResNet和VGG作为骨架网络;在ResNet中,我们添加了额外的1×1卷积以将输出通道减少到128。VGG的修改与SCNN相同。
主要结果
我们的RESA采用ResNet-50作为主干,标记为RESA-50。可达到36fps。
消融实验
每个组件的影响
baseline:选择ResNet-34作为backbone网络,从主干中提取特征图后,就像SCNN一样,使用双线性插值直接对特征图进行8倍采样。输出作为回归问题,并最终获得每个车道线的概率分布。
我们用双边上采样解码器代替双线性插值,然后逐步在主干和解码器之间插入RESA。如表4所示。
特征聚合的有效性
RESA方向的影响。
RESA迭代次数
从理论上讲,随着迭代的增加,特征图的每个切片都可以聚合更多的信息,从而有助于获得更好的性能。但是,更多的迭代会导致更多的计算时间成本。 这是性能和计算资源之间的折衷方案。 为了在两者之间取得平衡,我们选择迭代= 4作为最终选择。
RESA和SCNN的比较
SCNN已经表明,消息传递方案可以提高车道检测性能,但是更多的参数只会带来很小的改进。 因此,我们将RESA与SCNN进行了比较,以验证我们方法的有效性。
计算效率
我们还进行了实验,以比较我们的方法与LSTM,SCNN的运行时间。 这些方法的运行时间记录为1000次运行的平均时间。 我们使用不同的卷积核宽度(7、9、11)来比较效率。 SCNN以顺序的方式传播信息,即,一个分片直到从上一个分片接收到信息后才将信息传递给下一个分片。 因此,由于顺序计算,这种消息传递需要大量的计算成本。 相反,我们的RESA以并行方式传递信息。
代码解读
准备阶段
- 克隆RESA仓库
git clone https://github.com/zjulearning/resa.git
此目录标记为$RESA_ROOT
。
- 创建一个conda虚拟环境并激活它(conda是可选的)
conda create -n resa python=3.8 -y
conda activate resa
- 安装依赖
#首先安装pytorch,您的系统中的cudatoolkit版本应该相同。 (您也可以使用pip安装pytorch和torchvision)
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
#也可以通过pip工具安装
pip install torch torchvision
#安装python包
pip install -r requirements.txt
- 数据准备
下载CULane和Tusimple。 然后将它们提取到$ CULANEROOT
和$ TUSIMPLEROOT
。 创建到data
目录的链接。
cd $RESA_ROOT
mkdir -p data
ln -s $CULANEROOT data/CULane
ln -s $TUSIMPLEROOT data/tusimple
对于CULane,您应具有以下结构:
$CULANEROOT/driver_xx_xxframe # data folders x6
$CULANEROOT/laneseg_label_w16 # lane segmentation labels
$CULANEROOT/list # data lists
对于Tusimple,您应具有以下结构:
$TUSIMPLEROOT/clips # 数据文件夹
$TUSIMPLEROOT/lable_data_xxxx.json # 这里说有4个这样的json文件,我重新下载了一下数据集,发现还是3个
$TUSIMPLEROOT/test_tasks_0627.json # 测试任务的json文件
$TUSIMPLEROOT/test_label.json # 测试标签json文件
对于Tusimple,未提供语义标注信息,因此我们需要根据json标注生成语义信息。
python scripts/generate_seg_tusimple.py --root $TUSIMPLEROOT
#这将产生语义标签
(这里好像是在tools
文件夹下,即应:
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT )
- 安装CULane评价工具
此工具需要OpenCV C ++。 请按照此处安装OpenCV C ++。 或者只是使用命令sudo apt-get install libopencv-dev
安装opencv。
然后编译CULane的评估工具。
cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation
make
cd -
请注意,默认的opencv
版本是3。如果使用opencv2,请在Makefile
中将OPENCV_VERSION:= 3
修改为OPENCV_VERSION:= 2
。
训练
运行如下脚本:
python main.py [configs/path_to_your_config] --gpus [gpu_ids]
例如:
python main.py configs/culane.py --gpus 0 1 2 3
测试
运行如下脚本:
python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num]
例如:
python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3
我们在CULane和Tusimple数据集上提供了两个经过训练的ResNet模型。
下载性能最佳的模型(Tusimple:GoogleDrive / BaiduDrive(code:s5ii),CULane:GoogleDrive / BaiduDrive(code:rlwj))
(Google那个好像需要权限什么的,百度云盘还是可以的。)
引用信息
@misc{zheng2020resa,
title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection},
author={Tu Zheng and Hao Fang and Yi Zhang and Wenjian Tang and Zheng Yang and Haifeng Liu and Deng Cai},
year={2020},
eprint={2008.13719},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
代码解读
于tusimple数据集
数据准备里面第4步的最后一个操作。
python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT
该代码解读如下:
import json
import numpy as np
import cv2
import os
import argparseTRAIN_SET = ['label_data_0313.json', 'label_data_0601.json']
VAL_SET = ['label_data_0531.json']
TRAIN_VAL_SET = TRAIN_SET + VAL_SET
TEST_SET = ['test_label.json']def gen_label_for_json(args, image_set):H, W = 720, 1280SEG_WIDTH = 30save_dir = args.savediros.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True)list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w")json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set))with open(json_path) as f:for line in f:label = json.loads(line)# ---------- clean and sort lanes -------------lanes = []_lanes = []slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slopefor i in range(len(label['lanes'])):l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0] # 一条有效的车道线的点if (len(l)>1):_lanes.append(l)slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180) # 计算角度_lanes = [_lanes[i] for i in np.argsort(slope)] # np.argsort:排序,输出索引。https://blog.csdn.net/qq_38486203/article/details/80967696slope = [slope[i] for i in np.argsort(slope)] # 这里是对车道线排序idx = [None for i in range(6)]for i in range(len(slope)):if slope[i] <= 90:idx[2] = iidx[1] = i-1 if i > 0 else Noneidx[0] = i-2 if i > 1 else Noneelse:idx[3] = iidx[4] = i+1 if i+1 < len(slope) else Noneidx[5] = i+2 if i+2 < len(slope) else Nonebreakfor i in range(6):lanes.append([] if idx[i] is None else _lanes[idx[i]])# ---------------------------------------------img_path = label['raw_file']seg_img = np.zeros((H, W, 3))list_str = [] # str to be written to list.txtfor i in range(len(lanes)):coords = lanes[i]if len(coords) < 4:list_str.append('0')continuefor j in range(len(coords)-1):cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2)list_str.append('1')seg_path = img_path.split("/")seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3]os.makedirs(seg_path, exist_ok=True)seg_path = os.path.join(seg_path, img_name[:-3]+"png")cv2.imwrite(seg_path, seg_img)seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"])if seg_path[0] != '/':seg_path = '/' + seg_pathif img_path[0] != '/':img_path = '/' + img_pathlist_str.insert(0, seg_path)list_str.insert(0, img_path)list_str = " ".join(list_str) + "\n"list_f.write(list_str)def generate_json_file(save_dir, json_file, image_set):with open(os.path.join(save_dir, json_file), "w") as outfile:for json_name in (image_set):with open(os.path.join(args.root, json_name)) as infile:for line in infile:outfile.write(line)def generate_label(args):save_dir = os.path.join(args.root, args.savedir)os.makedirs(save_dir, exist_ok=True)generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET)generate_json_file(save_dir, "test.json", TEST_SET)print("generating train_val set...")gen_label_for_json(args, 'train_val')print("generating test set...")gen_label_for_json(args, 'test')if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--root', required=True, help='The root of the Tusimple dataset')parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset')args = parser.parse_args()generate_label(args)
该文件就是生成了语义标注信息。
测试
跑代码会报如下错误。
.local/lib/python3.8/site-packages/torch/cuda/init.py:104: UserWarning:
GeForce RTX 3060 Ti with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75.
If you want to use the GeForce RTX 3060 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/
warnings.warn(incompatible_device_warn.format(device_name, capability, " ".join(arch_list), device_name))
原因如下:
30 系列显卡是新一代架构,新驱动不支持 cuda 9 以及 cuda 10,所以必须安装 cuda 11。
因此要使用-nightly
来重新安装pytorch。
详见:
https://blog.csdn.net/weixin_43896241/article/details/108979744
import os
import os.path as osp
import time
import shutil
import torch
import torchvision # 针对这里出现的ModuleNotFoundError: No module named 'torchvision'。
# 解决方案:https://github.com/pytorch/pytorch/issues/12525
# conda install torchvision -c pytorch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloaderdef main():args = parse_args()os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)cfg = Config.fromfile(args.config) # 生成了配置文件cfg.gpus = len(args.gpus) # 关于__getattr__和__setattr__:https://blog.csdn.net/weixin_42233629/article/details/85723073cfg.load_from = args.load_fromcfg.finetune_from = args.finetune_fromcfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type # cfg.dataset.train.type:会多次调用魔术方法。cudnn.benchmark = True # 增加程序的运行效率。https://blog.csdn.net/Ibelievesunshine/article/details/99471258cudnn.fastest = True # runner = Runner(cfg) if args.validate:val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)runner.validate(val_loader)else:runner.train()def parse_args():parser = argparse.ArgumentParser(description='Train a detector')parser.add_argument('config', help='train config file path') # 位置参数。表示这个位置一定要这个文件。https://docs.python.org/zh-cn/3/library/argparse.html?highlight=add_argument#argparse.ArgumentParser.add_argumentparser.add_argument('--work_dirs', type=str, default='work_dirs',help='work dirs')parser.add_argument('--load_from', default=None,help='the checkpoint file to resume from')parser.add_argument('--finetune_from', default=None,help='whether to finetune from the checkpoint')parser.add_argument('--validate',action='store_true',help='whether to evaluate the checkpoint during training') # 关于action参数:# 如果是store_true,那么命令行不输入次参数时是False;输入时是True# 如果是store_false,那么命令行不输入次参数时是Ture;输入时是False# (即输入时不变,不输入时取反)parser.add_argument('--gpus', nargs='+', type=int, default='0') # nargs=' '+' 表示参数可设置一个或多个parser.add_argument('--seed', type=int,default=None, help='random seed')args = parser.parse_args()return argsif __name__ == '__main__':main()
main.py
文件导入的'configs/tusimple.py'
如下:
net = dict(type='RESANet',
)backbone = dict(type='ResNetWrapper',resnet='resnet34',pretrained=True,replace_stride_with_dilation=[False, True, True],out_conv=True,fea_stride=8,
)resa = dict(type='RESA',alpha=2.0,iter=5,input_channel=128,conv_stride=9,
)decoder = 'BUSD' trainer = dict(type='RESA'
)evaluator = dict(type='Tusimple', thresh = 0.60
)optimizer = dict(type='sgd',lr=0.020,weight_decay=1e-4,momentum=0.9
)total_iter = 80000
import math
scheduler = dict(type = 'LambdaLR',lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)bg_weight = 0.4img_norm = dict(mean=[103.939, 116.779, 123.68],std=[1., 1., 1.]
)img_height = 368
img_width = 640
cut_height = 160
seg_label = "seg_label"dataset_path = './data/tusimple'
test_json_file = './data/tusimple/test_label.json'dataset = dict(train=dict(type='TuSimple',img_path=dataset_path,data_list='train_val_gt.txt',),val=dict(type='TuSimple',img_path=dataset_path,data_list='test_gt.txt'),test=dict(type='TuSimple',img_path=dataset_path,data_list='test_gt.txt')
)loss_type = 'cross_entropy'
seg_loss_weight = 1.0batch_size = 4
workers = 12
num_classes = 6 + 1
ignore_label = 255
epochs = 300
log_interval = 100
eval_ep = 1
save_ep = epochs
log_note = ''
main.py
文件from utils.config import Config
解读如下:
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import os.path as osp
import shutil
import sys
import tempfile
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_modulefrom addict import Dict
from yapf.yapflib.yapf_api import FormatCodeBASE_KEY = '_base_'
DELETE_KEY = '_delete_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text']def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):if not osp.isfile(filename):raise FileNotFoundError(msg_tmpl.format(filename))class ConfigDict(Dict):def __missing__(self, name): # 关于此魔术方法https://blog.csdn.net/qq_43168521/article/details/103150464。https://www.cnblogs.com/geeklove01/p/8747653.html。还是不太懂。。。raise KeyError(name)def __getattr__(self, name):try:value = super(ConfigDict, self).__getattr__(name)except KeyError:ex = AttributeError(f"'{self.__class__.__name__}' object has no "f"attribute '{name}'")except Exception as e:ex = eelse:return valueraise exdef add_args(parser, cfg, prefix=''):for k, v in cfg.items():if isinstance(v, str):parser.add_argument('--' + prefix + k)elif isinstance(v, int):parser.add_argument('--' + prefix + k, type=int)elif isinstance(v, float):parser.add_argument('--' + prefix + k, type=float)elif isinstance(v, bool):parser.add_argument('--' + prefix + k, action='store_true')elif isinstance(v, dict):add_args(parser, v, prefix + k + '.')elif isinstance(v, abc.Iterable):parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')else:print(f'cannot parse key {prefix + k} of type {type(v)}')return parserclass Config:"""A facility for config and config files.It supports common file formats as configs: python/json/yaml. The interfaceis the same as a dict object and also allows access config values asattributes.Example:>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))>>> cfg.a1>>> cfg.b{'b1': [0, 1]}>>> cfg.b.b1[0, 1]>>> cfg = Config.fromfile('tests/data/config/a.py')>>> cfg.filename"/home/kchen/projects/mmcv/tests/data/config/a.py">>> cfg.item4'test'>>> cfg"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: ""{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}""""
# 上面的注释的意思呢,就是说
# 用于配置和配置文件的工具。
# 它支持常见的文件格式作为config,如:python/json/yaml。
# 该接口与dict对象相同,还允许将配置值作为属性访问。@staticmethoddef _validate_py_syntax(filename):with open(filename) as f:content = f.read() # 读取配置文件configs/tusimple.py中的内容try:ast.parse(content) # 把源码解析为AST节点。except SyntaxError:raise SyntaxError('There are syntax errors in config 'f'file {filename}')@staticmethoddef _file2dict(filename):filename = osp.abspath(osp.expanduser(filename)) # osp.expanduser:在linux下面,一般如果你自己使用系统的时候,是可以用~来代表"/home/你的名字/"这个路径的.但是python是不认识~这个符号的,如果你写路径的时候直接写"~/balabala",程序是跑不动的.所以如果你要用~,你就应该用这个os.path.expanduser把~展开.https://www.zhihu.com/question/48161511# osp.abspath:获取当前脚本的完整路径。https://blog.csdn.net/liuskyter/article/details/99936955check_file_exist(filename) if filename.endswith('.py'): # 检查字符串是否以标点符号 ('.py') 结尾: https://www.w3school.com.cn/python/ref_string_endswith.aspwith tempfile.TemporaryDirectory() as temp_config_dir: # tempfile:本模块主要提供了产生临时文件或临时目录,支持所有操作系统平台。创建临时文件时,不再使用进程ID来命名,而使用6位随机字符串进行命名。# tempfile.TemporaryDirectory():官方文档:创建并返回一个临时目录。这与mkdtemp具有相同的行为,但可以用作上下文管理器。退出上下文后,该目录及其包含的所有内容都将被删除。temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix='.py') # tempfile.NamedTemporaryFile:创建并返回一个临时文件。这里面两个参数用于说明文件的目录和格式temp_config_name = osp.basename(temp_config_file.name) # os.path.basename(path):返回路径 path 的基本名称。shutil.copyfile(filename,osp.join(temp_config_dir, temp_config_name)) # shutil.copyfile(src, dst):将名为 src 的文件的内容(不包括元数据)拷贝到名为 dst 的文件并以尽可能高效的方式返回 dst。 src 和 dst 均为路径类对象或以字符串形式给出的路径名。temp_module_name = osp.splitext(temp_config_name)[0] # os.path.splitext(path):将路径 path 拆分为一对,即 (root, ext),使 root + ext == path。其中 ext 为空或以英文句点开头,且最多包含一个句点。路径前的句点将被忽略,例如 splitext('.cshrc') 返回 ('.cshrc', '')。sys.path.insert(0, temp_config_dir) Config._validate_py_syntax(filename) # mod = import_module(temp_module_name) # 在执行相对导入时,` package `参数是必需的。它指定要使用的包作为锚点,从这个锚点将相对导入解析为绝对导入。sys.path.pop(0)cfg_dict = {name: valuefor name, value in mod.__dict__.items() # 把mod(congfig/tusimple.py)中的数据导入if not name.startswith('__')}# delete imported moduledel sys.modules[temp_module_name]# close temp filetemp_config_file.close()elif filename.endswith(('.yml', '.yaml', '.json')):import mmcvcfg_dict = mmcv.load(filename)else:raise IOError('Only py/yml/yaml/json type are supported now!')cfg_text = filename + '\n'with open(filename, 'r') as f:cfg_text += f.read()if BASE_KEY in cfg_dict:cfg_dir = osp.dirname(filename)base_filename = cfg_dict.pop(BASE_KEY)base_filename = base_filename if isinstance(base_filename, list) else [base_filename]cfg_dict_list = list()cfg_text_list = list()for f in base_filename:_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))cfg_dict_list.append(_cfg_dict)cfg_text_list.append(_cfg_text)base_cfg_dict = dict()for c in cfg_dict_list:if len(base_cfg_dict.keys() & c.keys()) > 0:raise KeyError('Duplicate key is not allowed among bases')base_cfg_dict.update(c)base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)cfg_dict = base_cfg_dict# merge cfg_textcfg_text_list.append(cfg_text)cfg_text = '\n'.join(cfg_text_list)return cfg_dict, cfg_text # cfg_dict:字典数据结构# cfg_text:多了一行“/home/wenqiang/resa/configs/tusimple.py”,同样也和这个文件数据结构一致。@staticmethoddef _merge_a_into_b(a, b):# merge dict `a` into dict `b` (non-inplace). values in `a` will# overwrite `b`.# copy first to avoid inplace modificationb = b.copy()for k, v in a.items():if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):if not isinstance(b[k], dict):raise TypeError(f'{k}={v} in child config cannot inherit from base 'f'because {k} is a dict in the child config but is of 'f'type {type(b[k])} in base config. You may set 'f'`{DELETE_KEY}=True` to ignore the base config')b[k] = Config._merge_a_into_b(v, b[k])else:b[k] = vreturn b@staticmethoddef fromfile(filename):cfg_dict, cfg_text = Config._file2dict(filename)return Config(cfg_dict, cfg_text=cfg_text, filename=filename)@staticmethoddef auto_argparser(description=None):"""Generate argparser from config file automatically (experimental)"""partial_parser = ArgumentParser(description=description)partial_parser.add_argument('config', help='config file path')cfg_file = partial_parser.parse_known_args()[0].configcfg = Config.fromfile(cfg_file)parser = ArgumentParser(description=description)parser.add_argument('config', help='config file path')add_args(parser, cfg)return parser, cfgdef __init__(self, cfg_dict=None, cfg_text=None, filename=None):if cfg_dict is None:cfg_dict = dict()elif not isinstance(cfg_dict, dict):raise TypeError('cfg_dict must be a dict, but 'f'got {type(cfg_dict)}')for key in cfg_dict:if key in RESERVED_KEYS:raise KeyError(f'{key} is reserved for config file')super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) # 这个操作呢,简单理解就是调用/重载了Config的父类中的__setattr__这个魔术方法。# 该魔术方法实现了将Config类设置一个_cfg_dict属性,并将cfg_dict实例为ConfigDict类。super(Config, self).__setattr__('_filename', filename)if cfg_text:text = cfg_textelif filename:with open(filename, 'r') as f:text = f.read()else:text = ''super(Config, self).__setattr__('_text', text)@propertydef filename(self):return self._filename@propertydef text(self):return self._text@propertydef pretty_text(self):indent = 4def _indent(s_, num_spaces):s = s_.split('\n')if len(s) == 1:return s_first = s.pop(0)s = [(num_spaces * ' ') + line for line in s]s = '\n'.join(s)s = first + '\n' + sreturn sdef _format_basic_types(k, v, use_mapping=False):if isinstance(v, str):v_str = f"'{v}'"else:v_str = str(v)if use_mapping:k_str = f"'{k}'" if isinstance(k, str) else str(k)attr_str = f'{k_str}: {v_str}'else:attr_str = f'{str(k)}={v_str}'attr_str = _indent(attr_str, indent)return attr_strdef _format_list(k, v, use_mapping=False):# check if all items in the list are dictif all(isinstance(_, dict) for _ in v):v_str = '[\n'v_str += '\n'.join(f'dict({_indent(_format_dict(v_), indent)}),'for v_ in v).rstrip(',')if use_mapping:k_str = f"'{k}'" if isinstance(k, str) else str(k)attr_str = f'{k_str}: {v_str}'else:attr_str = f'{str(k)}={v_str}'attr_str = _indent(attr_str, indent) + ']'else:attr_str = _format_basic_types(k, v, use_mapping)return attr_strdef _contain_invalid_identifier(dict_str):contain_invalid_identifier = Falsefor key_name in dict_str:contain_invalid_identifier |= \(not str(key_name).isidentifier())return contain_invalid_identifierdef _format_dict(input_dict, outest_level=False):r = ''s = []use_mapping = _contain_invalid_identifier(input_dict)if use_mapping:r += '{'for idx, (k, v) in enumerate(input_dict.items()):is_last = idx >= len(input_dict) - 1end = '' if outest_level or is_last else ','if isinstance(v, dict):v_str = '\n' + _format_dict(v)if use_mapping:k_str = f"'{k}'" if isinstance(k, str) else str(k)attr_str = f'{k_str}: dict({v_str}'else:attr_str = f'{str(k)}=dict({v_str}'attr_str = _indent(attr_str, indent) + ')' + endelif isinstance(v, list):attr_str = _format_list(k, v, use_mapping) + endelse:attr_str = _format_basic_types(k, v, use_mapping) + ends.append(attr_str)r += '\n'.join(s)if use_mapping:r += '}'return rcfg_dict = self._cfg_dict.to_dict()text = _format_dict(cfg_dict, outest_level=True)# copied from setup.cfgyapf_style = dict(based_on_style='pep8',blank_line_before_nested_class_or_def=True,split_before_expression_after_opening_paren=True)text, _ = FormatCode(text, style_config=yapf_style, verify=True)return textdef __repr__(self):return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'def __len__(self):return len(self._cfg_dict)def __getattr__(self, name):return getattr(self._cfg_dict, name)def __getitem__(self, name):return self._cfg_dict.__getitem__(name)def __setattr__(self, name, value):if isinstance(value, dict):value = ConfigDict(value)self._cfg_dict.__setattr__(name, value)def __setitem__(self, name, value):if isinstance(value, dict):value = ConfigDict(value)self._cfg_dict.__setitem__(name, value)def __iter__(self):return iter(self._cfg_dict)def dump(self, file=None):cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()if self.filename.endswith('.py'):if file is None:return self.pretty_textelse:with open(file, 'w') as f:f.write(self.pretty_text)else:import mmcvif file is None:file_format = self.filename.split('.')[-1]return mmcv.dump(cfg_dict, file_format=file_format)else:mmcv.dump(cfg_dict, file)def merge_from_dict(self, options):"""Merge list into cfg_dictMerge the dict parsed by MultipleKVAction into this cfg.Examples:>>> options = {'model.backbone.depth': 50,... 'model.backbone.with_cp':True}>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))>>> cfg.merge_from_dict(options)>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')>>> assert cfg_dict == dict(... model=dict(backbone=dict(depth=50, with_cp=True)))Args:options (dict): dict of configs to merge from."""option_cfg_dict = {}for full_key, v in options.items():d = option_cfg_dictkey_list = full_key.split('.')for subkey in key_list[:-1]:d.setdefault(subkey, ConfigDict())d = d[subkey]subkey = key_list[-1]d[subkey] = vcfg_dict = super(Config, self).__getattribute__('_cfg_dict')super(Config, self).__setattr__('_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))class DictAction(Action):"""argparse action to split an argument into KEY=VALUE formon the first = and append to a dictionary. List options shouldbe passed as comma separated values, i.e KEY=V1,V2,V3"""@staticmethoddef _parse_int_float_bool(val):try:return int(val)except ValueError:passtry:return float(val)except ValueError:passif val.lower() in ['true', 'false']:return True if val.lower() == 'true' else Falsereturn valdef __call__(self, parser, namespace, values, option_string=None):options = {}for kv in values:key, val = kv.split('=', maxsplit=1)val = [self._parse_int_float_bool(v) for v in val.split(',')]if len(val) == 1:val = val[0]options[key] = valsetattr(namespace, self.dest, options)
关于congfig
文件不太懂的可以参考下述博客:
https://blog.csdn.net/wulele2/article/details/113870217
模型结构如下:
/home/wenqiang/.conda/envs/wqf/bin/python /opt/pycharm-2020.1.1/plugins/python/helpers/pydev/pydevd.py --multiproc --qt-support=auto --client 127.0.0.1 --port 36039 --file /home/wenqiang/resa/main.py
pydev debugger: process 18320 is connecting
Connected to pydev debugger (build 201.7223.92)
2021-04-08 19:35:38,632 - resa - INFO - Config:
/home/wenqiang/resa/configs/tusimple.py
net = dict(
type=‘RESANet’,
)
backbone = dict(
type=‘ResNetWrapper’,
resnet=‘resnet34’,
pretrained=True,
replace_stride_with_dilation=[False, True, True],
out_conv=True,
fea_stride=8,
)
resa = dict(
type=‘RESA’,
alpha=2.0,
iter=5,
input_channel=128,
conv_stride=9,
)
decoder = ‘BUSD’
trainer = dict(
type=‘RESA’
)
evaluator = dict(
type=‘Tusimple’,
thresh = 0.60
)
optimizer = dict(
type=‘sgd’,
lr=0.020,
weight_decay=1e-4,
momentum=0.9
)
total_iter = 80000
import math
scheduler = dict(
type = ‘LambdaLR’,
lr_lambda = lambda _iter : math.pow(1 - iter/total_iter, 0.9)
)
bg_weight = 0.4
img_norm = dict(
mean=[103.939, 116.779, 123.68],
std=[1., 1., 1.]
)
img_height = 368
img_width = 640
cut_height = 160
seg_label = “seg_label”
dataset_path = ‘./data/tusimple’
test_json_file = ‘./data/tusimple/test_label.json’
dataset = dict(
train=dict(
type=‘TuSimple’,
img_path=dataset_path,
data_list=‘train_val_gt.txt’,
),
val=dict(
type=‘TuSimple’,
img_path=dataset_path,
data_list=‘test_gt.txt’
),
test=dict(
type=‘TuSimple’,
img_path=dataset_path,
data_list=‘test_gt.txt’
)
)
loss_type = ‘cross_entropy’
seg_loss_weight = 1.0
batch_size = 4
workers = 12
num_classes = 6 + 1
ignore_label = 255
epochs = 300
log_interval = 100
eval_ep = 1
save_ep = epochs
log_note = ‘’
/home/wenqiang/.conda/envs/wqf/lib/python3.8/site-packages/torch/serialization.py:701: UserWarning: Legacy tensor constructor is deprecated. Use: torch.tensor(…) for creating tensors from tensor-like objects; or torch.empty(…) for creating an uninitialized tensor with specific sizes. (Triggered internally at /opt/conda/conda-bld/pytorch_1617606367871/work/torch/csrc/utils/tensor_new.cpp:476.)
tensor = tensor_type().set(storage, storage_offset, size, stride)
2021-04-08 19:37:20,150 - resa - INFO - Network:
DataParallel(
(module): RESANet(
(backbone): ResNetWrapper(
(model): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(out): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(resa): RESA(
(conv_d0): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_u0): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_r0): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_l0): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_d1): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_u1): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_r1): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_l1): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_d2): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_u2): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_r2): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_l2): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_d3): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_u3): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_r3): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_l3): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_d4): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_u4): Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)
(conv_r4): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
(conv_l4): Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)
)
(decoder): BUSD(
(layers): ModuleList(
(0): UpsamplerBlock(
(conv): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(follows): ModuleList(
(0): non_bottleneck_1d(
(conv3x1_1): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
(1): non_bottleneck_1d(
(conv3x1_1): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
)
(interpolate_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(interpolate_bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(1): UpsamplerBlock(
(conv): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(follows): ModuleList(
(0): non_bottleneck_1d(
(conv3x1_1): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
(1): non_bottleneck_1d(
(conv3x1_1): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(32, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(32, 32, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
)
(interpolate_conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(interpolate_bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(2): UpsamplerBlock(
(conv): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(follows): ModuleList(
(0): non_bottleneck_1d(
(conv3x1_1): Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(16, 16, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(16, 16, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
(1): non_bottleneck_1d(
(conv3x1_1): Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_1): Conv2d(16, 16, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn1): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(conv3x1_2): Conv2d(16, 16, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
(conv1x3_2): Conv2d(16, 16, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
(bn2): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout2d(p=0, inplace=False)
)
)
(interpolate_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(interpolate_bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)
(output_conv): Conv2d(16, 7, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(heads): ExistHead(
(dropout): Dropout2d(p=0.1, inplace=False)
(conv8): Conv2d(128, 7, kernel_size=(1, 1), stride=(1, 1))
(fc9): Linear(in_features=6440, out_features=128, bias=True)
(fc10): Linear(in_features=128, out_features=6, bias=True)
)
)
)
main.py
中from runner.runner import Runner
解读如下:
import time
import torch
import numpy as np
from tqdm import tqdm
import pytorch_warmup as warmupfrom models.registry import build_net
from .registry import build_trainer, build_evaluator
from .optimizer import build_optimizer
from .scheduler import build_scheduler
from datasets import build_dataloader
from .recorder import build_recorder
from .net_utils import save_model, load_networkclass Runner(object):def __init__(self, cfg):self.cfg = cfgself.recorder = build_recorder(self.cfg)self.net = build_net(self.cfg)self.net = torch.nn.parallel.DataParallel(self.net, device_ids = range(self.cfg.gpus)).cuda() # 数据并行。多块GPU时有用self.recorder.logger.info('Network: \n' + str(self.net)) # 这里输出模型结构,见上述。self.resume() # 断点续训(载入模型)self.optimizer = build_optimizer(self.cfg, self.net)self.scheduler = build_scheduler(self.cfg, self.optimizer)self.evaluator = build_evaluator(self.cfg)self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period=5000)self.metric = 0.def resume(self):if not self.cfg.load_from and not self.cfg.finetune_from:returnload_network(self.net, self.cfg.load_from,finetune_from=self.cfg.finetune_from, logger=self.recorder.logger)def to_cuda(self, batch):for k in batch:if k == 'meta':continuebatch[k] = batch[k].cuda() # 把除了meta外的其他信息放到cuda上。包括img、label、exist。return batchdef train_epoch(self, epoch, train_loader):self.net.train()end = time.time()max_iter = len(train_loader)for i, data in enumerate(train_loader):if self.recorder.step >= self.cfg.total_iter:breakdate_time = time.time() - endself.recorder.step += 1data = self.to_cuda(data)output = self.trainer.forward(self.net, data)self.optimizer.zero_grad()loss = output['loss']loss.backward()self.optimizer.step()self.scheduler.step()self.warmup_scheduler.dampen()batch_time = time.time() - endend = time.time()self.recorder.update_loss_stats(output['loss_stats'])self.recorder.batch_time.update(batch_time)self.recorder.data_time.update(date_time)if i % self.cfg.log_interval == 0 or i == max_iter - 1:lr = self.optimizer.param_groups[0]['lr']self.recorder.lr = lrself.recorder.record('train')def train(self):self.recorder.logger.info('start training...')self.trainer = build_trainer(self.cfg)train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True)val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False)for epoch in range(self.cfg.epochs):self.recorder.epoch = epochself.train_epoch(epoch, train_loader)if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1:self.save_ckpt()if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1:self.validate(val_loader)if self.recorder.step >= self.cfg.total_iter:breakdef validate(self, val_loader):self.net.eval()for i, data in enumerate(tqdm(val_loader, desc=f'Validate')): # {DataLoader:696}# desc- 进度条标题data = self.to_cuda(data) # data是一个batch(4张图片)with torch.no_grad(): # valoutput = self.net(data['img']) self.evaluator.evaluate(output, data) # 对结果进行评价metric = self.evaluator.summarize() # 记录最好的accif not metric:returnif metric > self.metric:self.metric = metricself.save_ckpt(is_best=True)self.recorder.logger.info('Best metric: ' + str(self.metric)) # 保存最好的模型、日志输出def save_ckpt(self, is_best=False):save_model(self.net, self.optimizer, self.scheduler,self.recorder, is_best)
runner.runner
中from .recorder import build_recorder
解读如下:
from collections import deque, defaultdict
import torch
import os
import datetime
from .logger import get_loggerclass SmoothedValue(object):"""Track a series of values and provide access to smoothed values over awindow or the global series average."""
# 跟踪一系列值,并提供对窗口上的平滑值或全局序列平均值的访问def __init__(self, window_size=20):self.deque = deque(maxlen=window_size)self.total = 0.0self.count = 0def update(self, value):self.deque.append(value)self.count += 1self.total += value@propertydef median(self):d = torch.tensor(list(self.deque))return d.median().item()@propertydef avg(self):d = torch.tensor(list(self.deque))return d.mean().item()@propertydef global_avg(self):return self.total / self.countclass Recorder(object):def __init__(self, cfg):self.cfg = cfgself.work_dir = self.get_work_dir()cfg.work_dir = self.work_dirself.log_path = os.path.join(self.work_dir, 'log.txt')self.logger = get_logger('resa', self.log_path)self.logger.info('Config: \n' + cfg.text)# scalarsself.epoch = 0self.step = 0self.loss_stats = defaultdict(SmoothedValue) # Python中通过Key访问字典,当Key不存在时,会引发‘KeyError’异常。为了避免这种情况的发生,可以使用collections类中的defaultdict()方法来为字典提供默认值。https://blog.csdn.net/yangsong95/article/details/82319675self.batch_time = SmoothedValue()self.data_time = SmoothedValue()self.max_iter = self.cfg.total_iter self.lr = 0.def get_work_dir(self):now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size)work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str)if not os.path.exists(work_dir):os.makedirs(work_dir)return work_dirdef update_loss_stats(self, loss_dict):for k, v in loss_dict.items():self.loss_stats[k].update(v.detach().cpu())def record(self, prefix, step=-1, loss_stats=None, image_stats=None):self.logger.info(self)# self.write(str(self))def write(self, content):with open(self.log_path, 'a+') as f:f.write(content)f.write('\n')def state_dict(self):scalar_dict = {}scalar_dict['step'] = self.stepreturn scalar_dictdef load_state_dict(self, scalar_dict):self.step = scalar_dict['step']def __str__(self):loss_state = []for k, v in self.loss_stats.items():loss_state.append('{}: {:.4f}'.format(k, v.avg))loss_state = ' '.join(loss_state)recording_state = ' '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}'])eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step)eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string)def build_recorder(cfg):return Recorder(cfg)
runner.logger
中from .logger import get_logger
解读如下:
import logginglogger_initialized = {}def get_logger(name, log_file=None, log_level=logging.INFO):"""Initialize and get a logger by name.If the logger has not been initialized, this method will initialize thelogger by adding one or two handlers, otherwise the initialized logger willbe directly returned. During initialization, a StreamHandler will always beadded. If `log_file` is specified and the process rank is 0, a FileHandlerwill also be added.Args:name (str): Logger name.log_file (str | None): The log filename. If specified, a FileHandlerwill be added to the logger.log_level (int): The logger level. Note that only the process ofrank 0 is affected, and other processes will set the level to"Error" thus be silent most of the time.Returns:logging.Logger: The expected logger."""
# 按名称初始化并获取记录器。
# 如果记录器还没有初始化,这个方法将通过添加一个或两个处理程序来初始化记录器,否则初始化的记录器将直接返回。在初始化期间,将始终添加# StreamHandler。如果指定了' log_file '并且进程级别为0,则还会添加一个文件处理程序。logger = logging.getLogger(name)if name in logger_initialized:return logger# handle hierarchical names# e.g., logger "a" is initialized, then logger "a.b" will skip the# initialization since it is a child of "a".for logger_name in logger_initialized:if name.startswith(logger_name):return loggerstream_handler = logging.StreamHandler()handlers = [stream_handler]if log_file is not None:file_handler = logging.FileHandler(log_file, 'w')handlers.append(file_handler)formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')for handler in handlers:handler.setFormatter(formatter)handler.setLevel(log_level)logger.addHandler(handler)logger.setLevel(log_level)logger_initialized[name] = Truereturn logger
runner.runner
中from models.registry import build_net
解读如下:
from utils import Registry, build_from_cfgNET = Registry('net')def build(cfg, registry, default_args=None):if isinstance(cfg, list):modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]return nn.Sequential(*modules)else:return build_from_cfg(cfg, registry, default_args)def build_net(cfg):return build(cfg.net, NET, default_args=dict(cfg=cfg))
models.registry.py
中from utils import Registry, build_from_cfg
解读如下:
import inspectimport six# borrow from mmdetectiondef is_str(x):"""Whether the input is an string instance."""return isinstance(x, six.string_types)class Registry(object):def __init__(self, name):self._name = nameself._module_dict = dict()def __repr__(self):format_str = self.__class__.__name__ + '(name={}, items={})'.format(self._name, list(self._module_dict.keys()))return format_str@propertydef name(self):return self._name@propertydef module_dict(self):return self._module_dictdef get(self, key):return self._module_dict.get(key, None) def _register_module(self, module_class):"""Register a module.Args:module (:obj:`nn.Module`): Module to be registered."""if not inspect.isclass(module_class):raise TypeError('module must be a class, but got {}'.format(type(module_class)))module_name = module_class.__name__if module_name in self._module_dict:raise KeyError('{} is already registered in {}'.format(module_name, self.name))self._module_dict[module_name] = module_classdef register_module(self, cls):self._register_module(cls)return clsdef build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (dict): Config dict. It should at least contain the key "type".registry (:obj:`Registry`): The registry to search the type from.default_args (dict, optional): Default initialization arguments.Returns:obj: The constructed object."""assert isinstance(cfg, dict) and 'type' in cfgassert isinstance(default_args, dict) or default_args is Noneargs = {}obj_type = cfg.type if is_str(obj_type): # 判断是否是字符串类型obj_cls = registry.get(obj_type) # 按照字典方式取值。obj_type:'RESANet'if obj_cls is None:raise KeyError('{} is not in the {} registry'.format(obj_type, registry.name))elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError('type must be a str or valid type, but got {}'.format(type(obj_type)))if default_args is not None:for name, value in default_args.items():args.setdefault(name, value) # 字典的内置方法。给args赋值。# 关于setdefault:如果key不在字典中,则使用默认值插入key。# 如果key在字典中,则返回key的值,否则为默认值。# https://www.w3school.com.cn/python/ref_dictionary_setdefault.asp# args:{dict:1}{'cfg': Config (path: configs/tusimple.py): {'net': {'type': 'RESANet'}, 'backbone': {'type': 'ResNetWrapper', 'resnet': 'resnet34', 'pretrained': True, 'replace_stride_with_dilation': [False, True, True], 'out_conv': True, 'fea_stride': 8}, 'resa': {'type': 'RESA', 'alpha': 2.0, 'iter': 5, 'input_channel': 128, 'conv_stride': 9}, 'decoder': 'BUSD', 'trainer': {'type': 'RESA'}, 'evaluator': {'type': 'Tusimple', 'thresh': 0.6}, 'optimizer': {'type': 'sgd', 'lr': 0.02, 'weight_decay': 0.0001, 'momentum': 0.9}, 'total_iter': 80000, 'math': <module 'math' from '/home/wenqiang/.conda/envs/wqf/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>, 'scheduler': {'type': 'LambdaLR', 'lr_lambda': <function <lambda> at 0x7f258306f4c0>}, 'bg_weight': 0.4, 'img_norm': {'mean': [103.939, 116.779, 123.68], 'std': [1.0, 1.0, 1.0]}, 'img_height': 368, 'img_width': 640, 'cut_height': 160, 'seg_label': 'seg_label', 'dataset_path': './data/tusimple', 'test_json_file': './data/tusimple/test_label.json', 'dataset': {'train': {'type': 'TuSimple', 'img_path': './data/tusimple', 'data_list': 'train_val_gt.txt'}, 'val': {'type': 'TuSimple', 'img_path': './data/tusimple', 'data_list': 'test_gt.txt'}, 'test': {'type': 'TuSimple', 'img_path': './data/tusimple', 'data_list': 'test_gt.txt'}}, 'loss_type': 'cross_entropy', 'seg_loss_weight': 1.0, 'batch_size': 4, 'workers': 12, 'num_classes': 7, 'ignore_label': 255, 'epochs': 300, 'log_interval': 100, 'eval_ep': 1, 'save_ep': 300, 'log_note': '', 'gpus': 1, 'load_from': 'tusimple_resa34.pth', 'finetune_from': None, 'work_dirs': 'work_dirs/TuSimple', 'work_dir': 'work_dirs/TuSimple/20210408_103702_lr_2e-02_b_4'}}return obj_cls(**args) # 把args作为参数传给obj_cls。{<class 'models.resa.RESANet'>}# 这用法挺高级,把args传给obj_cls后,便创建了一个类。会跳到下述代码。
models/resa.py
解读如下:
import torch.nn as nn
import torch
import torch.nn.functional as Ffrom models.registry import NET
from .resnet import ResNetWrapper
from .decoder import BUSD, PlainDecoder class RESA(nn.Module): # 这个类是RESA。本文的关键模块def __init__(self, cfg):super(RESA, self).__init__()self.iter = cfg.resa.iter # 迭代次数 5chan = cfg.resa.input_channel # 通道数 128 fea_stride = cfg.backbone.fea_stride # 特征切片步长 8self.height = cfg.img_height // fea_stride # 对高度进行切片 46self.width = cfg.img_width // fea_stride # 对宽度进行切片 80self.alpha = cfg.resa.alpha # 2.0conv_stride = cfg.resa.conv_stride # 9for i in range(self.iter):conv_vert1 = nn.Conv2d(chan, chan, (1, conv_stride),padding=(0, conv_stride//2), groups=1, bias=False) # Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)conv_vert2 = nn.Conv2d(chan, chan, (1, conv_stride),padding=(0, conv_stride//2), groups=1, bias=False) # Conv2d(128, 128, kernel_size=(1, 9), stride=(1, 1), padding=(0, 4), bias=False)setattr(self, 'conv_d'+str(i), conv_vert1) # 官方解释:setattr(x, 'y', v) is equivalent to ``x.y = v''setattr(self, 'conv_u'+str(i), conv_vert2)conv_hori1 = nn.Conv2d(chan, chan, (conv_stride, 1),padding=(conv_stride//2, 0), groups=1, bias=False) # Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)conv_hori2 = nn.Conv2d(chan, chan, (conv_stride, 1),padding=(conv_stride//2, 0), groups=1, bias=False) # Conv2d(128, 128, kernel_size=(9, 1), stride=(1, 1), padding=(4, 0), bias=False)setattr(self, 'conv_r'+str(i), conv_hori1)setattr(self, 'conv_l'+str(i), conv_hori2)idx_d = (torch.arange(self.height) + self.height //2**(self.iter - i)) % self.heightsetattr(self, 'idx_d'+str(i), idx_d) # tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
# 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
# 37, 38, 39, 40, 41, 42, 43, 44, 45, 0])idx_u = (torch.arange(self.height) - self.height //2**(self.iter - i)) % self.heightsetattr(self, 'idx_u'+str(i), idx_u) # tensor([45, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
# 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
# 35, 36, 37, 38, 39, 40, 41, 42, 43, 44])idx_r = (torch.arange(self.width) + self.width //2**(self.iter - i)) % self.widthsetattr(self, 'idx_r'+str(i), idx_r) # tensor([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
# 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
# 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
# 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73,
# 74, 75, 76, 77, 78, 79, 0, 1])idx_l = (torch.arange(self.width) - self.width //2**(self.iter - i)) % self.widthsetattr(self, 'idx_l'+str(i), idx_l) # tensor([78, 79, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
# 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
# 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
# 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
# 70, 71, 72, 73, 74, 75, 76, 77])def forward(self, x):x = x.clone()for direction in ['d', 'u']:for i in range(self.iter):conv = getattr(self, 'conv_' + direction + str(i))idx = getattr(self, 'idx_' + direction + str(i))x.add_(self.alpha * F.relu(conv(x[..., idx, :])))for direction in ['r', 'l']:for i in range(self.iter):conv = getattr(self, 'conv_' + direction + str(i))idx = getattr(self, 'idx_' + direction + str(i))x.add_(self.alpha * F.relu(conv(x[..., idx])))return xclass ExistHead(nn.Module):def __init__(self, cfg=None):super(ExistHead, self).__init__()self.cfg = cfgself.dropout = nn.Dropout2d(0.1) # ???self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)stride = cfg.backbone.fea_stride * 2self.fc9 = nn.Linear(int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)self.fc10 = nn.Linear(128, cfg.num_classes-1)def forward(self, x):x = self.dropout(x)x = self.conv8(x)x = F.softmax(x, dim=1)x = F.avg_pool2d(x, 2, stride=2, padding=0)x = x.view(-1, x.numel() // x.shape[0])x = self.fc9(x)x = F.relu(x)x = self.fc10(x)x = torch.sigmoid(x)return x@NET.register_module
class RESANet(nn.Module):def __init__(self, cfg):super(RESANet, self).__init__()self.cfg = cfgself.backbone = ResNetWrapper(cfg)self.resa = RESA(cfg)self.decoder = eval(cfg.decoder)(cfg)self.heads = ExistHead(cfg) def forward(self, batch):fea = self.backbone(batch) # [4, 3, 368, 640]fea = self.resa(fea) # [4, 128, 46, 80]seg = self.decoder(fea) # [4, 7, 368, 640] 四张图片,每张图片用七维嵌入表示exist = self.heads(fea) # [4, 6]output = {'seg': seg, 'exist': exist}return output
上述代码from .resnet import ResNetWrapper
解读如下:
import torch
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url# This code is borrow from torchvision.model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth','resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth','resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth','wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth','wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')# if dilation > 1:# raise NotImplementedError(# "Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes, dilation=dilation)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNetWrapper(nn.Module):def __init__(self, cfg):super(ResNetWrapper, self).__init__()self.cfg = cfgself.in_channels = [64, 128, 256, 512]if 'in_channels' in cfg.backbone:self.in_channels = cfg.backbone.in_channelsself.model = eval(cfg.backbone.resnet)(pretrained=cfg.backbone.pretrained,replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels)self.out = Noneif cfg.backbone.out_conv:out_channel = 512for chan in reversed(self.in_channels):if chan < 0: continueout_channel = chanbreakself.out = conv1x1(out_channel * self.model.expansion, 128)def forward(self, x):x = self.model(x) # [4, 512, 46, 80]if self.out:x = self.out(x) # [4, 128, 46, 80]return xclass ResNet(nn.Module):def __init__(self, block, layers, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None, in_channels=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False) # self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.in_channels = in_channelsself.layer1 = self._make_layer(block, in_channels[0], layers[0]) # 大体相似self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2,dilate=replace_stride_with_dilation[1])if in_channels[3] > 0:self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.expansion = block.expansion# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))# self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules(): # 关于self.modules()# Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
# Yields:# (string, Module): Tuple containing a name and child module# Example::# >>> for name, module in model.named_children():# >>> if name in ['conv4', 'conv5']:# >>> print(module)if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_( # 初始化方法。https://blog.csdn.net/weixin_36670529/article/details/101776253。有了这些初始值估计更容易训练吧m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers) #
# layers:[BasicBlock(
# (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#), BasicBlock(
# (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#), BasicBlock(
# (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#)]def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x) # [4, 64, 184, 320]x = self.maxpool(x) # [4, 64, 92, 160]x = self.layer1(x) # [4, 64, 92, 160]x = self.layer2(x) # [4, 128, 46, 80]x = self.layer3(x) # [4, 256, 46, 80]if self.in_channels[3] > 0: # self.in_channels:[64, 128, 256, 512]x = self.layer4(x) # [4, 512, 46, 80]# x = self.avgpool(x)# x = torch.flatten(x, 1)# x = self.fc(x)return xdef _resnet(arch, block, layers, pretrained, progress, **kwargs):
# arch:'resnet34'
# block:<class 'models.resnet.BasicBlock'>
# layers:[3, 4, 6, 3]
# pretrained:True
# progress:True
# **kwargs:{'replace_stride_with_dilation': [False, True, True], 'in_channels': [64, 128, 256, 512]}model = ResNet(block, layers, **kwargs)if pretrained:state_dict = load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict, strict=False)return modeldef resnet18(pretrained=False, progress=True, **kwargs):r"""ResNet-18 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,**kwargs)def resnet34(pretrained=False, progress=True, **kwargs):r"""ResNet-34 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,**kwargs)def resnet50(pretrained=False, progress=True, **kwargs):r"""ResNet-50 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,**kwargs)def resnet101(pretrained=False, progress=True, **kwargs):r"""ResNet-101 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,**kwargs)def resnet152(pretrained=False, progress=True, **kwargs):r"""ResNet-152 model from`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,**kwargs)def resnext50_32x4d(pretrained=False, progress=True, **kwargs):r"""ResNeXt-50 32x4d model from`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""kwargs['groups'] = 32kwargs['width_per_group'] = 4return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],pretrained, progress, **kwargs)def resnext101_32x8d(pretrained=False, progress=True, **kwargs):r"""ResNeXt-101 32x8d model from`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""kwargs['groups'] = 32kwargs['width_per_group'] = 8return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],pretrained, progress, **kwargs)def wide_resnet50_2(pretrained=False, progress=True, **kwargs):r"""Wide ResNet-50-2 model from`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_The model is the same as ResNet except for the bottleneck number of channelswhich is twice larger in every block. The number of channels in outer 1x1convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048channels, and in Wide ResNet-50-2 has 2048-1024-2048.Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""kwargs['width_per_group'] = 64 * 2return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],pretrained, progress, **kwargs)def wide_resnet101_2(pretrained=False, progress=True, **kwargs):r"""Wide ResNet-101-2 model from`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_The model is the same as ResNet except for the bottleneck number of channelswhich is twice larger in every block. The number of channels in outer 1x1convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048channels, and in Wide ResNet-50-2 has 2048-1024-2048.Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""kwargs['width_per_group'] = 64 * 2return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],pretrained, progress, **kwargs)
resa.py
中from .decoder import BUSD, PlainDecoder
解读如下:
from torch import nn
import torch.nn.functional as Fclass PlainDecoder(nn.Module):def __init__(self, cfg):super(PlainDecoder, self).__init__()self.cfg = cfgself.dropout = nn.Dropout2d(0.1)self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)def forward(self, x):x = self.dropout(x)x = self.conv8(x)x = F.interpolate(x, size=[self.cfg.img_height, self.cfg.img_width],mode='bilinear', align_corners=False)return xdef conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class non_bottleneck_1d(nn.Module):def __init__(self, chann, dropprob, dilated):super().__init__()self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True)self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True)self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True,dilation=(dilated, 1))self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True,dilation=(1, dilated))self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)self.dropout = nn.Dropout2d(dropprob)def forward(self, input):output = self.conv3x1_1(input)output = F.relu(output)output = self.conv1x3_1(output)output = self.bn1(output)output = F.relu(output)output = self.conv3x1_2(output)output = F.relu(output)output = self.conv1x3_2(output)output = self.bn2(output)if (self.dropout.p != 0):output = self.dropout(output)# +input = identity (residual connection)return F.relu(output + input)class UpsamplerBlock(nn.Module):def __init__(self, ninput, noutput, up_width, up_height):super().__init__()self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True)self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)self.follows = nn.ModuleList()self.follows.append(non_bottleneck_1d(noutput, 0, 1))self.follows.append(non_bottleneck_1d(noutput, 0, 1))# interpolateself.up_width = up_widthself.up_height = up_heightself.interpolate_conv = conv1x1(ninput, noutput)self.interpolate_bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True)def forward(self, input):output = self.conv(input)output = self.bn(output)out = F.relu(output)for follow in self.follows:out = follow(out)interpolate_output = self.interpolate_conv(input)interpolate_output = self.interpolate_bn(interpolate_output)interpolate_output = F.relu(interpolate_output)interpolate = F.interpolate(interpolate_output, size=[self.up_height, self.up_width],mode='bilinear', align_corners=False)return out + interpolateclass BUSD(nn.Module):def __init__(self, cfg):super().__init__()img_height = cfg.img_height # 368img_width = cfg.img_width # 640num_classes = cfg.num_classes # 7self.layers = nn.ModuleList() # self.layers.append(UpsamplerBlock(ninput=128, noutput=64,up_height=int(img_height)//4, up_width=int(img_width)//4))self.layers.append(UpsamplerBlock(ninput=64, noutput=32,up_height=int(img_height)//2, up_width=int(img_width)//2))self.layers.append(UpsamplerBlock(ninput=32, noutput=16,up_height=int(img_height)//1, up_width=int(img_width)//1))self.output_conv = conv1x1(16, num_classes)def forward(self, input):output = inputfor layer in self.layers:output = layer(output)output = self.output_conv(output)return output
runner.runner
中from .optimizer import build_optimizer
解读如下:
import torch_optimizer_factory = {'adam': torch.optim.Adam,'sgd': torch.optim.SGD
}def build_optimizer(cfg, net):params = []lr = cfg.optimizer.lrweight_decay = cfg.optimizer.weight_decayfor key, value in net.named_parameters(): # net.named_parameters():输出是每层的参数,也就是weight和bias的值if not value.requires_grad:continueparams += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]if 'adam' in cfg.optimizer.type:optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay)else:optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum)return optimizer # SGD
runner.runner
中from .net_utils import save_model, load_network
解读如下:
import torch
import os
from torch import nn
import numpy as np
import torch.nn.functional
from termcolor import colored
from .logger import get_loggerdef save_model(net, optim, scheduler, recorder, is_best=False):model_dir = os.path.join(recorder.work_dir, 'ckpt')os.system('mkdir -p {}'.format(model_dir))epoch = recorder.epochckpt_name = 'best' if is_best else epochtorch.save({'net': net.state_dict(),'optim': optim.state_dict(),'scheduler': scheduler.state_dict(),'recorder': recorder.state_dict(),'epoch': epoch}, os.path.join(model_dir, '{}.pth'.format(ckpt_name)))def load_network_specified(net, model_dir, logger=None):pretrained_net = torch.load(model_dir)['net']net_state = net.state_dict()state = {}for k, v in pretrained_net.items():if k not in net_state.keys() or v.size() != net_state[k].size():if logger:logger.info('skip weights: ' + k)continuestate[k] = vnet.load_state_dict(state, strict=False)def load_network(net, model_dir, finetune_from=None, logger=None):if finetune_from:if logger:logger.info('Finetune model from: ' + finetune_from)load_network_specified(net, finetune_from, logger)returnpretrained_model = torch.load(model_dir)net.load_state_dict(pretrained_model['net'], strict=True) # 载入模型的参数。https://blog.csdn.net/t20134297/article/details/110533007
上述代码中载入模型时会报错:
pretrained_model = torch.load(model_dir)
报错信息如下:
Traceback (most recent call last):
File “/home/wenqiang/.conda/envs/wqf/lib/python3.8/site-packages/torch/serialization.py”, line 593, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File “/home/wenqiang/.conda/envs/wqf/lib/python3.8/site-packages/torch/serialization.py”, line 779, in _legacy_load
deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
RuntimeError: storage has wrong size: expected 7667828046330154142 got 147456
terminate called without an active exception
Process finished with exit code 134 (interrupted by signal 6: SIGABRT)
好像是由于模型损坏了,重新下载了模型,果然好了。
runner.runner
中from .scheduler import build_scheduler
解读如下:
import torch
import math_scheduler_factory = {'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
}def build_scheduler(cfg, optimizer):assert cfg.scheduler.type in _scheduler_factorycfg_cp = cfg.scheduler.copy()cfg_cp.pop('type')scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp)return scheduler # LambdaLR
runner.runner.py
中from .registry import build_trainer, build_evaluator
解读如下:
from utils import Registry, build_from_cfgTRAINER = Registry('trainer')
EVALUATOR = Registry('evaluator')def build(cfg, registry, default_args=None):if isinstance(cfg, list):modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]return nn.Sequential(*modules)else:return build_from_cfg(cfg, registry, default_args) # Registry(name=evaluator, items=['Tusimple', 'CULane'])def build_trainer(cfg):return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg))def build_evaluator(cfg):return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg))
utils.registry.py中最后
return obj_cls(**args)
运行到这里时,obj_cls:<class ‘runner.evaluator.tusimple.tusimple.Tusimple’>。把cfg配置文件传进去后同样会创建类。即如下:
runner.evaluator.tusimple.tusimple
代码解读如下:
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.logger import get_loggerfrom runner.registry import EVALUATOR
from .getLane import prob2lines_tusimple
import json
import osfrom .lane import LaneEvaldef split_path(path):"""split path tree into list"""folders = []while True:path, folder = os.path.split(path)if folder != "":folders.insert(0, folder)else:if path != "":folders.insert(0, path)breakreturn folders@EVALUATOR.register_module
class Tusimple(nn.Module):def __init__(self, cfg):super(Tusimple, self).__init__()self.cfg = cfg exp_dir = os.path.join(self.cfg.work_dir, "output") # 'work_dirs/TuSimple/20210409_091938_lr_2e-02_b_4/output'if not os.path.exists(exp_dir):os.mkdir(exp_dir)self.out_path = os.path.join(exp_dir, "coord_output") # 'work_dirs/TuSimple/20210409_091938_lr_2e-02_b_4/output/coord_output'if not os.path.exists(self.out_path):os.mkdir(self.out_path)self.dump_to_json = [] self.thresh = cfg.evaluator.thresh # 0.6self.logger = get_logger('resa') # def evaluate_pred(self, seg_pred, exist_pred, img_name, thr):for b in range(len(seg_pred)):seg = seg_pred[b]exist = [1 if exist_pred[b, i] >0.5 else 0 for i in range(self.cfg.num_classes-1)]lane_coords = prob2lines_tusimple(seg, exist, resize_shape=(720, 1280), y_px_gap=10, pts=56, thresh = thr, cfg=self.cfg)for i in range(len(lane_coords)):lane_coords[i] = sorted(lane_coords[i], key=lambda pair: pair[1])path_tree = split_path(img_name[b])save_dir, save_name = path_tree[-3:-1], path_tree[-1]save_dir = os.path.join(self.out_path, *save_dir)save_name = save_name[:-3] + "lines.txt"save_name = os.path.join(save_dir, save_name)if not os.path.exists(save_dir):os.makedirs(save_dir, exist_ok=True)with open(save_name, "w") as f:for l in lane_coords:for (x, y) in l:print("{} {}".format(x, y), end=" ", file=f)print(file=f)json_dict = {}json_dict['lanes'] = []json_dict['h_sample'] = []json_dict['raw_file'] = os.path.join(*path_tree[-4:])json_dict['run_time'] = 0for l in lane_coords:if len(l) == 0:continuejson_dict['lanes'].append([])for (x, y) in l:json_dict['lanes'][-1].append(int(x))for (x, y) in lane_coords[0]:json_dict['h_sample'].append(y)self.dump_to_json.append(json.dumps(json_dict))def evaluate(self, output, batch):seg_pred, exist_pred = output['seg'], output['exist']seg_pred = F.softmax(seg_pred, dim=1)seg_pred = seg_pred.detach().cpu().numpy() # [4, 7, 368, 640]exist_pred = exist_pred.detach().cpu().numpy() # [4, 6]img_name = batch['meta']['file_name'] # self.evaluate_pred(seg_pred, exist_pred, img_name, self.thresh)def summarize(self):best_acc = 0output_file = os.path.join(self.out_path, 'predict_test.json')with open(output_file, "w+") as f:for line in self.dump_to_json:print(line, end="\n", file=f)eval_result, acc = LaneEval.bench_one_submit(output_file,self.cfg.test_json_file)self.logger.info(eval_result)self.dump_to_json = []best_acc = max(acc, best_acc)return best_acc
main.py
中from datasets import build_dataloader
解读如下:
datasets
文件夹下__init__.py
文件:
from .registry import build_dataset, build_dataloaderfrom .tusimple import TuSimple
from .culane import CULane
因此会跳转到.registry.py文件。解读如下:
from utils import Registry, build_from_cfgimport torchDATASETS = Registry('datasets')def build(cfg, registry, default_args=None):if isinstance(cfg, list):modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]return nn.Sequential(*modules)else:return build_from_cfg(cfg, registry, default_args)def build_dataset(split_cfg, cfg):args = split_cfg.copy() # {'type': 'TuSimple', 'img_path': './data/tusimple', 'data_list': 'test_gt.txt'}# dict.copy():https://www.runoob.com/python/att-dictionary-copy.htmlargs.pop('type')args = args.to_dict()args['cfg'] = cfgreturn build(split_cfg, DATASETS, default_args=args)def build_dataloader(split_cfg, cfg, is_train=True):if is_train:shuffle = Trueelse:shuffle = Falsedataset = build_dataset(split_cfg, cfg) # {Tusimple:2782}data_loader = torch.utils.data.DataLoader(dataset, batch_size = cfg.batch_size, shuffle = shuffle,num_workers = cfg.workers, pin_memory = False, drop_last = False)return data_loader # {DataLoader:696}
跳转到utils.registry.py
文件。再到datasets.tusimple.py
文件解读如下:
import os.path as osp
import numpy as np
import torchvision
import utils.transforms as tf
from .base_dataset import BaseDataset
from .registry import DATASETS@DATASETS.register_module
class TuSimple(BaseDataset):def __init__(self, img_path, data_list, cfg=None):super().__init__(img_path, data_list, 'seg_label/list', cfg) # img_path:'./data/tusimple'# data_list:'test_gt.txt'def transform_train(self):input_mean = self.cfg.img_norm['mean']train_transform = torchvision.transforms.Compose([tf.GroupRandomRotation(),tf.GroupRandomHorizontalFlip(),tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(self.cfg.img_norm['std'], (1, ))),])return train_transformdef init(self):with open(osp.join(self.list_path, self.data_list)) as f:for line in f:line_split = line.strip().split(" ") # 把一个line转换为列表self.img.append(line_split[0])self.img_list.append(self.img_path + line_split[0])self.label_list.append(self.img_path + line_split[1])self.exist_list.append(np.array([int(line_split[2]), int(line_split[3]),int(line_split[4]), int(line_split[5]),int(line_split[6]), int(line_split[7])]))
跳转到.base_dataset.py
解读如下:
import os.path as osp
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS@DATASETS.register_module
class BaseDataset(Dataset):def __init__(self, img_path, data_list, list_path='list', cfg=None):self.cfg = cfgself.img_path = img_pathself.list_path = osp.join(img_path, list_path)self.data_list = data_listself.is_testing = ('test' in data_list)self.img = []self.img_list = []self.label_list = []self.exist_list = []self.transform = self.transform_val() if self.is_testing else self.transform_train()self.init()def transform_train(self):raise NotImplementedError()def transform_val(self):val_transform = torchvision.transforms.Compose([ # torchvision.transforms主要是用于常见的一些图形变换。https://blog.csdn.net/ai_faker/article/details/115320418tf.SampleResize((self.cfg.img_width, self.cfg.img_height)), # 图片尺寸resizetf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(self.cfg.img_norm['std'], (1, ))), # 一个batch的normaliz])return val_transformdef init(self):raise NotImplementedError() # 在父类中不实现此方法,在子类中实现。https://blog.csdn.net/Strive_For_Future/article/details/103587350# 会直接调用子类中的函数def __len__(self):return len(self.img_list) # {Tusimple:2782}def __getitem__(self, idx):img = cv2.imread(self.img_list[idx]).astype(np.float32)label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)if len(label.shape) > 2:label = label[:, :, 0]label = label.squeeze()img = img[self.cfg.cut_height:, :, :] # 对图像进行裁剪label = label[self.cfg.cut_height:, :]exist = self.exist_list[idx]if self.transform:img, label = self.transform((img, label)) # 对图像进行resize和normalization# img:{ndarray:{560,1280,3}}# label:{ndarray:{560,1280}}img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float() # permute:将Tensor换维https://zhuanlan.zhihu.com/p/76583143# contiguous:好像是让Tensor的存储变得连续https://zhuanlan.zhihu.com/p/64551412label = torch.from_numpy(label).contiguous().long()meta = {'file_name': self.img[idx]}data = {'img': img, 'label': label,'exist': exist, 'meta': meta}return data # data的一个例子如下
# {'img': tensor([[[ -37.7501, -34.0037, -33.6435, ..., -83.1600, -85.1200,-86.1836],[ -35.1362, -34.1299, -38.5709, ..., -82.5185, -83.1500,-83.4234],[ -47.4268, -43.6894, -44.5395, ..., -85.4847, -82.7770,-82.6082],...,[ -53.3754, -51.5970, -51.8585, ..., 108.2689, 106.4933,106.9001],[ -53.6510, -55.9953, -57.7836, ..., 74.6668, 70.6907,70.3713],[ -54.7331, -60.4019, -62.8242, ..., 54.7236, 52.9359,54.3570]],[[ -35.7873, -33.3145, -35.3849, ..., -93.9736, -95.1491,-95.8315],[ -31.6582, -31.1545, -35.2910, ..., -96.7250, -97.2785,-97.5377],[ -44.5988, -40.8372, -41.5554, ..., -97.0144, -94.2905,-94.4165],...,[ -65.2154, -63.4370, -63.6985, ..., 100.8151, 99.2394,98.7955],[ -65.4910, -67.8353, -69.6236, ..., 67.4828, 63.1377,61.6309],[ -66.5731, -72.2419, -74.6642, ..., 47.9025, 46.2274,46.4610]],[[ -24.3750, -21.8308, -24.5625, ..., -99.5207, -100.8123,-101.4947],[ -21.1447, -20.6436, -24.8099, ..., -101.3488, -102.1018,-102.3704],[ -34.6384, -30.9493, -31.8338, ..., -102.5025, -99.8886,-99.9371],...,[ -74.1164, -72.3380, -72.5995, ..., 103.3533, 101.3687,102.6836],[ -74.3920, -76.7363, -78.5246, ..., 68.1094, 64.6573,65.5307],[ -75.4741, -81.1429, -83.5652, ..., 46.0210, 46.2063,48.6720]]]), 'label': tensor([[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],...,[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0]]), 'exist': array([0, 0, 1, 1, 1, 0]), 'meta': {'file_name': '/clips/0530/1492626480958429729_0/20.jpg'}}
上述代码import utils.transforms as tf
解读如下:
import random
import cv2
import numpy as np
import numbers
import collections# copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py__all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad','GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize']class SampleResize(object):def __init__(self, size):assert (isinstance(size, collections.Iterable) and len(size) == 2) # assert:其作用是如果它的条件返回错误,则终止程序执行。https://www.runoob.com/w3cnote/c-assert.htmlself.size = sizedef __call__(self, sample): # 魔术方法http://c.biancheng.net/view/2380.html。该方法的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。out = list()out.append(cv2.resize(sample[0], self.size,interpolation=cv2.INTER_CUBIC)) # cv2.INTER_CUBIC:https://blog.csdn.net/Dontla/article/details/107017375# img:{ndarray:{560,1280,3}} → {368,640,3}# label:{ndarray:{560,1280}} → {368,640}out.append(cv2.resize(sample[1], self.size,interpolation=cv2.INTER_NEAREST))return outclass GroupRandomCrop(object):def __init__(self, size):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizedef __call__(self, img_group):h, w = img_group[0].shape[0:2]th, tw = self.sizeout_images = list()h1 = random.randint(0, max(0, h - th))w1 = random.randint(0, max(0, w - tw))h2 = min(h1 + th, h)w2 = min(w1 + tw, w)for img in img_group:assert (img.shape[0] == h and img.shape[1] == w)out_images.append(img[h1:h2, w1:w2, ...])return out_imagesclass GroupRandomCropRatio(object):def __init__(self, size):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizedef __call__(self, img_group):h, w = img_group[0].shape[0:2]tw, th = self.sizeout_images = list()h1 = random.randint(0, max(0, h - th))w1 = random.randint(0, max(0, w - tw))h2 = min(h1 + th, h)w2 = min(w1 + tw, w)for img in img_group:assert (img.shape[0] == h and img.shape[1] == w)out_images.append(img[h1:h2, w1:w2, ...])return out_imagesclass GroupCenterCrop(object):def __init__(self, size):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizedef __call__(self, img_group):h, w = img_group[0].shape[0:2]th, tw = self.sizeout_images = list()h1 = max(0, int((h - th) / 2))w1 = max(0, int((w - tw) / 2))h2 = min(h1 + th, h)w2 = min(w1 + tw, w)for img in img_group:assert (img.shape[0] == h and img.shape[1] == w)out_images.append(img[h1:h2, w1:w2, ...])return out_imagesclass GroupRandomPad(object):def __init__(self, size, padding):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizeself.padding = paddingdef __call__(self, img_group):assert (len(self.padding) == len(img_group))h, w = img_group[0].shape[0:2]th, tw = self.sizeout_images = list()h1 = random.randint(0, max(0, th - h))w1 = random.randint(0, max(0, tw - w))h2 = max(th - h - h1, 0)w2 = max(tw - w - w1, 0)for img, padding in zip(img_group, self.padding):assert (img.shape[0] == h and img.shape[1] == w)out_images.append(cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupCenterPad(object):def __init__(self, size, padding):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizeself.padding = paddingdef __call__(self, img_group):assert (len(self.padding) == len(img_group))h, w = img_group[0].shape[0:2]th, tw = self.sizeout_images = list()h1 = max(0, int((th - h) / 2))w1 = max(0, int((tw - w) / 2))h2 = max(th - h - h1, 0)w2 = max(tw - w - w1, 0)for img, padding in zip(img_group, self.padding):assert (img.shape[0] == h and img.shape[1] == w)out_images.append(cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupConcerPad(object):def __init__(self, size, padding):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizeself.padding = paddingdef __call__(self, img_group):assert (len(self.padding) == len(img_group))h, w = img_group[0].shape[0:2]th, tw = self.sizeout_images = list()h1 = 0w1 = 0h2 = max(th - h - h1, 0)w2 = max(tw - w - w1, 0)for img, padding in zip(img_group, self.padding):assert (img.shape[0] == h and img.shape[1] == w)out_images.append(cv2.copyMakeBorder(img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupRandomScaleNew(object):def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):self.size = sizeself.interpolation = interpolationdef __call__(self, img_group):assert (len(self.interpolation) == len(img_group))scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590out_images = list()for img, interpolation in zip(img_group, self.interpolation):out_images.append(cv2.resize(img, None, fx=scale_w,fy=scale_h, interpolation=interpolation))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupRandomScale(object):def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):self.size = sizeself.interpolation = interpolationdef __call__(self, img_group):assert (len(self.interpolation) == len(img_group))scale = random.uniform(self.size[0], self.size[1])out_images = list()for img, interpolation in zip(img_group, self.interpolation):out_images.append(cv2.resize(img, None, fx=scale,fy=scale, interpolation=interpolation))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupRandomMultiScale(object):def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):self.size = sizeself.interpolation = interpolationdef __call__(self, img_group):assert (len(self.interpolation) == len(img_group))scales = [0.5, 1.0, 1.5] # random.uniform(self.size[0], self.size[1])out_images = list()for scale in scales:for img, interpolation in zip(img_group, self.interpolation):out_images.append(cv2.resize(img, None, fx=scale, fy=scale, interpolation=interpolation))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imagesclass GroupRandomScaleRatio(object):def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)):self.size = sizeself.interpolation = interpolationself.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784,1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295]def __call__(self, img_group):assert (len(self.interpolation) == len(img_group))w_scale = random.randint(self.size[0], self.size[1])h_scale = random.randint(self.size[2], self.size[3])h, w, _ = img_group[0].shapeout_images = list()out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h,interpolation=self.interpolation[0])) # fx=w_scale*1.0/w, fy=h_scale*1.0/h### process label map ###origin_label = cv2.resize(img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1])origin_label = origin_label.astype(int)label = origin_label[:, :, 0] * 5 + \origin_label[:, :, 1] * 3 + origin_label[:, :, 2]new_label = np.ones(label.shape) * 100new_label = new_label.astype(int)for cnt in range(37):new_label = (label == self.origin_id[cnt]) * (cnt - 100) + new_labelnew_label = (label == self.origin_id[37]) * (36 - 100) + new_labelassert(100 not in np.unique(new_label))out_images.append(new_label)return out_imagesclass GroupRandomRotation(object):def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None):self.degree = degreeself.interpolation = interpolationself.padding = paddingif self.padding is None:self.padding = [0, 0]def __call__(self, img_group):assert (len(self.interpolation) == len(img_group))v = random.random()if v < 0.5:degree = random.uniform(self.degree[0], self.degree[1])h, w = img_group[0].shape[0:2]center = (w / 2, h / 2)map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0)out_images = list()for img, interpolation, padding in zip(img_group, self.interpolation, self.padding):out_images.append(cv2.warpAffine(img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding))if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imageselse:return img_groupclass GroupRandomBlur(object):def __init__(self, applied):self.applied = applieddef __call__(self, img_group):assert (len(self.applied) == len(img_group))v = random.random()if v < 0.5:out_images = []for img, a in zip(img_group, self.applied):if a:img = cv2.GaussianBlur(img, (5, 5), random.uniform(1e-6, 0.6))out_images.append(img)if len(img.shape) > len(out_images[-1].shape):out_images[-1] = out_images[-1][...,np.newaxis] # single channel imagereturn out_imageselse:return img_groupclass GroupRandomHorizontalFlip(object):"""Randomly horizontally flips the given numpy Image with a probability of 0.5"""def __init__(self, is_flow=False):self.is_flow = is_flowdef __call__(self, img_group, is_flow=False):v = random.random()if v < 0.5:out_images = [np.fliplr(img) for img in img_group]if self.is_flow:for i in range(0, len(out_images), 2):# invert flow pixel values when flippingout_images[i] = -out_images[i]return out_imageselse:return img_groupclass GroupNormalize(object):def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, img_group):out_images = list()for img, m, s in zip(img_group, self.mean, self.std):if len(m) == 1:img = img - np.array(m) # single channel imageimg = img / np.array(s)else:img = img - np.array(m)[np.newaxis, np.newaxis, ...]img = img / np.array(s)[np.newaxis, np.newaxis, ...]out_images.append(img)return out_images
runner.evaluator.tusimple.tusimple.py
中from .getLane import prob2lines_tusimple
解读如下:
import cv2
import numpy as npdef isShort(lane):start = [i for i, x in enumerate(lane) if x > 0]if not start:return 1else:return 0def fixGap(coordinate):if any(x > 0 for x in coordinate):start = [i for i, x in enumerate(coordinate) if x > 0][0]end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]lane = coordinate[start:end+1]if any(x < 0 for x in lane):gap_start = [i for i, x in enumerate(lane[:-1]) if x > 0 and lane[i+1] < 0]gap_end = [i+1 for i,x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]gap_id = [i for i, x in enumerate(lane) if x < 0]if len(gap_start) == 0 or len(gap_end) == 0:return coordinatefor id in gap_id:for i in range(len(gap_start)):if i >= len(gap_end):return coordinateif id > gap_start[i] and id < gap_end[i]:gap_width = float(gap_end[i] - gap_start[i])lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (gap_end[i] - id) / gap_width * lane[gap_start[i]])if not all(x > 0 for x in lane):print("Gaps still exist!")coordinate[start:end+1] = lanereturn coordinatedef getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None):"""Arguments:----------prob_map: prob map for single lane, np array size (h, w)resize_shape: reshape size target, (H, W)Return:----------coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape"""if resize_shape is None:resize_shape = prob_map.shapeh, w = prob_map.shapeH, W = resize_shapeH -= cfg.cut_heightcoords = np.zeros(pts)coords[:] = -1.0for i in range(pts):y = int((H - 10 - i * y_px_gap) * h / H) # 每隔6个像素遍历图片的列向量。{361,:}、{354,:}、{348,:}……if y < 0:breakline = prob_map[y, :]id = np.argmax(line)if line[id] > thresh: # 看该列的最大值是否大于阈值。如果大于,则把该列横坐标记下(转换到原始图片)coords[i] = int(id / w * W)if (coords > 0).sum() < 2:coords = np.zeros(pts)fixGap(coords)return coordsdef prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None):
# seg_pred:{ndarray:{7,368,640}}; existresize_shape:{list:6}[0, 1, 1, 1, 1, 0]; y_px_gap:10; pts:56; thresh:0.6; cfg"""Arguments:----------seg_pred: np.array size (5, h, w)resize_shape: reshape size target, (H, W)exist: list of existence, e.g. [0, 1, 1, 0]smooth: whether to smooth the probability or noty_px_gap: y pixel gap for samplingpts: how many points for one lanethresh: probability thresholdReturn:----------coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]"""if resize_shape is None:resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)_, h, w = seg_pred.shapeH, W = resize_shapecoordinates = []if pts is None:pts = round(H / 2 / y_px_gap)seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0))) # ascontiguousarray函数将一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快。{368,640,7}for i in range(cfg.num_classes - 1): prob_map = seg_pred[..., i + 1] # 取该图片的第一维嵌入(从0开始)(每个维度检测一条车道线)if smooth:prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) # 均值滤波coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg)if isShort(coords):continuecoordinates.append([[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j inrange(pts)]) # 记录车道线的坐标if len(coordinates) == 0:coords = np.zeros(pts)coordinates.append([[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j inrange(pts)])return coordinates
20210410。完结。
----------20210505----------
修改代码时报错信息如下:
Attribute Error: ‘NoneType’ object has no attribute ‘astype’
转自:https://blog.csdn.net/weixin_43826242/article/details/90325955
几天前在跑模型的时候遇到了这个错误,已经解决,现在补充记录一下。
解决关键:验证数据集中的照片格式是否正确
我出错的原因是数据集中所有的照片虽然都是.jpg格式的,但是有的照片没有进行转码,只是更改了后缀,因此在模型加载的时候才会报错,可以使用PIL工具将所有的jpg图片转换成为jpg图片,从而避免这样的麻烦。
经过2000次的迭代调试之后发现了报错的根源:cv2.imread()
而错误的原因正是因为后缀名和图片的实际格式不符,才会导致imread读入为空
下面是解决的一个小脚本:
————————————————
版权声明:本文为CSDN博主「CC_且听风吟」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43826242/article/details/90325955
# 将所有的图片转换成为jpg格式(防止因为图片格式造成的cv2.imread()异常)import PIL.Image as Image
import osdef start(Path):filelist = os.listdir(Path + 'JPEGImages/')for file in filelist:img = Image.open(Path + 'JPEGImages/' + file).convert('RGB')# print(img)img.save(Path + file)print('Done!')
待验证。