在现代深度学习和计算机视觉应用中,高性能推理是关键。本文将介绍如何使用 TensorRT 和 Python 构建一个高性能的图像推理服务器。该服务器能够接收客户端发送的图像数据,使用 TensorRT 进行推理,并将结果返回给客户端。
1. 概述
1.1 项目目标
-
构建一个基于 TCP 协议的图像推理服务器。
-
使用 TensorRT 加速深度学习模型的推理。
-
支持多客户端并发处理。
1.2 技术栈
-
TensorRT:NVIDIA 的高性能深度学习推理库。
-
PyCUDA:用于在 Python 中操作 CUDA。
-
OpenCV 和 PIL:用于图像处理。
-
Socket:用于实现网络通信。
-
多进程:使用
ProcessPoolExecutor
实现并发处理。
2. 代码实现
2.1 依赖库
首先,确保安装了以下 Python 库:
pip install numpy opencv-python pillow pycuda tensorrt torch
2.2 核心代码
2.2.1 图像预处理
def normalize(image: np.ndarray):"""图像归一化处理"""resize_image = cv2.resize(image, (700, 800))image = resize_image.astype(np.float32)mean = (0.485, 0.456, 0.406)std = (0.229, 0.224, 0.225)image /= 255.0image -= meanimage /= stdreturn image
2.2.2 加载 TensorRT 引擎
def load_engine(engine_path):"""加载 TensorRT 引擎"""TRT_LOGGER = trt.Logger(trt.Logger.WARNING)with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())
2.2.3 TensorRT 推理
def ai_predict(input_data,engine,batch_size=1):with engine.create_execution_context() as context:stream = cuda.Stream()# 分配输入和输出的内存print(input_data.dtype,input_data.nbytes,input_data.shape)h_output = cuda.pagelocked_empty((batch_size, 5), dtype=np.float32) # 假设输出是 5 类d_input = cuda.mem_alloc(input_data.nbytes)d_output = cuda.mem_alloc(h_output.nbytes)# 获取输入和输出的名称tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]input_name = tensor_names[0]output_name = tensor_names[1]# 设置输入张量的形状# context.set_input_shape(input_name, (batch_size, 3, 800, 700))context.set_tensor_address(input_name, int(d_input))context.set_tensor_address(output_name, int(d_output))# 将输入数据复制到 GPUcuda.memcpy_htod_async(d_input, input_data, stream)# 设置输入和输出的地址# 执行推理context.execute_async_v3(stream.handle)h_output = cuda.aligned_zeros((batch_size, 5), dtype=input_data.dtype)cuda.memcpy_dtoh_async(h_output, d_output, stream)# context.execute_v2([d_input,d_output])# cuda.memcpy_dtoh(h_output,d_output)stream.synchronize()# 返回输出数据return h_output
2.2.4 图像处理与推理
def process_image(image_data, engine, class_indict):"""处理图像并进行推理"""try:image = Image.open(io.BytesIO(image_data))image_array = np.array(image)if len(image_array.shape) < 3:image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)normalized_image = normalize(image_array)normalized_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)batch_array = normalized_image.astype(np.float32)batch_array = np.ascontiguousarray(batch_array)start_time = time.time()pro = ai_predict(batch_array, engine, batch_size=1)predict_time = time.time() - start_timeprint("AI预测时间:", predict_time)pro_val = torch.softmax(torch.from_numpy(pro), dim=1)outputs = torch.argmax(pro_val, dim=1).numpy()for index, index_pro in enumerate(pro_val):maxpro_index = torch.argmax(index_pro).to("cpu").numpy()if maxpro_index == 4 and index_pro[-1].to("cpu").numpy() < 0.60:outputs[index] = 0pre_label = [class_indict[str(num)] for num in outputs]print("label:{},pro:{}".format(pre_label, np.max(pro_val.numpy())))cv2.putText(image_array, pre_label[0], (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 1, cv2.LINE_AA)with io.BytesIO() as output:processed_pil_image = Image.fromarray(image_array)processed_pil_image.save(output, format="JPEG")processed_image_data = output.getvalue()return pre_label, processed_image_dataexcept Exception as e:print(f"处理图像时出错: {e}")return None
2.2.5 客户端处理
def handle_client(conn, addr, engine_file_path, class_indict):"""处理客户端连接"""engine = load_engine(engine_file_path)with conn:print(f"已连接 {addr}")file_size_data = conn.recv(4)file_size = int.from_bytes(file_size_data, byteorder='big')image_data = b''remaining_size = file_sizewhile remaining_size > 0:chunk = conn.recv(min(65536, remaining_size))if not chunk:breakimage_data += chunkremaining_size -= len(chunk)print(f"从 {addr} 接收图像数据完成")start_time = time.time()pro_label, processed_image_data = process_image(image_data, engine, class_indict)img_time = time.time() - start_timeprint("图像处理时间:", img_time)if processed_image_data and len(pro_label) > 0:file_size = len(processed_image_data)conn.sendall(file_size.to_bytes(4, byteorder='big'))conn.sendall(processed_image_data)data_bytes = json.dumps(pro_label).encode('utf-8')conn.sendall(data_bytes)print(f"处理后的图像已发送回 {addr}")else:print(f"无法处理图像,未发送数据到 {addr}")
2.2.6 服务器启动
复制
def server_process(host, port, engine_file_path, class_indict, max_workers=4):"""启动服务器"""server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)server_socket.bind((host, port))server_socket.listen(5)print(f"Server listening on {host}:{port}")with ProcessPoolExecutor(max_workers=max_workers) as pool:try:while True:client_socket, addr = server_socket.accept()print(f"Accepted connection from {addr}")pool.submit(handle_client, client_socket, addr, engine_file_path, class_indict)except KeyboardInterrupt:print("Server is shutting down...")finally:server_socket.close()
2.2.7 主程序
复制
if __name__ == "__main__":engine_file_path = "best_weight/resnet34.trt"json_label_path = './class_indices.json'assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)json_file = open(json_label_path, 'r')class_indict = json.load(json_file)HOST = '0.0.0.0' # 服务器IP地址PORT = 4099 # 监听的端口server_process(HOST, PORT, engine_file_path, class_indict)
3. 运行与测试
-
启动服务器:
bash
复制
python server.py
-
使用客户端发送图像数据并接收推理结果。
4. 总结
本文介绍了如何使用 TensorRT 和 Python 构建一个高性能的图像推理服务器。通过多进程和 TensorRT 的加速,服务器能够高效处理客户端请求并返回推理结果。希望本文对您有所帮助!