使用LoRA微调florence-2模型

devtools/2024/11/27 1:14:46/

1 环境

Kaggle,单GPU

2 数据

图片、索引和标签放在JSON文件中

文件目录如下:
在这里插入图片描述
logo是图片的文件夹,PNG-SVG是图片的文件夹,re.json是索引,florence2-weight是预训练的权重

JSON文件内容如下:
在这里插入图片描述
在这里插入图片描述
image是图片的地址,label是图片的标签

3 微调代码

3.1 安装所需要的环境

python"> !pip install peft einops flash_attn

3.2 微调代码

使用LoRA策略

python">import os
import json
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoProcessor, AutoModelForCausalLM
from tqdm import tqdm  
import time
from torch.cuda.amp import autocast, GradScaler
import warningsos.environ["TOKENIZERS_PARALLELISM"] = "false" # 关闭 HuggingFace Tokenizers 警告 
warnings.filterwarnings("ignore", category=UserWarning, module="PIL.Image") # 忽略 PIL 图像透明度相关警告# 数据集类
class IconLogoDataset(Dataset):def __init__(self, json_path, transform=None):with open(json_path, 'r', encoding='utf-8') as f:self.data = json.load(f)self.transform = transformdef __len__(self):return len(self.data)def __getitem__(self, idx):while True:  # 循环直到成功返回有效样本try:item = self.data[idx]image_path = item['image']label = item['label']image = Image.open(image_path).convert('RGB')  # 尝试打开图片if self.transform:image = self.transform(image)return image, label  # 成功返回样本except Exception as e:print(f"图片无法读取 {self.data[idx]['image']}: {e}")idx = (idx + 1) % len(self.data)  # 跳过当前样本,尝试下一个def prepare_dataloader(json_path, batch_size, num_workers=0):transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),  #  转化为张量,也会缩放到[0, 1] ])dataset = IconLogoDataset(json_path, transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)print('数据预处理完成')return dataloader# 配置 LoRA
def configure_lora(model_name):lora_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM,inference_mode=False,r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["q_proj", "v_proj"],)# 加载模型和权重processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained('/kaggle/input/florence2-weight/icon_caption_florence', torch_dtype=torch.float16, trust_remote_code=True)# 配置 LoRA 模型model = get_peft_model(model, lora_config)print('模型加载完毕')return model, processor# 训练函数
def train(model, dataloader, processor, optimizer, device, epochs=5):print('开始训练')start_time = time.time()  # 记录训练开始时间scaler = GradScaler()  # 初始化混合精度缩放器model.train()for epoch in range(epochs):epoch_loss = 0# 显示进度条progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")for images, labels in progress_bar:images = images.to(device, dtype=torch.float16)  # 转换到 float16# 同时处理图像和文本inputs = processor(text=labels,images=images,return_tensors="pt",padding=True,do_rescale=False # 因为前面已经0-1归一化了,所以这里不再做).to(device)optimizer.zero_grad()# 自动混合精度上下文with autocast(dtype=torch.float16):outputs = model(**inputs, labels=inputs["input_ids"])  # 模型前向传播loss = outputs.loss# 使用 GradScaler 处理损失scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()epoch_loss += loss.item()# 在进度条上显示当前的平均损失progress_bar.set_postfix(loss=epoch_loss / len(dataloader))print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")end_time = time.time()  # 记录训练结束时间print(f"训练总耗时: {end_time - start_time:.2f} 秒")# 主程序
def main():# 设置设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 数据加载re_dataloader = prepare_dataloader(json_path="/kaggle/input/index-of-data/re.json",batch_size=16)# 配置 LoRA 模型model_name = "microsoft/Florence-2-base"  # 替换为你的模型名称model, tokenizer = configure_lora(model_name)model.to(device)# 优化器optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)# 开始训练train(model, re_dataloader, tokenizer, optimizer, device, epochs=5)print('训练结束')# 保存模型model.save_pretrained("./lora_florence2")tokenizer.save_pretrained("./lora_florence2")if __name__ == "__main__":main()

训练结束后,压缩成一个文件进行下载

python">import shutil# 压缩文件夹
shutil.make_archive("lora_florence2", 'zip', "./lora_florence2")
print("压缩完成:生成了 lora_florence2.zip 文件")

http://www.ppmy.cn/devtools/137268.html

相关文章

unsloth vlm模型Qwen2-VL、Llama 3.2 Vision微调案例

T4卡15G显卡训练 参考: https://github.com/unslothai/unsloth 按自己显卡cuda版本安装 免费colab微调代码: Qwen2-VL: https://colab.research.google.com/drive/1whHb54GNZMrNxIsi2wm2EY_-Pvo2QyKh?usp=sharing from unsloth import FastVisionModel # NEW instead …

Python绘制太极八卦

文章目录 系列目录写在前面技术需求1. 图形绘制库的支持2. 图形绘制功能3. 参数化设计4. 绘制控制5. 数据处理6. 用户界面 完整代码代码分析1. rset() 函数2. offset() 函数3. taiji() 函数4. bagua() 函数5. 绘制过程6. 技术亮点 写在后面 系列目录 序号直达链接爱心系列1Pyth…

jQuery-Word-Export 使用记录及完整修正文件下载 jquery.wordexport.js

参考资料: jQuery-Word-Export导出word_jquery.wordexport.js下载-CSDN博客 近期又需要自己做个 Html2Doc 的解决方案,因为客户又不想要 Html2pdf 的下载了,当初还给我费尽心思解决Html转pdf时中文输出的问题(html转pdf文件下载之…

Maven学习笔记

Maven功能介绍 提供了一套标准化的项目结构提供了一套标准化的构建流程(编译、测试、打包、发布.....)提供了一套依赖管理机制 依赖管理其实就是管理你项目所依赖的第三方资源(jar包、插件...) ①Maven使用标准的坐标配置来管理…

Java文件上传解压

目录结构 工具类 枚举 定义文件类型 public enum FileType {// 未知UNKNOWN,// 压缩文件ZIP, RAR, _7Z, TAR, GZ, TAR_GZ, BZ2, TAR_BZ2,// 位图文件BMP, PNG, JPG, JPEG,// 矢量图文件SVG,// 影音文件AVI, MP4, MP3, AAR, OGG, WAV, WAVE}为了避免文件被修改后缀&#xff0…

手机无法连接服务器1302什么意思?

你有没有遇到过手机无法连接服务器,屏幕上显示“1302”这样的错误代码?尤其是在急需使用手机进行工作或联系朋友时,突然出现的连接问题无疑会带来不少麻烦。那么,什么是1302错误,它又意味着什么呢? 1302错…

常见排序算法总结 (二) - 不基于比较的排序

计数排序 算法思想 用哈希表记录每个不同元素出现的次数,然后再根据这个记录还原。 稳定性分析 计数排序是稳定的,如果待排序元素不是纯数值,那么用链地址法来解决冲突,遍历的过程中按链表元素的先后顺序还原元素就可以保证元…

【高阶数据结构】图论

> 作者:დ旧言~ > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:了解什么是图,并能掌握深度优先遍历和广度优先遍历。 > 毒鸡汤:有些事情,总是不明白,所以我不会坚持…