如何快速看懂并修改神经网络

news/2025/4/1 6:23:56/

前言:个人之见,一个神经网络网络源码出现,你先看数据集的输入和输出,而这数据集肯定要包括数据增加和制作数据集,第二 看模型的输入和输出(至于模型内部可以自己看论文 无非就是加了几个组件),然后根据输出选择的损失函数。至于学习率和优化器 差不多都是余弦退火和admw的优化器


1.数据集

直接实战,首先你看它的readme,它一般由标注文件的格式(一般都是 文件路径 + 对应的标签数字)(要求自己制作)
输入一般都是这个标注文件,输出一般都是元组或者字典。
数据增强一般包含在数据集的制作当中

actionclip

数据增强(空间剪裁)

数据增强源码

from datasets.transforms_ss import *
from RandAugment import RandAugmentclass GroupTransform(object):def __init__(self, transform):self.worker = transformdef __call__(self, img_group):return [self.worker(img) for img in img_group]def get_augmentation(training, config):input_mean = [0.48145466, 0.4578275, 0.40821073]input_std = [0.26862954, 0.26130258, 0.27577711]scale_size = config.data.input_size * 256 // 224if training:unique = torchvision.transforms.Compose([GroupMultiScaleCrop(config.data.input_size, [1, .875, .75, .66]),GroupRandomHorizontalFlip(is_sth='some' in config.data.dataset),GroupRandomColorJitter(p=0.8, brightness=0.4, contrast=0.4,saturation=0.2, hue=0.1),GroupRandomGrayscale(p=0.2),GroupGaussianBlur(p=0.0),GroupSolarization(p=0.0)])else:unique = torchvision.transforms.Compose([GroupScale(scale_size),GroupCenterCrop(config.data.input_size)])common = torchvision.transforms.Compose([Stack(roll=False),ToTorchFormatTensor(div=True),GroupNormalize(input_mean,input_std)])return torchvision.transforms.Compose([unique, common])def randAugment(transform_train,config):print('Using RandAugment!')transform_train.transforms.insert(0, GroupTransform(RandAugment(config.data.randaug.N, config.data.randaug.M)))return transform_train

这个数据增强 你可以直接 参考()
一般直接蕴含在数据集

 def __init__(self, list_file, labels_file,num_segments=1, new_length=1,image_tmpl='img_{:05d}.jpg', transform=None,random_shift=True, test_mode=False, index_bias=1):def get(self, record, indices):images = list()for i, seg_ind in enumerate(indices):p = int(seg_ind)try:seg_imgs = self._load_image(record.path, p)except OSError:print('ERROR: Could not read image "{}"'.format(record.path))print('invalid indices: {}'.format(indices))raiseimages.extend(seg_imgs)process_data = self.transform(images)return process_data, record.label
  • 空间剪裁 无疑就是进行多少词crop 你得了解一手 ranaugment函数
数据集的制作(时间剪裁以及帧数实现)
  • 输入
    actionclip的标注文件为:
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/HfI4vN2vbHU_000000_000010 289 31
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/B8FXlmO5zk4_000079_000089 240 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/XsEw1vd32l8_000052_000062 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/r61D2lDCHsM_000268_000278 240 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/4sCQ-EX6cIg_000021_000031 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/N9mQC7MeZCk_000008_000018 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/fzVhIrMnY-E_000322_000332 250 1
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/6dLNI2BPTY0_000057_000067 250 23
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/othYtMhFdOU_000020_000030 250 29
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/JVSxlojnBYk_000047_000057 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/8jO9DeYLruU_000003_000013 300 1
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/pU12_c-XvU_000045_000055 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/x6rP9b1V7sQ_000060_000070 250 18
/public/datasets/kinetics400/data2/extracted_train_frames/blasting_sand/jqC2SnFAvoM_000092_000102 300 23
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri6AwOp59yA_000009_000019 250 31
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/wRaacvxMoc8_000014_000024 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/7kbO0v4hag_000107_000117 300 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/GjtR9KZbV3Y_000494_000504 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/hwUQqFadvE_000048_000058 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/vXmgE41UnBk_000844_000854 300 29
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/dglCzcubsw_000246_000256 159 1
/public/datasets/kinetics400/data2/extracted_train_frames/bowling/ri1H0ygN3Us_000768_000778 300 31
/public/datasets/kinetics400/data2/extracted_train_frames/belly_dancing/n24zV9OtorU_000257_000267 300 18
/public/datasets/kinetics400/data2/extracted_train_frames/abseiling/nKoqxSJcZn8_000071_000081 250 0
/public/datasets/kinetics400/data2/extracted_train_frames/air_drumming/pT2byS0qiZM_000001_000011 150 1
/public/datasets/kinetics400/data2/extracted_train_frames/bookbinding/CMo6AJhtZo_000075_000085 250 29

视频提起帧 视频总帧数 对应的标签数字

  • 输出
    一般看__getitem_
    def __getitem__(self, index):record = self.video_list[index]segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)return self.get(record, segment_indices)def __call__(self, img_group):return [self.worker(img) for img in img_group]def get(self, record, indices):images = list()for i, seg_ind in enumerate(indices):p = int(seg_ind)try:seg_imgs = self._load_image(record.path, p)except OSError:print('ERROR: Could not read image "{}"'.format(record.path))print('invalid indices: {}'.format(indices))raiseimages.extend(seg_imgs)process_data = self.transform(images)return process_data, record.label

返回元组 (images,labes)

  • 帧数 一般num_segment由这个决定 为什么?
    因为我看顶刊 基本上 一个片段抽一政数,这个无疑由片段决定
  • 时间剪裁
    时间剪裁指的是从视频的时间维度上选取特定的帧(验证数据集)
  def _get_val_indices(self, record):if self.num_segments == 1:return np.array([record.num_frames //2], dtype=np.int) + self.index_biasif record.num_frames <= self.total_length:if self.loop:return np.mod(np.arange(self.total_length), record.num_frames) + self.index_biasreturn np.array([i * record.num_frames // self.total_lengthfor i in range(self.total_length)], dtype=np.int) + self.index_biasoffset = (record.num_frames / self.num_segments - self.seg_length) / 2.0return np.array([i * record.num_frames / self.num_segments + offset + jfor i in range(self.num_segments)for j in range(self.seg_length)], dtype=np.int) + self.index_bias

帧数不足时
当 self.loop 为 True 时,通过 np.mod(np.arange(self.total_length), record.num_frames) 循环选取视频帧,确保选取的帧数达到 self.total_length,这是一种时间剪裁方式,通过循环利用现有帧来满足所需的帧数。
当 self.loop 为 False 时,使用 i * record.num_frames // self.total_length 均匀地从视频中选取 self.total_length 帧,同样实现了时间维度上的剪裁。

在视频帧数充足的情况下,先根据 self.num_segments 划分片段,然后在每个片段内选取连续的 self.seg_length 帧。offset 确保每个片段内选取的帧在片段中处于相对居中的位置,通过这种方式实现了在每个片段内的时间剪裁。

x-clip

数据集

1.参考一下这一篇 关于数据集的输入输出
2 讲一下时间剪裁

val_pipeline = [dict(type='DecordInit'),dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, test_mode=True),dict(type='DecordDecode'),dict(type='Resize', scale=(-1, scale_resize)),dict(type='CenterCrop', crop_size=config.DATA.INPUT_SIZE),dict(type='Normalize', **img_norm_cfg),dict(type='FormatShape', input_format='NCHW'),dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),dict(type='ToTensor', keys=['imgs'])]if config.TEST.NUM_CROP == 3:val_pipeline[3] = dict(type='Resize', scale=(-1, config.DATA.INPUT_SIZE))val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)if config.TEST.NUM_CLIP > 1:val_pipeline[1] = dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=config.DATA.NUM_FRAMES, multiview=config.TEST.NUM_CLIP)

multiview=config.TEST.NUM_CLIP)无疑是控制为时间剪裁的数量
3 空间剪裁
val_pipeline[4] = dict(type='ThreeCrop', crop_size=config.DATA.INPUT_SIZE)
这个更加直观了直接剪了三次 所以为3

2 模型

action-clip

从输入而言:

  • 文本
    classes, num_text_aug, text_dict = text_prompt(train_data)
    class为( num_text_augxnum_class,context)
    text_dict为(num_class,context)
    num_text_aug为填充内容长度
text_id = numpy.random.randint(num_text_aug,size=len(list_id))texts = torch.stack([text_dict[j][i,:] for i,j in zip(list_id,text_id)])

分为了(B,context)

  • 图片
images = images.view((-1,config.data.num_segments,3)+images.size()[-2:])b,t,c,h,w = images.size()images= images.to(device).view(-1,c,h,w ) 

这个论文严格意义上 是借用 clip的编码器 所以它压缩了

输出也简单

  • 文件
    text_embedding = model_text(texts)(b,d)
  • 图片
 image_embedding = model_image(images)image_embedding = image_embedding.view(b,t,-1)image_embedding = fusion_model(image_embedding)

关于这个fusion输出x.mean(dim=1, keepdim=False)
会把t压缩 x 变成了 (b,d)

x-clip

  • 文本
    text_labels = generate_text(train_data) 这个为(num_class(k),77)
    (和上面同理),但是它没有转为样本数
  • 图片
    images = images.view((-1, config.DATA.NUM_FRAMES, 3) + images.size()[-2:])
    它内部实现了一个编码器
    def encode_video(self, image):b,t,c,h,w = image.size()image = image.reshape(-1,c,h,w)cls_features, img_features = self.encode_image(image)img_features = self.prompts_visual_ln(img_features)img_features = img_features @ self.prompts_visual_projcls_features = cls_features.view(b, t, -1)img_features = img_features.view(b,t,-1,cls_features.shape[-1])video_features = self.mit(cls_features)return video_features, img_features

image = image.reshape(-1,c,h,w) 内部化了

输出:

logit_scale = self.logit_scale.exp()logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features)return logits

返回了一个b k 相似度得分


http://www.ppmy.cn/news/1584110.html

相关文章

第十四届蓝桥杯大赛软件赛省赛C/C++ 大学 B 组(部分题解)

文章目录 前言日期统计题意&#xff1a; 冶炼金属题意&#xff1a; 岛屿个数题意&#xff1a; 子串简写题意&#xff1a; 整数删除题意&#xff1a; 总结 前言 一年一度的&#x1f3c0;杯马上就要开始了&#xff0c;为了取得更好的成绩&#xff0c;好名字写了下前年2023年蓝桥…

【Linux】System V共享内存:零拷贝加速进程通信!

引言 本文深入探讨System V IPC中的共享内存技术&#xff0c;涵盖其原理、操作步骤、实现细节及与其他IPC机制的关系&#xff0c;助力读者全面掌握这一高效进程间通信方式。 &#x1f4dd; 文章总结&#xff1a; 共享内存原理 System V共享内存通过让多个进程共享同一物理内存区…

Avro 批量转换成 Json 文件

环境准备 1. java 运行环境 2. avro-tools.jar (版本不关心&#xff0c;演示使用 avro-tools-1.10.2.jar)目录 avro&#xff08;要转换的avro文件&#xff09; json&#xff08;转换后的json&#xff09; avro-tools-1.10.2.jar 批量转换处理.bat (创建脚本并将下面的代码粘入…

Python命名规范与代码最优结构规范:提升PyCharm中的可读性与健壮性

Python代码规范指南&#xff1a;提升PyCharm中的可读性与健壮性 一个函数只做一件事&#xff0c;不超过150行&#xff0c;函数之间空两行&#xff0c;不要有报黄波浪线&#xff0c;命名规范&#xff0c;注意命名规范&#xff0c;不要想当然认为代码出什么问题要以实测为核心找…

qwen2.5vl技术报告解读

一. 首先qwen2.5vl模型特点 全能文档解析能力 升级文本识别至全场景文档解析,擅长处理多场景、多语种及复杂版式文档(含手写体、表格、图表、化学方程式、乐谱等),实现跨类型文档的精准解析。 跨格式精准目标定位 突破格式限制,大幅提升对象检测、坐标定位与数量统计精度,…

全链路压测:性能测试的流量录制和回放

全链路压测是一种模拟真实用户操作场景&#xff0c;对整个系统进行压力测试的方法&#xff0c;旨在评估系统在高负载下的性能表现。​在全链路压测中&#xff0c;流量录制与回放技术起着关键作用&#xff0c;能够捕获并重现真实的用户流量&#xff0c;帮助发现潜在的性能瓶颈和…

Open GL ES ->模型矩阵、视图矩阵、投影矩阵等变换矩阵数学推导以及方法接口说明

Open GL ES 变换矩阵详解 一、坐标空间变换流程 局部空间 ->Model Matrix(模型矩阵)-> 世界空间 世界空间->View Matrix(视图矩阵)->观察空间 观察空间 ->Projection Matrix(投影矩阵)->裁剪空间 裁剪空间 ->ViewPort Transform(视口变换)>屏幕空间 …

【环路补偿】环路补偿的九种类型-mathcad计算书免费下载

环路补偿的九种类型-mathcad计算书免费下载 通过网盘分享的文件&#xff1a;环路补偿的9种类型.xmcd 链接: https://pan.baidu.com/s/1QIwsKsbv-WyyYgGc4P1eqg?pwd4sar 提取码: 4sar --来自百度网盘超级会员v3的分享