1.基础知识
搭建网络:去除resnet18 fc层,取出剩余网络,自定义fc层
class test(nn.Module):def __init__(self, num_classes=2):super().__init__()net = models.resnet18(pretrained=True)net.fc= nn.Sequential() # 将分类层(fc)置空self.features = netself.fc= nn.Sequential( # 定义一个卷积网络结构nn.Linear(512*7*7, 512),nn.ReLU(True),nn.Dropout(),nn.Linear(512, 128),nn.ReLU(True),nn.Dropout(),nn.Linear(128, num_classes),)def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.fc(x)return x
net = models.resnet18
print(net)
print(net),可输出该网络结构,可通过关键词进行访问操作相应的层
nn.Sequential()+nn.包,缺点,每一层网络都是数字标识,不好识别
self.conv=torch.nn.Sequential()
self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
等效方法:
self.dense = torch.nn.Sequential(OrderedDict([("dense1", torch.nn.Linear(32 * 3 * 3, 128)),("relu2", torch.nn.ReLU()),("dense2", torch.nn.Linear(128, 10))]))
for i,param in enumerate(test_net.parameters()):
print(param)可输出每一层参数
for param in test_net.named_parameters():
print(param[0])
param[0]输出每一层键
param[1]输出每一层参数值
可以使用列表存储每一层键,为下一步冻结层做准备
for name,weight in test_net.named_parameters():
name_space.append(name)
print(name)
for i, param in enumerate(test_net.parameters()):if i < 45:param.requires_grad = False
这种冻结方法只是阻止参数更新,求导以及参数计算都正常进行
方法二:采用with torch .no_grad,来限制某层,与之相关所有层均会停止传播