Chinese-Clip实现以文搜图和以图搜图(transformers版)

ops/2024/12/19 12:24:37/

本文不生产技术,只做技术的搬运工!

前言

        作者昨天使用cn_clip库实现了一版,但是觉得大家复现配置环境可能有点复杂,因此有使用transformers库实现了一版,提供大家选择,第一篇参考链接如下:

Chinese-Clip实现以文搜图和以图搜图-CSDN博客文章浏览阅读728次,点赞9次,收藏17次。使用clip实现以文搜图和以图搜图的图文检索功能https://blog.csdn.net/qq_44908396/article/details/144537426

 环境配置

transformers:

pip install transformers

milvus:

pip install -U pymilvus

pytorch:

pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117

源码

数据入库

python">from PIL import Image
import requests
from transformers import ChineseCLIPProcessor, ChineseCLIPModel
import torch
import os
import numpy as np
from pymilvus import MilvusClient
client = MilvusClient("BlingPic.db")
if client.has_collection(collection_name="text_image"):client.drop_collection(collection_name="text_image")
client.create_collection(collection_name="text_image",dimension=512,  # The vectors we will use in this demo has 768 dimensionsmetric_type="COSINE"
)def getFileList(dir, Filelist, ext=None):"""获取文件夹及其子文件夹中文件列表输入 dir:文件夹根目录输入 ext: 扩展名返回: 文件路径列表"""newDir = dirif os.path.isfile(dir):if ext is None:Filelist.append(dir)else:if ext in dir:Filelist.append(dir)elif os.path.isdir(dir):for s in os.listdir(dir):newDir = os.path.join(dir, s)getFileList(newDir, Filelist, ext)return Filelistif __name__ == "__main__":device = "cuda" if torch.cuda.is_available() else "cpu"model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")model.to(device)preprocess = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")model.eval()img_dir = r"/home/turing/图片/BlingPic"image_path_list = []image_path_list = getFileList(img_dir, image_path_list, '.jpg')data = []i = 0for image_path in image_path_list:temp = {}image = Image.open(image_path)with torch.no_grad():inputs = preprocess(images=image, return_tensors="pt").to(device)image_features = model.get_image_features(**inputs)image_features = image_features / image_features.norm(dim=-1, keepdim=True)  # normalizeimage_features = image_features.cpu().numpy().astype(np.float32).flatten()# 将特征向量转换为字符串temp['id'] = itemp['image_path'] = image_pathtemp['vector'] = image_featuresdata.append(temp)i = i + 1print(i)res = client.insert(collection_name="text_image", data=data)

上述代码会在指定路径生成一个BlingPic.db的文件,这就说明数据完成了入库,我们接下来进行调用

数据查询

python">from PIL import Image,ImageDraw,ImageFont
from transformers import ChineseCLIPProcessor, ChineseCLIPModel,AutoTokenizer
import torch
import numpy as np
from pymilvus import MilvusClient
client = MilvusClient("BlingPic.db")
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']def display_single_image_with_text(image_path):with Image.open(image_path) as img:draw = ImageDraw.Draw(img)# 设置字体和字号,这里假设你有一个可用的字体文件,例如 Arial.ttf# 如果没有,可以使用系统默认字体try:font = ImageFont.truetype("Arial.ttf", 30)except IOError:font = ImageFont.load_default()# 文本内容和颜色text = "Example image"text_color = (255, 0, 0)  # 红色# 文本位置text_position = (10, 10)# 绘制文本draw.text(text_position, text, fill=text_color, font=font)# 显示图像img.show()def display_images_in_grid(image_paths, images_per_row=3):# 计算需要的行数num_images = len(image_paths)num_rows = (num_images + images_per_row - 1) // images_per_row# 打开所有图像并调整大小images = []for path in image_paths:with Image.open(path) as img:img = img.resize((200, 200))  # 调整图像大小以适应画布images.append(img)# 创建一个空白画布canvas_width = images_per_row * 200canvas_height = num_rows * 200canvas = Image.new('RGB', (canvas_width, canvas_height), (255, 255, 255))# 将图像粘贴到画布上for idx, img in enumerate(images):row = idx // images_per_rowcol = idx % images_per_rowposition = (col * 200, row * 200)canvas.paste(img, position)# 显示画布canvas.show()def load_model(device):model = ChineseCLIPModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")model.to(device)preprocess = ChineseCLIPProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")model.eval()return model, preprocessdef text_encode(model,text,device):tokenizer = AutoTokenizer.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")inputs = tokenizer(text, return_tensors="pt").to(device)with torch.no_grad():text_features = model.get_text_features(**inputs)text_features /= text_features.norm(dim=-1, keepdim=True)text_features = text_features.cpu().numpy().astype(np.float32)return text_featuresdef image_encode(model,preprocess,image_path,device):image = Image.open(image_path)with torch.no_grad():inputs = preprocess(images=image, return_tensors="pt").to(device)image_features = model.get_image_features(**inputs)image_features = image_features / image_features.norm(dim=-1, keepdim=True)  # normalizeimage_features = image_features.cpu().numpy().astype(np.float32)return image_featuresif __name__ == "__main__":search_text = "大象"search_image_path = "/home/project_python/Chinese-CLIP/my_dataset/coco/val2017/000000000285.jpg"device = "cuda" if torch.cuda.is_available() else "cpu"model, preprocess = load_model(device)text_flag = Falseif text_flag:text_features = text_encode(model,search_text,device)results = client.search("text_image",data=text_features,output_fields=["image_path"],search_params={"metric_type": "COSINE"},limit=36)else:display_single_image_with_text(search_image_path)image_features = image_encode(model,preprocess,search_image_path,device)results = client.search("text_image",data=image_features,output_fields=["image_path"],search_params={"metric_type": "COSINE"},limit=36)image_list = []for i,result in enumerate(results[0]):image_list.append(result["entity"]["image_path"])display_images_in_grid(image_list,9)

上述代码使用text_flag控制是以文搜图还是以图搜图,True时为以文搜图,False时为以图搜图

实现效果

以文搜图

以图搜图

示例图像:

搜索结果:

附加

权重下载遇到问题参考如下链接:

解决OSError: We couldn‘t connect to ‘https://huggingface.co‘ to load this file-CSDN博客文章浏览阅读1.4k次,点赞6次,收藏2次。解决hugging face无法下载模型的问题https://blog.csdn.net/qq_44908396/article/details/142516867


http://www.ppmy.cn/ops/143177.html

相关文章

OpenCV--图像拼接

OpenCV--图像拼接 代码和笔记 代码和笔记 import cv2 import numpy as np""" 图像拼接: 1. 读取图片,灰度化 2. 计算各自的特征点和描述子 3. 匹配特征 4. 计算单应性矩阵 5. 透视变换 6. 创建一个大图,放图两张图 "&qu…

基于yolov10的遥感影像目标检测系统,支持图像检测,视频检测和实时摄像检测功能(pytorch框架,python源码)

更多目标检测、图像分类识别、目标检测等其他项目可看我主页其他文章 功能演示: 基于yolov10的遥感影像目标检测系统,既支持图像检测,也支持视频和摄像实时检测【pytorch框架、python源码】_哔哩哔哩_bilibili (一)…

AtomGit 开源生态应用开发赛报名开始啦

目录 1、赛项背景2、赛项信息3、报名链接4、赛题一:开发者原创声明(DCO)应用开发赛题要求目标核心功能 5、赛题二:基于 OpenHarmony 的开源社区应用开发简介赛题要求 6、参赛作品提交初赛阶段决赛阶段 7、参赛作品提交方式 1、赛项…

Linux——Shell

if 语句 格式:if list; then list; [ elif list; then list; ] ... [ else list; ] fi 单分支 if 条件表达式; then 命令 fi 示例: #!/bin/bash N10 if [ $N -gt 5 ]; then echo yes fi # bash test.sh yes 双分支 if 条件表达式; then 命令 else 命令…

Linux安装部署Redis(超级详细)

前言 网上搜索了一筐如何在Linux下安装部署Redis的文章,各种文章混搭在一起勉强安装成功了。自己也记录下,方便后续安装时候有个借鉴之处。 Redis版本 5.0.4服务器版本 Linux CentOS 7.6 64位 下载Redis 进入官网找到下载地址 https://redis.io/down…

爬虫抓取的数据如何有效存储和管理?

在现代数据驱动的世界中,爬虫技术已成为获取网络数据的重要手段。然而,如何有效地存储和管理这些数据是一个关键问题。本文将详细介绍几种有效的数据存储和管理方法,并提供相应的Java代码示例。 1. 数据存储方式 1.1 文件存储 文件存储是最…

access数据库代做/mysql代做/Sql server数据库代做辅导设计服务

针对Access数据库、MySQL以及SQL Server数据库的代做和辅导设计服务,以下是一些关键信息和建议: 一、服务概述 这些服务通常包括数据库的设计、创建、优化、维护以及相关的编程和查询编写等。无论是Access这样的桌面关系数据库管理系统(RDB…

【漏洞分析】DDOS攻防分析(四)——TCP篇

0x00 TCP DDOS攻击案例 政治因素一直是黑客发动网络攻击的一个重要动机。2015年12月,著名黑客组织匿名者(Anonymous)发布视频谴责土耳其支持ISIS,并向土耳其发动了史上最大规模的DDoS攻击。 2015年12月14日开始,大规模网络攻击导致土耳其银…