服务端 flask_server.py
三步走:
- 加载模型
- 数据预处理
- 开启服务
import io
import json
import flask
import torch
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
#from torchvision import transforms as T
from torchvision import transforms, models, datasets
from torch.autograd import Variable
# 初始化Flask app
app = flask.Flask(__name__)
model = None
use_gpu = False
# 加载模型进来
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
global model
#这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc