目标检测多模态大模型实践:貌似是全网唯一Shikra的部署和测试教程,内含各种踩坑以及demo代码

ops/2025/1/16 1:45:11/

在这里插入图片描述
原文:
Shikra: Unleashing Multimodal LLM’s Referential Dialogue Magic
代码:
https://github.com/shikras/shikra
模型:
https://huggingface.co/shikras/shikra-7b-delta-v1
https://huggingface.co/shikras/shikra7b-delta-v1-0708
第一个是论文用的,第二个会有迭代。

本人的shikra论文解读,逐行解读,非常详细!

多模态大模型目标检测,精读,Shikra

部署:

  1. 下载GitHub工程,和shikras的模型参数,注意,还要下载LLaMA-7b的模型;
  2. 创建环境:
conda create -n shikra python=3.10
conda activate shikra
pip install -r requirements.txt

后面我运行的时候报缺包了,又pip install了以下包,不过每个人情况不同:

pip install uvicorn
pip install mmengine

然后还会报错:

File "/usr/local/lib/python3.10/dist-packages/cv2/typing/__init__.py", line 171, in <module>LayerId = cv2.dnn.DictValue
AttributeError: module 'cv2.dnn' has no attribute 'DictValue'

解决方案:
修改/usr/local/lib/python3.10/dist-packages/cv2/typing/init.py
注释掉LayerId = cv2.dnn.DictValue这行即可。

  1. 权重下载和合并
    shikra官方提供的模型权重需要和llama1-7b合并之后才能用,然而llama1需要申请,比较麻烦,在hf上找到了平替(这一步我走了好久QwQ):
    https://huggingface.co/huggyllama/llama-7b
    大家自己下载,然后运行官方提供的合并代码:
python mllm/models/shikra/apply_delta.py \--base /path/to/llama-7b \--target /output/path/to/shikra-7b-merge \--delta shikras/shikra-7b-delta-v1

得到了可用的模型参数shikra-7b-merge。
注意要把参数文件夹里config里的模型路径改成merge版的。
此外还需要下载clip模型参数:
https://huggingface.co/openai/clip-vit-large-patch14

代码和配置文件中有多处调用/openai/clip-vit-large-patch14,要改成本地版本。如果不预先下载,应该会在运行时自动下载,大家看网络情况自行选择。

  1. 我写的demo文件,用于在命令行测试模型效果,主要是为了不用gradiofastapi这些东西。
import argparse
import os
import sys
import base64
import logging
import time
from pathlib import Path
from io import BytesIOimport torch
import uvicorn
import transformers
from PIL import Image
from mmengine import Config
from transformers import BitsAndBytesConfigsys.path.append(str(Path(__file__).parent.parent.parent))from mllm.dataset.process_function import PlainBoxFormatter
from mllm.dataset.builder import prepare_interactive
from mllm.models.builder.build_shikra import load_pretrained_shikra
from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square# Set up logging
log_level = logging.DEBUG
transformers.logging.set_verbosity(log_level)
transformers.logging.enable_default_handler()
transformers.logging.enable_explicit_format()# prompt for coco# Argument parsing
parser = argparse.ArgumentParser("Shikra Local Demo")
parser.add_argument('--model_path', default = "xxx/shikra-merge", help="Path to the model")
parser.add_argument('--load_in_8bit', action='store_true', help="Load model in 8-bit precision")
parser.add_argument('--image_path', default = "xxx/shikra-main/mllm/demo/assets/ball.jpg", help="Path to the image file")
parser.add_argument('--text', default="What do you see in this image? Please mention the objects and their locations using the format [x1,y1,x2,y2].", help="Text prompt")
parser.add_argument('--boxes_value', nargs='+', type=int, default=[], help="Bounding box values (x1, y1, x2, y2)")
parser.add_argument('--boxes_seq', nargs='+', type=int, default=[], help="Sequence of bounding boxes")
parser.add_argument('--do_sample', action='store_true', help="Use sampling during generation")
parser.add_argument('--max_length', type=int, default=512, help="Maximum length of the output")
parser.add_argument('--top_p', type=float, default=1.0, help="Top-p value for sampling")
parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for sampling")args = parser.parse_args()
model_name_or_path = args.model_path
# Model initialization
model_args = Config(dict(type='shikra',version='v1',# checkpoint configcache_dir=None,model_name_or_path=model_name_or_path,vision_tower=r'xxx/clip-vit-large-patch14',pretrain_mm_mlp_adapter=None,# model configmm_vision_select_layer=-2,model_max_length=2048,# finetune configfreeze_backbone=False,tune_mm_mlp_adapter=False,freeze_mm_mlp_adapter=False,# data process configis_multimodal=True,sep_image_conv_front=False,image_token_len=256,mm_use_im_start_end=True,target_processor=dict(boxes=dict(type='PlainBoxFormatter'),),process_func_args=dict(conv=dict(type='ShikraConvProcess'),target=dict(type='BoxFormatProcess'),text=dict(type='ShikraTextProcess'),image=dict(type='ShikraImageProcessor'),),conv_args=dict(conv_template='vicuna_v1.1',transforms=dict(type='Expand2square'),tokenize_kwargs=dict(truncation_size=None),),gen_kwargs_set_pad_token_id=True,gen_kwargs_set_bos_token_id=True,gen_kwargs_set_eos_token_id=True,
))
training_args = Config(dict(bf16=False,fp16=True,device='cuda',fsdp=None,
))quantization_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_8bit=args.load_in_8bit,)
) if args.load_in_8bit else dict()model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs)# Convert the model and vision tower to float16
if not getattr(model, 'is_quantized', False):model.to(dtype=torch.float16, device=torch.device('cuda'))
if not getattr(model.model.vision_tower[0], 'is_quantized', False):model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda'))preprocessor['target'] = {'boxes': PlainBoxFormatter()}
tokenizer = preprocessor['text']# Load and preprocess the image
pil_image = Image.open(args.image_path).convert("RGB")
ds = prepare_interactive(model_args, preprocessor)image = expand2square(pil_image)
boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in zip(args.boxes_value[::2], args.boxes_value[1::2], args.boxes_value[2::2], args.boxes_value[3::2])]ds.set_image(image)
ds.append_message(role=ds.roles[0], message=args.text, boxes=boxes_value, boxes_seq=args.boxes_seq)
model_inputs = ds.to_model_input()
model_inputs['images'] = model_inputs['images'].to(torch.float16)# Generate
gen_kwargs = dict(use_cache=True,do_sample=args.do_sample,pad_token_id=tokenizer.pad_token_id,bos_token_id=tokenizer.bos_token_id,eos_token_id=tokenizer.eos_token_id,max_new_tokens=args.max_length,top_p=args.top_p,temperature=args.temperature,
)input_ids = model_inputs['input_ids']
st_time = time.time()
with torch.inference_mode():with torch.autocast(device_type='cuda', dtype=torch.float16):output_ids = model.generate(**model_inputs, **gen_kwargs)
print(f"Generated in {time.time() - st_time} seconds")input_token_len = input_ids.shape[-1]
response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0]
print(f"Response: {response}")

这么良心,点个关注吧,会持续更新多模态大模型相关内容。


http://www.ppmy.cn/ops/97880.html

相关文章

Linux:Linux多线程

目录 线程概念 什么是线程 二级页表 线程的优点 线程的缺点 线程异常 线程用途 Linux进程VS线程 进程和线程 进程的多个线程共享 进程和线程的关系 Linux线程控制 POSIX线程库 线程创建 线程等待 线程终止 分离线程 线程ID及进程地址空间布局 线程概念 什么…

XSS和DOM破坏案例

XSS案例 环境地址&#xff1a;XSS Game - Learning XSS Made Simple! | Created by PwnFunction 1.Ma Spaghet! 源码&#xff1a; <!-- Challenge --> <h2 id"spaghet"></h2> <script>spaghet.innerHTML (new URL(location).searchParam…

苹果上架没有iphone、没有ipad也可以生成截屏

使用flutter、uniapp或其他跨平台框架开发ios的APP&#xff0c;上架的时候都会遇到一个问题&#xff0c;上架的时候需要各种尺寸的设备来做ios截屏。 比如目前最新的要求是&#xff0c;iphone需要三种不同尺寸的设备的截屏&#xff0c;假如支持ipad则还需要使用ipad 2代和ipad…

HTML静态网页成品作业(HTML+CSS)——非遗昆曲介绍设计制作(1个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有1个页面。 二、作品演示 三、代…

封装了一个iOS评论弹窗

封装了一个iOS类似抖音效果的评论弹窗&#xff0c;可以跟手滑动的效果 主要有下面两需要注意的点 双手势响应 因为我们的弹窗既要支持拖动整体上下滑动&#xff0c;还要支持内容列表的滑动 &#xff0c;所以&#xff0c;我们需要在内容视图中添加一个滑动的手势&#xff0c;以…

页面设计任务 个人信息页面

目录 成品: 任务要求&#xff1a; 1. 创建一个基本的个人简介网页 2. 样式和布局要求 3. 详细样式要求 源码&#xff1a; 详细讲解&#xff1a; 1.导航栏部分&#xff1a; 2.头像和介绍部分: 3.技能列表部分 4.作品集部分 成品: 任务要求&#xff1a; 1. 创建一个基本…

Python异常处理

在Python中&#xff0c;异常处理是一种重要的编程结构&#xff0c;它允许你在代码运行时检测并响应错误或异常情况。异常处理使得程序在遇到错误时能够优雅地处理这些错误&#xff0c;而不是直接崩溃或终止执行。 下面是根据代码示例来说明&#xff1a; input_str1 input(&q…

分析 Runtime.getRuntime() 执行阻塞原因

1、起因 线上系统通过 git 命令执行的方式获取远程仓库分支&#xff0c;一直运行正常的接口&#xff0c;突然出现超时&#xff0c;接口无法响应&#xff0c;分析验证发现只有个别仓库获取分支会出现这种情况&#xff0c;其他都还是可以正常获取到分支结果信息。 2、分析异常原…