基于Flask的深度学习模型部署技术文档
1. 概述
本文档介绍如何使用Flask框架部署深度学习模型,实现一个完整的Web服务。Flask是一个轻量级的Python Web框架,非常适合快速构建RESTful API服务,将训练好的深度学习模型提供给客户端调用。
2. 系统架构
客户端 (Web/App)<– HTTP/HTTPS –>Flask Web服务<–>深度学习模型
(REST API)
3. 环境准备
3.1 软件依赖
- Python 3.6+
- Flask 2.0+
- 深度学习框架 (PyTorch/TensorFlow/Keras等)
- 其他依赖库: numpy, pillow, opencv-python等
3.2 安装
pip install flask torch torchvision pillow
4. 项目结构
project/
├── app.py# Flask主应用
├── model/# 模型目录
│├── model_weights.pth # 模型权重
│└── model.py# 模型定义
├── static/# 静态文件
├── templates/# 模板文件
├── utils/# 工具函数
│└── preprocess.py# 数据预处理
└── requirements.txt# 依赖列表
5. Flask应用实现
5.1 基本Flask应用
from flask import Flask, request, jsonify
app = Flask(__name__)
@app.route('/')
def home():
return "Deep Learning Model Serving API"
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
5.2 加载深度学习模型
import torch
from model.model import MyModel
def load_model():
model = MyModel()
model.load_state_dict(torch.load('model/model_weights.pth'))
model.eval()# 设置为评估模式
return model
model = load_model()
5.3 预测API端点
from utils.preprocess import preprocess_image
from werkzeug.utils import secure_filename
import os
UPLOAD_FOLDER = 'static/uploads'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
def allowed_file(filename):
return '.' in filename and \\
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# 预处理图像
input_tensor = preprocess_image(filepath)
# 执行预测
with torch.no_grad():
output = model(input_tensor)
prediction = output.argmax(dim=1).item()
return jsonify({
'prediction': prediction,
'image_url': f'/static/uploads/{filename}'
})
return jsonify({'error': 'File type not allowed'}), 400
6. 模型预处理和后处理
6.1 预处理示例 (preprocess.py)
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
def preprocess_image(image_path):
# 图像预处理流程
input_image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)# 添加batch维度
return input_batch
6.2 后处理示例
# 在predict函数中添加后处理
class_names = ['cat', 'dog']# 示例类别
@app.route('/predict', methods=['POST'])
def predict():
# …之前的代码…
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_catid = torch.topk(probabilities, 3)
results = []
for i in range(top_prob.size(0)):
results.append({
'class': class_names[top_catid[i]],
'probability': top_prob[i].item()
})
return jsonify({
'predictions': results,
'image_url': f'/static/uploads/{filename}'
})
7. 性能优化
7.1 启用多线程
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, threaded=True)
7.2 使用生产服务器
# 使用Waitress
pip install waitress
waitress-serve –host=0.0.0.0 –port=5000 app:app
# 或使用Gunicorn
pip install gunicorn
gunicorn -w 4 -b 0.0.0.0:5000 app:app
7.3 模型缓存
from flask import Flask, g
import functools
def get_model():
if 'model' not in g:
g.model = load_model()
return g.model
@app.teardown_appcontext
def teardown_model(exception):
model = g.pop('model', None)
if model is not None:
# 如果有需要可以在这里清理模型资源
pass
8. 安全考虑
8.1 文件上传安全
# 在文件上传处理中添加安全检查
import magic
def is_valid_image(file_stream):
# 使用python-magic检查文件实际类型
file_type = magic.from_buffer(file_stream.read(1024), mime=True)
file_stream.seek(0)
return file_type in ['image/jpeg', 'image/png']
8.2 API认证
from functools import wraps
from flask import request, jsonify
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('X-API-Key')
if not token or token != 'your-secret-token':
return jsonify({'message': 'Token is missing or invalid'}), 401
return f(*args, **kwargs)
return decorated
@app.route('/predict', methods=['POST'])
@token_required
def predict():
# 受保护的预测端点
pass
9. 测试API
使用curl测试API:
curl -X POST -F "file=@test.jpg" http://localhost:5000/predict
使用Python requests测试:
import requests
url = 'http://localhost:5000/predict'
files = {'file': open('test.jpg', 'rb')}
response = requests.post(url, files=files)
print(response.json())
10. 部署建议
11. 常见问题解决
12. 扩展功能
通过本文档,您可以快速搭建一个基于Flask的深度学习模型服务,并根据实际需求进行扩展和优化。
评论前必须登录!
注册