使用 TensorRT 和 Python 实现高性能图像推理服务器

news/2025/2/8 19:49:04/

在现代深度学习和计算机视觉应用中,高性能推理是关键。本文将介绍如何使用 TensorRT 和 Python 构建一个高性能的图像推理服务器。该服务器能够接收客户端发送的图像数据,使用 TensorRT 进行推理,并将结果返回给客户端。

1. 概述

1.1 项目目标

  • 构建一个基于 TCP 协议的图像推理服务器。

  • 使用 TensorRT 加速深度学习模型的推理。

  • 支持多客户端并发处理。

1.2 技术栈

  • TensorRT:NVIDIA 的高性能深度学习推理库。

  • PyCUDA:用于在 Python 中操作 CUDA。

  • OpenCV 和 PIL:用于图像处理。

  • Socket:用于实现网络通信。

  • 多进程:使用 ProcessPoolExecutor 实现并发处理。


2. 代码实现

2.1 依赖库

首先,确保安装了以下 Python 库:

pip install numpy opencv-python pillow pycuda tensorrt torch

2.2 核心代码

2.2.1 图像预处理
def normalize(image: np.ndarray):"""图像归一化处理"""resize_image = cv2.resize(image, (700, 800))image = resize_image.astype(np.float32)mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)image /= 255.0image -= meanimage /= stdreturn image
2.2.2 加载 TensorRT 引擎
def load_engine(engine_path):"""加载 TensorRT 引擎"""TRT_LOGGER = trt.Logger(trt.Logger.WARNING)with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())
2.2.3 TensorRT 推理
def ai_predict(input_data,engine,batch_size=1):with engine.create_execution_context() as context:stream = cuda.Stream()# 分配输入和输出的内存print(input_data.dtype,input_data.nbytes,input_data.shape)h_output = cuda.pagelocked_empty((batch_size, 5), dtype=np.float32)  # 假设输出是 5 类d_input = cuda.mem_alloc(input_data.nbytes)d_output = cuda.mem_alloc(h_output.nbytes)# 获取输入和输出的名称tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]input_name = tensor_names[0]output_name = tensor_names[1]# 设置输入张量的形状# context.set_input_shape(input_name, (batch_size, 3, 800, 700))context.set_tensor_address(input_name, int(d_input))context.set_tensor_address(output_name, int(d_output))# 将输入数据复制到 GPUcuda.memcpy_htod_async(d_input, input_data, stream)# 设置输入和输出的地址# 执行推理context.execute_async_v3(stream.handle)h_output = cuda.aligned_zeros((batch_size, 5), dtype=input_data.dtype)cuda.memcpy_dtoh_async(h_output, d_output, stream)# context.execute_v2([d_input,d_output])# cuda.memcpy_dtoh(h_output,d_output)stream.synchronize()# 返回输出数据return h_output
2.2.4 图像处理与推理
def process_image(image_data, engine, class_indict):"""处理图像并进行推理"""try:image = Image.open(io.BytesIO(image_data))image_array = np.array(image)if len(image_array.shape) < 3:image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)normalized_image = normalize(image_array)normalized_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)batch_array = normalized_image.astype(np.float32)batch_array = np.ascontiguousarray(batch_array)start_time = time.time()pro = ai_predict(batch_array, engine, batch_size=1)predict_time = time.time() - start_timeprint("AI预测时间:", predict_time)pro_val = torch.softmax(torch.from_numpy(pro), dim=1)outputs = torch.argmax(pro_val, dim=1).numpy()for index, index_pro in enumerate(pro_val):maxpro_index = torch.argmax(index_pro).to("cpu").numpy()if maxpro_index == 4 and index_pro[-1].to("cpu").numpy() < 0.60:outputs[index] = 0pre_label = [class_indict[str(num)] for num in outputs]print("label:{},pro:{}".format(pre_label, np.max(pro_val.numpy())))cv2.putText(image_array, pre_label[0], (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 1, cv2.LINE_AA)with io.BytesIO() as output:processed_pil_image = Image.fromarray(image_array)processed_pil_image.save(output, format="JPEG")processed_image_data = output.getvalue()return pre_label, processed_image_dataexcept Exception as e:print(f"处理图像时出错: {e}")return None
2.2.5 客户端处理
def handle_client(conn, addr, engine_file_path, class_indict):"""处理客户端连接"""engine = load_engine(engine_file_path)with conn:print(f"已连接 {addr}")file_size_data = conn.recv(4)file_size = int.from_bytes(file_size_data, byteorder='big')image_data = b''remaining_size = file_sizewhile remaining_size > 0:chunk = conn.recv(min(65536, remaining_size))if not chunk:breakimage_data += chunkremaining_size -= len(chunk)print(f"从 {addr} 接收图像数据完成")start_time = time.time()pro_label, processed_image_data = process_image(image_data, engine, class_indict)img_time = time.time() - start_timeprint("图像处理时间:", img_time)if processed_image_data and len(pro_label) > 0:file_size = len(processed_image_data)conn.sendall(file_size.to_bytes(4, byteorder='big'))conn.sendall(processed_image_data)data_bytes = json.dumps(pro_label).encode('utf-8')conn.sendall(data_bytes)print(f"处理后的图像已发送回 {addr}")else:print(f"无法处理图像,未发送数据到 {addr}")
2.2.6 服务器启动

python

复制

def server_process(host, port, engine_file_path, class_indict, max_workers=4):"""启动服务器"""server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)server_socket.bind((host, port))server_socket.listen(5)print(f"Server listening on {host}:{port}")with ProcessPoolExecutor(max_workers=max_workers) as pool:try:while True:client_socket, addr = server_socket.accept()print(f"Accepted connection from {addr}")pool.submit(handle_client, client_socket, addr, engine_file_path, class_indict)except KeyboardInterrupt:print("Server is shutting down...")finally:server_socket.close()
2.2.7 主程序

python

复制

if __name__ == "__main__":engine_file_path = "best_weight/resnet34.trt"json_label_path = './class_indices.json'assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)json_file = open(json_label_path, 'r')class_indict = json.load(json_file)HOST = '0.0.0.0'  # 服务器IP地址PORT = 4099  # 监听的端口server_process(HOST, PORT, engine_file_path, class_indict)

3. 运行与测试

  1. 启动服务器:

    bash

    复制

    python server.py
  2. 使用客户端发送图像数据并接收推理结果。


4. 总结

本文介绍了如何使用 TensorRT 和 Python 构建一个高性能的图像推理服务器。通过多进程和 TensorRT 的加速,服务器能够高效处理客户端请求并返回推理结果。希望本文对您有所帮助!


http://www.ppmy.cn/news/1570395.html

相关文章

大模型做导师之方案版本比较

背景&#xff1a; 在阅读lightRAG项目时&#xff0c;利用LLM辅助进行理解&#xff0c;当询问LLM如何自定义一个符合项目要求的大模型调用函数时&#xff0c;LLM给出了两个不同的版本。借此想提升一下自己的编程质量&#xff0c;于是让LLM对两个版本进行点评比较。 实现建议 基础…

2. 【.NET Aspire 从入门到实战】--理论入门与环境搭建--.NET Aspire 概览

在当今快速发展的软件开发领域&#xff0c;构建高效、可靠且易于维护的云原生应用程序已成为开发者和企业的核心需求。.NET Aspire 作为一款专为云原生应用设计的开发框架&#xff0c;旨在简化分布式系统的构建和管理&#xff0c;提供了一整套工具、模板和集成包&#xff0c;帮…

maven不能导入依赖和插件Cannot resolve plugin org.apache.maven.plugins:maven-xxx

新建了个工程&#xff0c;新设置了一个仓库地址&#xff0c;maven导入报错&#xff1a; 连最基础的maven自带的插件都无法导入&#xff1a; Cannot resolve plugin org.apache.maven.plugins:maven-install-plugin:2.4Try to run Maven import with -U flag (force update sna…

深入理解Docker:为你的爬虫项目提供隔离环境

1. 明确目标 前置知识 在本教程中&#xff0c;我们的目标是利用Docker构建一个隔离环境&#xff0c;运行一个Python爬虫项目。该项目将采集小红书目标视频页面中的简介和评论&#xff0c;主要涵盖以下技术点&#xff1a; Docker隔离环境&#xff1a;通过Docker容器运行爬虫&…

2024.1版android studio创建Java语言项目+上传gitee

1.在gitee上创建仓库 Gitee 创建仓库并邀请成员指南_gitee创建仓库邀请成员-CSDN博客 见1 2.新建android studio项目 3.在Android studio配置gitee Android Studio提交代码到gitee仓库_android log in to gitee-CSDN博客 其中的一二步 p.s.添加gitee账户选择password时&a…

Qt实现简易音乐播放器

使用Qt6实现简易音乐播放器&#xff0c;效果如下&#xff1a; github&#xff1a; Gabriel-gxb/MusicPlayer: qt6实现简易音乐播放器 一、整体架构 基于Qt框架构建 整个音乐播放器程序以Qt框架为基础进行开发。Qt提供了丰富的类库和工具&#xff0c;方便开发者构建图形用户界…

Redis企业开发实战(二)——点评项目之商户缓存查询

目录 一、缓存介绍 二、缓存更新策略 三、如何保证redis与数据库一致性 1.解决方案概述 2.双写策略 3.双删策略 3.1延迟双删的目的 4.数据重要程度划分 四、缓存穿透 (一)缓存穿透解决方案 (二)缓存穿透示意图 五、缓存雪崩 (一)缓存雪崩解决方案 (二)缓存雪崩…

解锁 DeepSeek 模型高效部署密码:蓝耘平台全解析

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…