PyTorch翻译官网教程-DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK

news/2024/11/17 19:02:56/

官网链接

Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation

通过flask的rest API在python中部署pytorch

在本教程中,我们将使用Flask部署PyTorch模型,并开放用于模型推断的REST API。特别是,我们将部署一个预训练的DenseNet 121模型来检测图像。

这是关于在生产环境中部署PyTorch模型的系列教程中的第一篇。使用Flask这种方式是迄今为止部署PyTorch模型的最简单方法,但它不适用于具有高性能要求的用例。

  • 如果你已经熟悉了TorchScript,你可以直接跳到我们的加载一个TorchScript模型在c++教程。(Loading a TorchScript Model in C++ )
  • 如果你需要对TorchScript进行复习,请查看我们的TorchScript入门教程。(Intro a TorchScript )

API定义

我们将首先定义API 路径、请求和响应类型。我们的API路径是 /predict它接受带有包含图像的文件参数的HTTP POST请求。响应将是JSON响应,其中包含预测结果。

{"class_id": "n02124075", "class_name": "Egyptian_cat"}{"class_id": "n02124075", "class_name": "Egyptian_cat"}


依赖项

运行以下命令安装所需的依赖项:

$ pip install Flask==2.0.1 torchvision==0.10.0

简单Web服务器

下面是一个简单的web服务器,摘自Flask的文档

from flask import Flask
app = Flask(__name__)@app.route('/')
def hello():return 'Hello World!'

将上面的代码片段保存在一个名为app.py的文件中,现在你可以通过输入以下命令来运行Flask开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在web浏览器中访问http://localhost:5000/时,您将看到Hello World!文本

我们将对上面的代码片段做一些修改,使其适合我们的API定义。首先,我们将把方法重命名为predict。我们将把请求路径更新为/predict。由于图像文件将通过HTTP POST请求发送,我们将更新它,使其也只接受POST请求。

@app.route('/predict', methods=['POST'])
def predict():return 'Hello World!'

我们还将更改响应类型,以便它返回一个包含ImageNet类id和名称的JSON响应。更新后的app.py文件现在将是:

from flask import Flask, jsonify
app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在下一节中,我们将重点讨论如何编写推理代码。这将涉及两个部分,一个是我们准备图像,以便它可以馈送到DenseNet,接下来,我们将编写代码以从模型中获得实际预测。

准备图像

DenseNet模型要求图像为3通道RGB图像,大小为224 x 224。我们还将用所需的均值和标准差值对图像张量进行归一化。你可以在这里读到更多(here)。

我们将使用torchvision 库中的 transforms ,并构建一个变换管道,它可以根据要求变换我们的图像。你可以在这里关于变换的内容(here)。

import ioimport torchvision.transforms as transforms
from PIL import Imagedef transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)

上述方法接受字节数据的图像,应用一些列的transforms 并返回一个张量。为了测试上述方法,以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径)并查看是否返回一个张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()tensor = transform_image(image_bytes=image_bytes)print(tensor)


预测

现在将使用预训练的DenseNet 121模型来预测图像类别。我们将使用torchvision库,加载模型并获得推理结果。虽然我们将在本例中使用预训练模型,但您可以对自己的模型使用相同的方法。了解更多关于加载模型的信息(tutorial)。

from torchvision import models# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)return y_hat


张量y_hat 将包含预测类别的索引id, 然而,我们需要一个人类可读的类名。为此,我们需要一个类别id和命名的映射。下载 imagenet_class_index.json 这个文件( this file),并记住保存它的位置(或者,如果您遵循本教程中的确切步骤,将其保存在教程/_static中)。这个文件包含ImageNet类别id到ImageNet类名的映射。我们将加载这个JSON文件并获取预测类别索引的类名。

import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]


在使用imagenet_class_index字典之前,首先我们将把张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()print(get_prediction(image_bytes=image_bytes))

你应该得到这样的返回:

['n02124075', 'Egyptian_cat']

数组中的第一项是ImageNet类别id,第二项是人类可读的名称。

注意

您是否注意到model变量不是get_prediction方法的局部变量,或者说为什么model是一个全局变量?就内存和计算而言,加载模型可能是一项昂贵的操作。如果我们在get_prediction方法中加载模型,那么每次调用该方法时都会不必要地加载模型。因为我们正在构建一个web服务器,每秒可能有数千个请求,我们不应该浪费时间为每个推理加载模型。因此,我们只将模型加载到内存中一次。在生产系统中,为了能够大规模地处理请求,必须高效地使用计算,因此通常应该在处理请求之前加载模型。


在我们的API服务器中集成模型

在最后一部分中,我们将把模型添加到Flask API服务器中。由于我们的API服务器应该接受一个图像文件,我们将更新我们的预测方法来从请求中读取文件:

from flask import request@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# we will get the file from the requestfile = request.files['file']# convert that to bytesimg_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现在已经完成。以下是完整版本;将路径替换为您保存文件的路径,它应该运行:

import io
import jsonfrom torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, requestapp = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()def transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':file = request.files['file']img_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__':app.run()

让我们测试一下我们的web服务器!运行:

$ FLASK_ENV=development FLASK_APP=app.py flask run


我们可以使用requests库向我们的应用发送POST请求:

import requestsresp = requests.post("http://localhost:5000/predict",files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印rep .json()将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

下一个步骤

我们编写的服务器非常简单,可能无法完成生产应用程序所需的所有功能。所以,这里有一些你可以做的事情来让它变得更好:

  • 请求路径 /predict假定请求中总是有一个图像文件。这可能并不适用于所有请求。我们的用户可以发送带有不同参数的图像或根本不发送图像。
  • 用户也可以发送非图像类型的文件。由于我们不处理错误,这将破坏我们的服务器。显式添加异常的错误处理路径,将使我们能够更好地处理错误输入
  • 尽管该模型可以识别大量的图像类别,但它可能无法识别所有的图像。优化实现以处理模型无法识别图像中的任何内容的情况。
  • 我们以开发模式运行Flask服务器,这种模式不适合部署到生产环境中。您可以查看本教程,了解如何在生产环境中部署Flask服务器。(this tutorial
  • 您还可以通过创建带有表单的页面来添加UI,该表单接受图像并显示预测结果。请查看类似项目的演示及其源代码。(source code.
  • 在本教程中,我们只展示了如何构建一个每次可以返回单个图像预测的服务。我们可以修改我们的服务,使其能够一次返回多个图像的预测结果。此外,service-streamer库会自动将请求排队到您的服务中,并将它们采样到可以馈送到模型中的小批量中。您可以查看本教程(this tutorial.)。
  • 最后,我们鼓励您查看页面顶部链接的关于部署PyTorch模型的其他教程.


http://www.ppmy.cn/news/897839.html

相关文章

注册微信小程序并开通微信支付流程

首先我们要梳理清楚&#xff0c;微信到底有哪些平台&#xff0c;以及每个平台的作用。 微信开放平台&#xff08;如果单单只做小程序的话&#xff0c;这个平台可以不管&#xff0c;我们公司因为还有App所以也开通了&#xff0c;并且可以在开放平台绑定小程序【目前来看绑定没什…

怎么免费注册微信小程序-微信小程序开发-视频教程1

自从去年微信发布小程序以来&#xff0c; 受到很多人的关注&#xff0c; 今年微信小程序不断的推出新功能&#xff0c; 组合线下活动&#xff0c; 让人有更多期待&#xff0c; 今天我们来看看怎么注册微信小程序。 什么人和机构可以注册微信小程序 微信公众平台允许以下5个主体…

如何正确的注册微信开发测试号

注册测试号 注册的地址在 这里 要进行微信公众号的开发&#xff0c;那就需要一个本地的开发环境来进行开发。而微信测试号就正好提供了这样的一个development环境。每个微信号只能对应一个测试号&#xff0c;但是每个测试号可以开发多个微信公众号项目。微信号与测试号是一一…

微信扫码登陆或注册设计流程

一、整体流程 1、点击微信登录按钮&#xff0c;跳转微信扫码页面 2、用户扫描登录码&#xff0c;匹配用户信息 3、未匹配到用户信息&#xff0c;做注册操作 4、匹配到用户信息&#xff0c;做登录操作 微信授权流程说明&#xff0c;参照微信开放平台 获取access_token时序图…

微信小程序:注册微信小程序

注册小程序帐号 在微信公众平台官网首页&#xff08;mp.weixin.qq.com&#xff09;点击右上角的“立即注册”按钮。 选择注册的帐号类型 显示了4选项&#xff0c;我们现在是注册小程序&#xff0c;所以选择小程序 填写邮箱和密码 请填写未注册过公众平台、开放平台、企业号、…

java微信授权登录回调地址,微信开发者工具,注册微信公共平台

最近在做授权登录的时候遇到一个大坑&#xff0c;回调的地址的问题。 微信登录授权首先要在微信公众平台注册一个账号&#xff0c;然后获取 appID和appsecret 然后点击授权的 切记一定不能加上http://这些协议的东西&#xff1b; 这个域名是内网穿透获取用来通过微信授权的地…

【微信公众号】2. 微信公众号申请注册流程

目录 1. 内容概要 1.1 注意事项 1.2 注册订阅号步骤 1.3 订阅号包含功能 1. 内容概要 注册订阅号 1.1 注意事项 一个邮箱只能注册一个公众号&#xff08;订阅号/服务号/小程序&#xff09;注册订阅号成功后&#xff0c;3天内修改默认名字&#xff0c;否则账号会被回收 1.…

5.3 Python高级特性之-列表生成式、生成器、迭代器

一、 列表生成式 是Python内置的非常简单却强大的可以用来创建list的生成式 具体可根据如下案例理解&#xff0c;且代码也是可用的""" 1、 生成[0,1,2,3,4,5,6]这样列表 """ print(list(range(0, 7))) """ 2、 生成[0&#xff0…