官网链接
Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation
通过flask的rest API在python中部署pytorch
在本教程中,我们将使用Flask部署PyTorch模型,并开放用于模型推断的REST API。特别是,我们将部署一个预训练的DenseNet 121模型来检测图像。
这是关于在生产环境中部署PyTorch模型的系列教程中的第一篇。使用Flask这种方式是迄今为止部署PyTorch模型的最简单方法,但它不适用于具有高性能要求的用例。
- 如果你已经熟悉了TorchScript,你可以直接跳到我们的加载一个TorchScript模型在c++教程。(Loading a TorchScript Model in C++ )
- 如果你需要对TorchScript进行复习,请查看我们的TorchScript入门教程。(Intro a TorchScript )
API定义
我们将首先定义API 路径、请求和响应类型。我们的API路径是 /predict ,它接受带有包含图像的文件参数的HTTP POST请求。响应将是JSON响应,其中包含预测结果。
{"class_id": "n02124075", "class_name": "Egyptian_cat"}{"class_id": "n02124075", "class_name": "Egyptian_cat"}
依赖项
运行以下命令安装所需的依赖项:
$ pip install Flask==2.0.1 torchvision==0.10.0
简单Web服务器
下面是一个简单的web服务器,摘自Flask的文档
from flask import Flask
app = Flask(__name__)@app.route('/')
def hello():return 'Hello World!'
将上面的代码片段保存在一个名为app.py的文件中,现在你可以通过输入以下命令来运行Flask开发服务器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
当您在web浏览器中访问http://localhost:5000/时,您将看到Hello World!文本
我们将对上面的代码片段做一些修改,使其适合我们的API定义。首先,我们将把方法重命名为predict。我们将把请求路径更新为/predict。由于图像文件将通过HTTP POST请求发送,我们将更新它,使其也只接受POST请求。
@app.route('/predict', methods=['POST'])
def predict():return 'Hello World!'
我们还将更改响应类型,以便它返回一个包含ImageNet类id和名称的JSON响应。更新后的app.py文件现在将是:
from flask import Flask, jsonify
app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
推理
在下一节中,我们将重点讨论如何编写推理代码。这将涉及两个部分,一个是我们准备图像,以便它可以馈送到DenseNet,接下来,我们将编写代码以从模型中获得实际预测。
准备图像
DenseNet模型要求图像为3通道RGB图像,大小为224 x 224。我们还将用所需的均值和标准差值对图像张量进行归一化。你可以在这里读到更多(here)。
我们将使用torchvision 库中的 transforms ,并构建一个变换管道,它可以根据要求变换我们的图像。你可以在这里关于变换的内容(here)。
import ioimport torchvision.transforms as transforms
from PIL import Imagedef transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)
上述方法接受字节数据的图像,应用一些列的transforms 并返回一个张量。为了测试上述方法,以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径)并查看是否返回一个张量:
with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()tensor = transform_image(image_bytes=image_bytes)print(tensor)
预测
现在将使用预训练的DenseNet 121模型来预测图像类别。我们将使用torchvision库,加载模型并获得推理结果。虽然我们将在本例中使用预训练模型,但您可以对自己的模型使用相同的方法。了解更多关于加载模型的信息(tutorial)。
from torchvision import models# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)return y_hat
张量y_hat 将包含预测类别的索引id, 然而,我们需要一个人类可读的类名。为此,我们需要一个类别id和命名的映射。下载 imagenet_class_index.json 这个文件( this file),并记住保存它的位置(或者,如果您遵循本教程中的确切步骤,将其保存在教程/_static中)。这个文件包含ImageNet类别id到ImageNet类名的映射。我们将加载这个JSON文件并获取预测类别索引的类名。
import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]
在使用imagenet_class_index字典之前,首先我们将把张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。我们将测试上述方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()print(get_prediction(image_bytes=image_bytes))
你应该得到这样的返回:
['n02124075', 'Egyptian_cat']
数组中的第一项是ImageNet类别id,第二项是人类可读的名称。
注意
您是否注意到model变量不是get_prediction方法的局部变量,或者说为什么model是一个全局变量?就内存和计算而言,加载模型可能是一项昂贵的操作。如果我们在get_prediction方法中加载模型,那么每次调用该方法时都会不必要地加载模型。因为我们正在构建一个web服务器,每秒可能有数千个请求,我们不应该浪费时间为每个推理加载模型。因此,我们只将模型加载到内存中一次。在生产系统中,为了能够大规模地处理请求,必须高效地使用计算,因此通常应该在处理请求之前加载模型。
在我们的API服务器中集成模型
在最后一部分中,我们将把模型添加到Flask API服务器中。由于我们的API服务器应该接受一个图像文件,我们将更新我们的预测方法来从请求中读取文件:
from flask import request@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# we will get the file from the requestfile = request.files['file']# convert that to bytesimg_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})
app.py文件现在已经完成。以下是完整版本;将路径替换为您保存文件的路径,它应该运行:
import io
import jsonfrom torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, requestapp = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()def transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':file = request.files['file']img_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__':app.run()
让我们测试一下我们的web服务器!运行:
$ FLASK_ENV=development FLASK_APP=app.py flask run
我们可以使用requests库向我们的应用发送POST请求:
import requestsresp = requests.post("http://localhost:5000/predict",files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印rep .json()将显示以下内容:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
下一个步骤
我们编写的服务器非常简单,可能无法完成生产应用程序所需的所有功能。所以,这里有一些你可以做的事情来让它变得更好:
- 请求路径 /predict假定请求中总是有一个图像文件。这可能并不适用于所有请求。我们的用户可以发送带有不同参数的图像或根本不发送图像。
- 用户也可以发送非图像类型的文件。由于我们不处理错误,这将破坏我们的服务器。显式添加异常的错误处理路径,将使我们能够更好地处理错误输入
- 尽管该模型可以识别大量的图像类别,但它可能无法识别所有的图像。优化实现以处理模型无法识别图像中的任何内容的情况。
- 我们以开发模式运行Flask服务器,这种模式不适合部署到生产环境中。您可以查看本教程,了解如何在生产环境中部署Flask服务器。(this tutorial )
- 您还可以通过创建带有表单的页面来添加UI,该表单接受图像并显示预测结果。请查看类似项目的演示及其源代码。(source code.)
- 在本教程中,我们只展示了如何构建一个每次可以返回单个图像预测的服务。我们可以修改我们的服务,使其能够一次返回多个图像的预测结果。此外,service-streamer库会自动将请求排队到您的服务中,并将它们采样到可以馈送到模型中的小批量中。您可以查看本教程(this tutorial.)。
- 最后,我们鼓励您查看页面顶部链接的关于部署PyTorch模型的其他教程.