[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - DataSet篇

[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - DataSet篇

  • 前情提要
  • 源码解读
    • 完整代码
    • 逐行解读
      • 导包
      • readjson函数
      • data_collate函数
      • ImageCaptionDataset类(init函数)
      • ImageCaptionDataset类(readImage函数)

在这里插入图片描述
参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE

前情提要

有关多模态大模型架构中的语言模型部分(MQwen.py)的代码请看(多模态大模型源码阅读 - 1、 多模态大模型源码阅读 - 2, 多模态大模型源码阅读 - 3,多模态大模型源码阅读 - 4)
多模态大模型架构中的视觉模型(visual/CLIP-VIT.py)部分请看多模态大模型源码阅读 - 5
多模态大模型架构中的trainer(trainer.py)部分请看多模态大模型源码阅读 - 6
多模态大模型架构中的MultiModal融合部分(MultiModal.py)部分请看多模态大模型源码阅读 - MultiModal篇。
观前提醒,本文中介绍的多模态模型架构来源于github项目WatchTower-Liu/VLM-learning,对Qwen模型的前向传播代码进行重写,并通过中间投影层将视觉特征与文本映射到同一向量空间。投影层原理参考LLAVA
在这里插入图片描述
本节将介绍多模态模型架构中的dataset部分,该部分主要用于处理图片和文本数据,使其能够用于image captioning(图像字幕生成)任务。

源码解读

完整代码

import torch
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPProcessor, SiglipProcessor
from PIL import Image
import numpy as np
from tqdm import tqdmfrom qwen.qwen_generation_utils import make_contextdef readJson(filePath):with open(filePath, 'r', encoding="utf-8") as f:data = json.load(f)return datadef data_collate(example, tokenizer, black_token_length):images = []captions = []labels = []max_length = np.max([len(e[1]) for e in example]) + 1for e in example:img, caption, L = eL = L + 1caption = caption + [tokenizer.eod_id]images.append(img)caption_labels = [-100]*(black_token_length + (len(caption)-L) - 1) + caption[-L:] + [-100]*(max_length - len(caption))captions.append(torch.tensor(caption + [tokenizer.eod_id]*(max_length - len(caption))))labels.append(torch.tensor(caption_labels))labels = torch.stack(labels, dim=0).long()captions = torch.stack(captions, dim=0).long()images = torch.stack(images, dim=0).to(torch.float16)return {"images": images, "input_ids": captions, "labels": labels}class ImageCaptionDataset(Dataset):def __init__(self, tokenizer, image_map_file, captions_file, Vconfig, return_caption_num=1, max_train_data_item=None):super().__init__()self.tokenizer = tokenizerself.return_caption_num = return_caption_numself.max_train_data_item = max_train_data_itemmean = [0.485, 0.456, 0.406]  # RGBstd = [0.229, 0.224, 0.225]  # RGBself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std),transforms.Resize([224, 224])])self.image_map = readJson(image_map_file)self.captions = readJson(captions_file)# self.image_processor = CLIPProcessor.from_pretrained(Vconfig.model_path)self.image_processor = SiglipProcessor.from_pretrained(Vconfig.model_path)self.readImage()  # 一次性读入内存def readImage(self):self.data_list = []number = 0image_map_keys = list(self.image_map.keys())np.random.shuffle(image_map_keys)for IM in tqdm(image_map_keys):number += 1if self.max_train_data_item is not None and number > self.max_train_data_item:returntry:image_file_path = self.image_map[IM]["path"] + self.image_map[IM]["image_file"]self.data_list.append([image_file_path, self.image_map[IM]["ID"]])except Exception as e:print(f"Error loading image {IM}: {e}")continue# Debug informationprint(f"Total images loaded: {len(self.data_list)}")def __getitem__(self, index):image_path, ID = self.data_list[index]try:image = Image.open(image_path).convert("RGB")image = self.image_processor(images=image, return_tensors="pt")["pixel_values"][0]except Exception as e:print(f"Error processing image {image_path}: {e}")raisecaptions_data = self.captions.get(str(ID), {})captions = captions_data.get("a", [])# Ensure captions is a listif isinstance(captions, str):captions = [captions]elif isinstance(captions, dict):# Handle the case where captions is a dictionarycaptions = [captions.get("value", "")]if not isinstance(captions, list):raise ValueError(f"Captions for ID {ID} are not in the expected format: {captions}")if not captions:raise ValueError(f"No captions found for ID {ID}")prompt = captions_data.get("q", "")# Debug information# print(f"Captions for ID {ID}: {captions}")select_idx = np.random.choice(len(captions))# More debug information# print(f"Selected index: {select_idx}, Selected caption: {captions[select_idx]}")messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt}]prompt_raw, context_tokens = make_context(self.tokenizer,prompt,history=[],system="你是一位图像理解助手。")choice_captions = self.tokenizer(prompt_raw)["input_ids"]answer = self.tokenizer(captions[select_idx])["input_ids"]choice_captions = choice_captions + answerreturn image, choice_captions, len(answer)def __len__(self):return len(self.data_list)

逐行解读

导包

import torch
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPProcessor, SiglipProcessor
from PIL import Image
import numpy as np
from tqdm import tqdmfrom qwen.qwen_generation_utils import make_context

torch:深度学习的核心出装,无需赘述。
json:主要用于处理json格式文件。json简洁易用,并且被多种语言支持,有良好的跨平台兼容性,所以在很多项目里我们都能看到json文件的影子。在python里,json可以很容易地转换为字典和列表对象,字典和猎豹对象也可以存储为json文件。
Dataset:Dataset是一个用于数据处理的抽象类,可以通过继承它自定义自己的数据处理方式。
DataLoader:封装已有的数据集对象,可以进行批处理和多进程加载。
transforms:主要用于图像预处理,可以对图像进行旋转,裁剪,缩放等操作。
CLIPProcessor, SiglipProcessor:在这里主要用于将图像转换为像素值。
Image:主要用于图像的打开,处理和保存,多模态中非常常用的一个模块。
make_context:用于生成文本和上下文,将文本输入转化为模型可以理解的格式。这个方法来自于Qwen模型原始项目中的模块,在github的transforemers仓库中可以找到。

readjson函数

def readJson(filePath):with open(filePath, 'r', encoding="utf-8") as f:data = json.load(f)return data

根据传入的文件路径打开json文件,指定文件的编码类型为‘utf-8’,防止文件内部可能有非ASCII字符,以只读模式打开文件。
通过json.load()函数将json数据格式的内容转换为python对象,例如字典或列表。并将转换后的值返回。

data_collate函数

def data_collate(example, tokenizer, black_token_length):images = []captions = []labels = []max_length = np.max([len(e[1]) for e in example]) + 1for e in example:img, caption, L = eL = L + 1caption = caption + [tokenizer.eod_id]images.append(img)caption_labels = [-100]*(black_token_length + (len(caption)-L) - 1) + caption[-L:] + [-100]*(max_length - len(caption))captions.append(torch.tensor(caption + [tokenizer.eod_id]*(max_length - len(caption))))labels.append(torch.tensor(caption_labels))labels = torch.stack(labels, dim=0).long()captions = torch.stack(captions, dim=0).long()images = torch.stack(images, dim=0).to(torch.float16)return {"images": images, "input_ids": captions, "labels": labels}

data_collate函数用于数据预处理。传入参数分别为example,tokenizer,black_token_length。
example:包含多个样本的列表,每个元素都是一个三元组,包含图像、文本和一个自定义长度。
tokenizer:分词器,用于处理文本数据。
black_token_length:black_token_length指定应该屏蔽的token数量,对这部分token不计算其损失。
初始化三个列表,分别存储与处理后的图片、字幕(也可以称作图片描述)和标签信息。
首先遍历example中的每个元素,len(e[1])代表字幕长度,max_length初始化为最长字幕的长度加一,这里的加一是加上了eos_token的长度。
接着遍历example中的每个元素,并将e解包赋值给img,caption和L,需要注意的是这里的L和字幕长度不相等,具体数值取决于实际需求。
L的长度加一,添加上结束符eos_token的长度。
在字幕变量结尾加上eos_token
images数组添加入example列表中每个元组内的img数据。
初始化caption_labels,开头是长度为(black_token_length + 当前字幕长度 - L - 1)的掩码,中间为字幕的倒数L个token,结尾为长度是(max_lenght - 当前字幕长度)的掩码。这样操作用于忽略序列的开始和填充部分。
用eos_token对当前字幕进行填充,确保其长度为max_length,并将填充后的数据转换为浮点数长点,添加入captions数组中。
将caption_labels添加入labels数组中。
利用stack函数对labels,captions,images进行堆叠。其中labels和captions转换为长整型,images转换为单精度浮点数。
将images,captions,labels打包为一个字典返回。值得注意的是input_ids键对应的是字幕(captions)。

ImageCaptionDataset类(init函数)

class ImageCaptionDataset(Dataset):def __init__(self, tokenizer, image_map_file, captions_file, Vconfig, return_caption_num=1, max_train_data_item=None):super().__init__()self.tokenizer = tokenizerself.return_caption_num = return_caption_numself.max_train_data_item = max_train_data_itemmean = [0.485, 0.456, 0.406]  # RGBstd = [0.229, 0.224, 0.225]  # RGBself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std),transforms.Resize([224, 224])])self.image_map = readJson(image_map_file)self.captions = readJson(captions_file)# self.image_processor = CLIPProcessor.from_pretrained(Vconfig.model_path)self.image_processor = SiglipProcessor.from_pretrained(Vconfig.model_path)self.readImage()  # 一次性读入内存

ImageCaptionDataset类继承自DataSet类,重构了部分方法。
image_map_file:这个参数是图像和索引的映射。
captions_file:包含了对应图像的字幕信息。
Vconfig:视觉模型的通用配置
return_caption_num:这一参数的数值代表每个图片返回的字幕数量,假设每一个图片都有k个字幕,如果这个参数的数值为n,n<=k,那么就会从k个字幕中随机选取n个返回。
max_train_data_item参数限制了训练数据的最大数量。

        self.tokenizer = tokenizerself.return_caption_num = return_caption_numself.max_train_data_item = max_train_data_item

将部分参数存储为成员变量。

        mean = [0.485, 0.456, 0.406]  # RGBstd = [0.229, 0.224, 0.225]  # RGB

mean和std用于图像数据标准化,其中mean为RGB数据的均值,std为RGB数据的标准差,这些数据会在后续代码中对图像数据进行处理时用到。

        self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std),transforms.Resize([224, 224])])

创建一个transpose对象,用于对图像进行转换,该对象执行三个操作。1.将图像数据转换为浮点数张量类型,2.使用之前初始化的均值和标准差对图像数据进行标准化。3,将图像数据重塑为(224,224)像素。

        self.image_map = readJson(image_map_file)self.captions = readJson(captions_file)

使用之前创建的readjson方法读取图像索引映射数据和图像字幕数据,并转换为python对象(字典)。

        self.image_processor = CLIPProcessor.from_pretrained(Vconfig.model_path)self.readImage()  # 一次性读入内存

从视觉模型的模型路径中初始化一个图像处理器,用于将图片转换为像素信息,最后使用self.readImage()将有关图像的各类信息一次性读入内存。

ImageCaptionDataset类(readImage函数)

    def readImage(self):self.data_list = []number = 0image_map_keys = list(self.image_map.keys())np.random.shuffle(image_map_keys)for IM in tqdm(image_map_keys):number += 1if self.max_train_data_item is not None and number > self.max_train_data_item:returntry:image_file_path = self.image_map[IM]["path"] + self.image_map[IM]["image_file"]self.data_list.append([image_file_path, self.image_map[IM]["ID"]])except Exception as e:print(f"Error loading image {IM}: {e}")continue# Debug informationprint(f"Total images loaded: {len(self.data_list)}")

这个函数主要用于将图像的路径信息和图像id信息成对存储为成员变量。

    def readImage(self):self.data_list = []number = 0image_map_keys = list(self.image_map.keys())np.random.shuffle(image_map_keys)

首先是一系列初始化工作。初始化一个data_list列表用于存储相关数据,初始化number用于计数,防止超过最大训练数据上限。将图片索引映射数据的键转换为列表,并赋值给image_map_keys变量。利用np.random.shuffle打乱图像键的顺序,保证训练数据的随机性。

        for IM in tqdm(image_map_keys):number += 1if self.max_train_data_item is not None and number > self.max_train_data_item:returntry:image_file_path = self.image_map[IM]["path"] + self.image_map[IM]["image_file"]self.data_list.append([image_file_path, self.image_map[IM]["ID"]])except Exception as e:print(f"Error loading image {IM}: {e}")continue

循环遍历所有的图像键,每次循环代表我们处理了一个数据,number计数加一。如果当前number超过了最大训练数据上限则直接退出循环并返回。
用try except体防止在读取图像时发生错误。首先读取图像的文件路径,将文件夹路径和图片路径进行拼接。并将其与图片的ID数据成对存入data_list中。
如果报错则打印错误信息,并继续遍历后续数据,这样可以避免循环中途被打断。

    def __getitem__(self, index):image_path, ID = self.data_list[index]try:image = Image.open(image_path).convert("RGB")image = self.image_processor(images=image, return_tensors="pt")["pixel_values"][0]except Exception as e:print(f"Error processing image {image_path}: {e}")raise

get_item魔法方法允许类的实例对象通过[]对象符进行索引操作。如Multimodal[1],其中Multimodal为类的实例对象,1为传入的index参数值。
根据传入的索引值,将成员变量data_list对应所以值下的列表进行解包,解包为image_path和ID。
用try_except格式捕捉报错和报错类型,提高代码的鲁棒性,如果try内部的代码出现错误,就会立即报错,这个应该是防止图片地址不存在。
根据图片存储地址打开图片并用之前初始化好的图片处理对象将图片转换为像素值,返回类型为浮点数张量。这里用[0]是因为图片处理对象默认的返回值size为(batchsize, …),即使只有一个图片。也会返回批次为1的返回值,因此需要用索引操作从批次中获取数据。

        captions_data = self.captions.get(str(ID), {})captions = captions_data.get("a", [])# Ensure captions is a listif isinstance(captions, str):captions = [captions]elif isinstance(captions, dict):# Handle the case where captions is a dictionarycaptions = [captions.get("value", "")]

根据解包后获得的ID值获取字幕相关数据,这里用get是为了防止没有ID值对应的字幕相关数据,出现报错,get函数的第二个传入参数{}表示当ID值对应数据不存在时,返回一个空集合。
进一步用get函数从字幕相关数据中获取字幕数据,这里的’a’代表字幕数据存放在键‘a’对应的值中。
接下来的操作是为了将字幕数据转换为列表格式。
首先判断字幕数据是否为字符串,如果是,则转换为列表。如果是字典,则进一步获取字典的值,并转换为列表。

        if not isinstance(captions, list):raise ValueError(f"Captions for ID {ID} are not in the expected format: {captions}")if not captions:raise ValueError(f"No captions found for ID {ID}")

这段代码是为了确保字幕数据为列表格式。如果不是列表格式,则报错并输出自定义的报错信息。
如果字幕数据为空,则报错并输出自定义的报错信息。

        prompt = captions_data.get("q", "")select_idx = np.random.choice(len(captions))

从字幕相关数据中获取键’q’下对应的值,并赋值给变量prompt,这里的’q’含义为query。
随机从选择一个idx,idx的范围为[0, len(captions) - 1],表示从captions中随机抽取一个图像字幕数据。

        messages = [{"role": "system", "content": ""}, {"role": "user", "content": prompt}]prompt_raw, context_tokens = make_context(self.tokenizer,prompt,history=[],system="你是一位图像理解助手。")

创建一个消息对象,其中第一个字典代表系统的角色的消息,初始化内容为空,另一个字典代表用户的消息,初始化为之前提取的prompt。
使用导入的创建上下文方法,传入初始化的成员变量self.tokennizer分词器,prompt,初始化历史信息为空,代表当前没有历史对话,system作为描述信息,描述了系统功能。
返回值为初步处理后的提示文本和上下文token,用于后续处理以适配模型输入

        choice_captions = self.tokenizer(prompt_raw)["input_ids"]answer = self.tokenizer(captions[select_idx])["input_ids"]choice_captions = choice_captions + answerreturn image, choice_captions, len(answer)

使用tokenizer将原始提示和选取的字幕转换为input_ids,并拼接在一起作为模型的文本输入信息。
最后返回图像信息,对应的文本输入信息和字幕信息的长度。
这里的answer相当于对图像信息(image)的描述。
例如我的返回值image是一个黄色的花朵的像素信息,那么answer就是对这个图像的描述(一个黄色的花朵),choice_caption将提示信息文本(‘你是一位图像理解助手’)和answer打包在一起用于模型训练。
至此,模型的输入数据处理部分讲解完毕,后续应该会更新模型的正式训练代码阅读,trainer.py部分。


http://www.ppmy.cn/news/1517047.html

相关文章

Java中Objecy类

没有成员变量 也就只有无参 的构造方法 /*** ClassName Test* author gyf* Date 2024/8/28 10:32* Version V1.0* Description : */ public class Test {public static void main(String[] args) {// toString()Object object new Object();System.out.println(object);String…

网络安全新视角:人工智能在防御中的最新应用

人工智能在网络安全中的最新应用 概述 人工智能&#xff08;AI&#xff09;在网络安全领域的应用正日益成熟&#xff0c;它通过机器学习和深度学习技术&#xff0c;为网络安全带来了革命性的变革。AI技术不仅能够自动化、智能化地检测、分析和应对安全威胁&#xff0c;还能够…

Jenkins:自动化的魔法师,打造无缝CI/CD流水线

标题&#xff1a;“Jenkins&#xff1a;自动化的魔法师&#xff0c;打造无缝CI/CD流水线” 在当今快速发展的软件开发领域&#xff0c;持续集成&#xff08;Continuous Integration, CI&#xff09;和持续部署&#xff08;Continuous Deployment, CD&#xff09;已经成为提升开…

Docker续1:

一、打包传输 1.打包 [rootlocalhost ~]# systemctl start docker [rootlocalhost ~]# docker save -o centos.tar centos:latest [rootlocalhost ~]# ls anaconda-ks.cfg centos.tar 2.传输 [rootlocalhost ~]# scp centos.tar root192.168.1.100:/root 3.删除镜像 [r…

总结:Python语法

Python中的字典、列表和数组是三种常用的数据结构&#xff0c;它们各自有不同的用途和特性。 字典&#xff08;Dictionary&#xff09; 字典是一种无序的、可变的数据结构&#xff0c;它存储键值对&#xff08;key-value pairs&#xff09;。字典中的每个元素都是一个键值对&…

flink--会话模式与应用模式

flink-会话模式部署 会话情况&#xff1a; 添加依赖 <properties><flink.version>1.17.2</flink.version> </properties> ​ <dependencies><dependency><groupId>org.apache.flink</groupId><artifactId>flink-strea…

CSS属性

一、CSS列表样式 1、list-style-type属性&#xff08;列表项标记&#xff09; CSS列表属性允许我们设置不同的列表项标记。 在HTML中&#xff0c;有​两种类型​的列表&#xff1a; ​无序列表​&#xff08;<ul>&#xff09; - 列表项目用​项目符号​标记​有序列表…

【Linux】自动化构建工具makefile

目录 背景 makefile简单编写 .PHONY makefile中常用选项 makefile的自动推导 背景 会不会写makefile&#xff0c;从一个侧面说明了一个人是否具备完成大型工程的能力 ​ ◉ 一个工程中的源文件不计数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…

开放式耳机怎么戴?佩戴舒适在线的几款开放式耳机分享

开放式耳机的佩戴方式与传统的入耳式耳机有所不同&#xff0c;它采用了一种挂耳式的设计&#xff0c;提供了一种新颖的佩戴体验&#xff0c;以下是开放式耳机的佩戴方式。 1. 开箱及外观&#xff1a;首先&#xff0c;从包装盒中取出耳机及其配件&#xff0c;包括耳机本体、充电…

使用 FinalShell 链接 Centos

1. 安装 FinalShell 下载地址&#xff1a;https://www.hostbuf.com/t/988.html 2. 查看 IP地址。 2.1 通过命令查询IP 输入 ip addr show 查询&#xff0c;输出效果如下截图&#xff0c;其中的 192.168.1.5 就是 IP 地址。 2.2 通过可视化界面查询IP 点击右上角的网络图标…

LoadBalancer负载均衡

一、概述 1.1、Ribbon目前也进入维护模式 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具。 简单的说&#xff0c;Ribbon是Netflix发布的开源项目&#xff0c;主要功能是提供客户端的软件负载均衡算法和服务调用。Ribbon客户端组件提供一系列完善的…

企业中需要哪些告警Rules

文章目录 企业中需要哪些告警Rules前言定义告警规则企业中的告警rulesNode.rulesprometheus.ruleswebsite.rulespod.rulesvolume.rulesprocess.rules 总结 企业中需要哪些告警Rules 前言 Prometheus中的告警规则允许你基于PromQL表达式定义告警触发条件&#xff0c;Prometheus…

poi word 添加水印

poi word 添加水印 依赖DocxUtil调用遇到的问题部分客户给的word无法添加水印水印文案 过长会导致字变小变形 超过一定长度就会显示异常。消失等情况 依赖 <!--poi-tl--><dependency><groupId>com.deepoove</groupId><artifactId>poi-tl</art…

捕获神经网络的精髓:深入探索PyTorch的torch.jit.trace方法

标题&#xff1a;捕获神经网络的精髓&#xff1a;深入探索PyTorch的torch.jit.trace方法 在深度学习领域&#xff0c;模型的部署和优化是至关重要的环节。PyTorch作为最受欢迎的深度学习框架之一&#xff0c;提供了多种工具来帮助开发者优化和部署模型。torch.jit.trace是PyTo…

设计模式 10 外观模式

设计模式 10 创建型模式&#xff08;5&#xff09;&#xff1a;工厂方法模式、抽象工厂模式、单例模式、建造者模式、原型模式结构型模式&#xff08;7&#xff09;&#xff1a;适配器模式、桥接模式、组合模式、装饰者模式、外观模式、享元模式、代理模式行为型模式&#xff…

ansible的tags标签

1、tags模块 可以给任务定义标签&#xff0c;可以根据标签来运行指定的任务 2、标签的类型 always&#xff1a;设定了标签名为always&#xff0c;除非指定跳过这个标签&#xff0c;否则该任务将始终会运行&#xff0c;即使指定了标签还会运行never&#xff1a;始终不运行的任…

CPU、MPU、MCU、SOC分别是什么?

CPU、MPU、MCU和SoC都是与微电子和计算机科学相关的术语&#xff0c;它们在功能定位、应用场景以及处理能力等方面有所区别。具体如下&#xff1a; CPU&#xff1a;CPU是中央处理单元的缩写&#xff0c;它通常指计算机内部负责执行程序指令的芯片。CPU是所有类型计算机&#x…

java 读取mysql中的表并按照指定格式导出excel

在Java中读取MySQL中的数据表并将其导出到Excel文件中&#xff0c;你需要以下几个步骤&#xff1a; 连接MySQL数据库&#xff1a;使用JDBC驱动程序连接到MySQL数据库。执行SQL查询&#xff1a;获取表数据。使用Apache POI库生成Excel文件&#xff1a;将数据写入Excel格式。保存…

SpringBoot文档之构建包的阅读笔记

Packaging Spring Boot Applications Efficient Deployments Efficient Deployments 默认情况下&#xff0c;基于SpringBoot框架开发应用时&#xff0c;构建插件spring-boot-maven-plugin将项目打包为fat jar。 执行如下命令&#xff0c;解压构建得到的jar文件。 java -Djarmo…

Python 程序设计基础教程

Python 程序设计基础教程 撰稿人&#xff1a;南星六月雪 第 一 章 变量与简单数据类型 1.1 变量 先来观察以下程序&#xff1a; world "Hello Python!" print(world)world "Hello Python,I love you!" print(world)运行这个程序&#xff0c;将看到两…