DL - 图像分割

news/2024/10/11 5:44:12/

from transformers import SegformerFeatureExtractor
import PIL.Image#一个把图像转换为数据的工具类
feature_extractor = SegformerFeatureExtractor()#模拟一批数据
pixel_values = [PIL.Image.new('RGB', (200, 100), 'blue'),PIL.Image.new('RGB', (200, 100), 'red')
]value = [PIL.Image.new('L', (200, 100), 150),PIL.Image.new('L', (200, 100), 200)
]#试算
out = feature_extractor(pixel_values, value)
print('keys=', out.keys())
print('type=', type(out['pixel_values']), type(out['labels']))
print('len=', len(out['pixel_values']), len(out['labels']))
print('type0=', type(out['pixel_values'][0]), type(out['labels'][0]))
print('shape0=', out['pixel_values'][0].shape, out['labels'][0].shape)feature_extractor

keys= dict_keys(['pixel_values', 'labels'])
type= <class 'list'> <class 'list'>
len= 2 2
type0= <class 'numpy.ndarray'> <class 'numpy.ndarray'>
shape0= (3, 512, 512) (512, 512)
SegformerFeatureExtractor {"do_normalize": true,"do_resize": true,"feature_extractor_type": "SegformerFeatureExtractor","image_mean": [0.485,0.456,0.406],"image_std": [0.229,0.224,0.225],"reduce_labels": false,"resample": 2,"size": 512
}

from torchvision.transforms import ColorJitter#能对图像进行亮度,对比度,饱和度,色相变换的工具类。其实就是数据增强
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)print(jitter)jitter(PIL.Image.new('RGB', (200, 100), 'blue'))

ColorJitter(brightness=[0.75, 1.25], contrast=[0.75, 1.25], saturation=[0.75, 1.25], hue=[-0.1, 0.1])


加载数据

from datasets import load_dataset, load_from_disk#一个道路分类数据集
dataset = load_dataset(path='segments/sidewalk-semantic')
# dataset = load_from_disk('datas/segments/sidewalk-semantic')#把图片数据全部转换为数字
def transforms(data):pixel_values = data['pixel_values']label = data['label']#应用数据增强pixel_values = [jitter(i) for i in pixel_values]#编码图片成数字return feature_extractor(pixel_values, label)#切分训练集和测试集
dataset = dataset.shuffle(seed=1)['train'].train_test_split(test_size=0.1)dataset['train'] = dataset['train'].with_transform(transforms)print(dataset['train'][0])dataset

import torchdef collate_fn(data):pixel_values = [i['pixel_values'] for i in data]labels = [i['labels'] for i in data]pixel_values = torch.FloatTensor(pixel_values)labels = torch.LongTensor(labels)return {'pixel_values': pixel_values, 'labels': labels}loader = torch.utils.data.DataLoader(dataset=dataset['train'],batch_size=4,collate_fn=collate_fn,shuffle=True,drop_last=True,
)for i, data in enumerate(loader):breaklen(loader), data['pixel_values'].shape, data['labels'].shape

#因为模型的计算输出是原尺寸除以4,所以需要把结果扩张成原来的大小便于计算正确率什么的
torch.nn.functional.interpolate(torch.randn(4, 35, 128, 128),size=(512, 512),mode='bilinear',align_corners=False).shape

from transformers import SegformerForSemanticSegmentation, SegformerModel#加载模型
#一共35中道路类别,怎么来的不重要
#model = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)#定义下游任务模型
class Model(torch.nn.Module):def __init__(self):super().__init__()self.pretrained = SegformerModel.from_pretrained('nvidia/mit-b0')self.linears = torch.nn.ModuleList([torch.nn.Linear(32, 256),torch.nn.Linear(64, 256),torch.nn.Linear(160, 256),torch.nn.Linear(256, 256)])self.classifier = torch.nn.Sequential(torch.nn.Conv2d(in_channels=1024,out_channels=256,kernel_size=1,bias=False),torch.nn.BatchNorm2d(256),torch.nn.ReLU(),torch.nn.Dropout(0.1),torch.nn.Conv2d(256, 35, kernel_size=1),)#加载预训练模型的参数parameters = SegformerForSemanticSegmentation.from_pretrained('nvidia/mit-b0',num_labels=35)for i in range(4):self.linears[i].load_state_dict(parameters.decode_head.linear_c[i].proj.state_dict())self.classifier[0].load_state_dict(parameters.decode_head.linear_fuse.state_dict())self.classifier[1].load_state_dict(parameters.decode_head.batch_norm.state_dict())self.classifier[4].load_state_dict(parameters.decode_head.classifier.state_dict())self.criterion = torch.nn.CrossEntropyLoss(ignore_index=255)def forward(self, pixel_values, labels):#pixel_values -> [4, 3, 512, 512]#labels -> [4, 512, 512]#首先通过预训练模型抽中间特征#[4, 32, 128, 128]#[4, 64, 64, 64]#[4, 160, 32, 32]#[4, 256, 16, 16]features = self.pretrained(pixel_values=pixel_values,output_hidden_states=True)features = features.hidden_states#打平#[4, 32, 16384]#[4, 64, 4096]#[4, 160, 1024]#[4, 256, 256]features = [i.flatten(2) for i in features]#转置,把通道放到最后一个维度#[4, 16384, 32]#[4, 4096, 64]#[4, 1024, 160]#[4, 256, 256]features = [i.transpose(1, 2) for i in features]#线性计算#[4, 16384, 256]#[4, 4096, 256]#[4, 1024, 256]#[4, 256, 256]features = [l(f) for f, l in zip(features, self.linears)]#转置回来,把通道放中间#[4, 256, 16384]#[4, 256, 4096]#[4, 256, 1024]#[4, 256, 256]features = [i.permute(0, 2, 1) for i in features]#变形成二维的图片#[4, 256, 128, 128]#[4, 256, 64, 64]#[4, 256, 32, 32]#[4, 256, 16, 16]features = [f.reshape(pixel_values.shape[0], -1, s, s)for f, s in zip(features, [128, 64, 32, 16])]#拓展到统一的尺寸#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]#[4, 256, 128, 128]features = [torch.nn.functional.interpolate(i,size=(128, 128),mode='bilinear',align_corners=False)for i in features]#逆序,维度不变features = features[::-1]#在通道维度合并成一个张量#[4, 1024, 128, 128]features = torch.cat(features, dim=1)#跑分类网络,其中包括了1024->256->35两步,使用cnn网络实现#[4, 35, 128, 128]features = self.classifier(features)#为了计算loss,要把计算结果放大到和labels一致#[4, 35, 128, 128] -> [4, 35, 512, 512]#计算交叉熵lossloss = self.criterion(torch.nn.functional.interpolate(features,size=(512, 512),mode='bilinear',align_corners=False), labels)return {'loss': loss, 'logits': features}model = Model()#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)out = model(**data)out['loss'], out['logits'].shape

from datasets import load_metric#加载评价指标
metric = load_metric('mean_iou')#试算
metric.compute(predictions=torch.ones(4, 10, 10),references=torch.ones(4, 10, 10),#一共35中道路类别,怎么来的不重要num_labels=35,#忽略背景类0ignore_index=0,reduce_labels=False)

from matplotlib import pyplot as pltdef show(image, out, label):plt.figure(figsize=(15, 5))image = image.clone()image = image.permute(1, 2, 0)image = image - image.min().item()image = image / image.max().item()image = image * 255image = PIL.Image.fromarray(image.numpy().astype('uint8'), mode='RGB')image = image.resize((512, 512))plt.subplot(1, 3, 1)plt.imshow(image)plt.axis('off')out = PIL.Image.fromarray(out.numpy().astype('uint8'))plt.subplot(1, 3, 2)plt.imshow(out)plt.axis('off')label = PIL.Image.fromarray(label.numpy().astype('uint8'))plt.subplot(1, 3, 3)plt.imshow(label)plt.axis('off')plt.show()show(data['pixel_values'][0], torch.ones(512, 512), data['labels'][0])

测试

def test():model.eval()dataset['test'] = dataset['test'].shuffle()loader_test = torch.utils.data.DataLoader(dataset=dataset['test'].with_transform(transforms),batch_size=8,collate_fn=collate_fn,shuffle=False,drop_last=True,)labels = []outs = []correct = 0#初始化为1,防止除0total = 1for i, data in enumerate(loader_test):with torch.no_grad():out = model(**data)#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False)out = out.argmax(dim=1)outs.append(out)label = data['labels']labels.append(label)#统计正确率时排除label中的0select = label != 0correct += (label[select] == out[select]).sum().item()total += len(label[select])if i % 1 == 0:show(data['pixel_values'][0], out[0], label[0])if i == 4:break#计算评价指标metric_out = metric.compute(predictions=torch.cat(outs, dim=0),references=torch.cat(labels, dim=0),num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')print(metric_out)print(correct / total)test()

训练

from transformers import AdamW
from transformers.optimization import get_schedulerdef train():optimizer = AdamW(model.parameters(), lr=5e-5)scheduler = get_scheduler(name='linear',num_warmup_steps=0,num_training_steps=len(loader) * 3,optimizer=optimizer)model.train()for i, data in enumerate(loader):out = model(**data)loss = out['loss']loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()scheduler.step()optimizer.zero_grad()model.zero_grad()if i % 10 == 0:#运算结果扩张4倍out = torch.nn.functional.interpolate(out['logits'],size=(512, 512),mode='bilinear',align_corners=False).argmax(dim=1)label = data['labels']#计算评价指标metric_out = metric.compute(predictions=out,references=label,num_labels=35,ignore_index=0)#删除这两个结果,不想看metric_out.pop('per_category_iou')metric_out.pop('per_category_accuracy')#统计正确率时排除label中的0select = label != 0label = label[select]out = out[select]#防止除0accuracy = (label == out).sum().item() / (len(label) + 1)lr = optimizer.state_dict()['param_groups'][0]['lr']print(i, loss.item(), lr, metric_out, accuracy)torch.save(model, 'models/9.抠图.model')train()

model = torch.load('models/9.抠图.model')
test()


http://www.ppmy.cn/news/1436861.html

相关文章

Python | Leetcode Python题解之第39题组合总和

题目&#xff1a; 题解&#xff1a; from typing import Listclass Solution:def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:def dfs(candidates, begin, size, path, res, target):if target < 0:returnif target 0:res.append(p…

【论文推导】基于有功阻尼的转速环PI参数整定分析

前言 在学习电机控制的路上&#xff0c;PMSM的PI电流控制是不可避免的算法之一&#xff0c;其核心在于内环电流环、外环转速环的设置&#xff0c;来保证转速可调且稳定&#xff0c;并且保证较好的动态性能。整个算法仿真在《现代永磁同步电机控制原理及matlab仿真》中已详细给出…

Postman 工具发送请求的技巧与实践

在开发和测试 API 时&#xff0c;发送 JSON 格式的请求是一个常见需求。 在 Postman 中构建和发送 JSON 请求 创建一个新的请求 首先&#xff0c;在 Postman 启动界面上找到并点击 “New” 按钮&#xff0c;选择 “HTTP Request” 来开始新建一个请求。这一步骤允许你定义请…

K8s: 最佳实践经验之谈

最佳实践 1 &#xff09;普通配置 定义配置时&#xff0c;请指定最新的稳定 API 版本在推送到集群之前&#xff0c;配置文件应存储在版本控制中 这允许您在必要时快速回滚配置更改它还有助于集群重新创建和恢复 使用 YAML 而不是 JSON 编写配置文件 虽然这些格式几乎可以在所有…

doc转html参考

参考&#xff1a;https://github.com/mwilliamson/mammoth.js?tabreadme-ov-file 参考&#xff1a;前端玩Word&#xff1a;Word文档解析成浏览器认识的HTML_前端解析word成html-CSDN博客

我独自升级崛起怎么下载 一文分享我独自升级崛起游戏下载教程

我独自升级崛起怎么下载 一文分享我独自升级崛起游戏下载教程 我独自升级&#xff1a;崛起是一款由韩国漫画改编而成的热门多人网络在线联机游戏&#xff0c;这款游戏是一款的角色扮演类型游戏&#xff0c;游戏有着独一无二的剧情模式。小伙伴们在游戏中可以体验到独特的成长系…

《MATLAB科研绘图与学术图表绘制从入门到精通》示例:绘制伊甸火山3D曲面图

伊甸火山&#xff08; Mount Eden&#xff09;是新西兰奥克兰市的一座火山&#xff0c;也是一处受欢迎的旅游景点。数据来自R内置volcano数据&#xff0c;笔者导出为volcano.csv文件&#xff0c;这个数据集用于演示3D曲面图和地形建模的目的。 购书地址&#xff1a;https://ite…

某某志蓝队初级一面分享

某某志蓝队初级一面分享 所面试的公司&#xff1a;某某志 薪资待遇&#xff1a;待定 所在城市&#xff1a;河北 面试职位&#xff1a;蓝队初级 面试过程&#xff1a;我感觉面试官的语速有点点快&#xff0c;就像两个字读成一个字的那种&#xff0c;在加上我耳朵不太好&…