使用MobileNetV3训练水果分类模型并用Flask部署

前言

在计算机视觉领域,图像分类是一个基础且重要的任务。本文将介绍如何使用MobileNetV3预训练模型来训练一个水果分类模型,并通过Flask框架进行部署。MobileNetV3作为轻量级网络,在保持较高精度的同时,具有较快的推理速度,非常适合实际应用场景。

环境准备

首先,我们需要准备以下环境:

# 主要依赖包
torch>=1.7.0
torchvision>=0.8.0
flask>=2.0.0
pillow>=8.0.0
numpy>=1.19.0
requests>=2.25.0  # 用于数据采集
matplotlib>=3.3.0  # 用于绘制训练曲线

数据集准备

1. 数据采集

我们使用百度图片API来采集水果图片数据。以下是数据采集的代码实现:

import requests
import os

def get_images(keyword, page_num):
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/69.0.3497.81 Safari/537.36'
    }
    url = 'https://image.baidu.com/search/acjson?'
    
    # 设置图片保存路径
    download_path = os.path.join("./data", keyword)
    if not os.path.exists(download_path):
        os.makedirs(download_path)
    
    # 构造请求参数
    params = {
        'tn': 'resultjson_com',
        'word': keyword,
        'pn': 0,
        'rn': 30,
        # ... 其他参数
    }
    
    # 下载图片
    for i in range(page_num):
        params["pn"] = i*30
        response = requests.get(url, params=params, headers=headers)
        # 处理返回结果并保存图片

2. 数据集组织

将采集到的图片按照以下结构组织:

data/
    ├── apple/
    │   ├── 0.jpg
    │   ├── 1.jpg
    │   └── ...
    ├── banana/
    │   ├── 0.jpg
    │   ├── 1.jpg
    │   └── ...
    └── ...

模型训练

1. 数据加载和预处理

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225]),
])

# 加载数据集
dataset = ImageFolder("data", transform=transform)

# 保存类别标签
with open("label.txt", "w", encoding="UTF-8") as f:
    for line in dataset.classes:
        f.write(line + "\n")

# 划分训练集和测试集
train_ratio = 0.8
train_size = int(len(dataset) * train_ratio)
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

2. 模型定义

from torchvision import models
import torch.nn as nn

# 使用MobileNetV3-Small预训练模型
model = models.mobilenet_v3_small(pretrained=True)
# 修改最后的分类层
model.classifier[3] = nn.Linear(in_features=1024, out_features=5)  # 5个类别

# 如果有GPU则使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

3. 训练过程

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练参数
num_epochs = 20
best_valid_acc = 0
best_model = None

# 记录训练过程
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []

for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    total = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_acc += (predicted == labels).sum().item()
        total += len(labels)
    

Flask部署

1. 创建Flask应用

2. 实现预测接口

@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return render_template('index.html', prediction=None)
    
    image_file = request.files['image']
    image_data = image_file.read()
    
    # 图像预处理
    img = Image.open(io.BytesIO(image_data))
    img = transform(img)
    img = torch.unsqueeze(img, dim=0)
    
    # 模型预测
    with torch.no_grad():
        prediction = model(img)
        prediction = F.softmax(prediction, dim=1)
    
    # 获取预测结果
    pred_label = class_labels[torch.argmax(prediction).item()]
    confidence = torch.max(prediction).item()
    
    return render_template('index.html', 
                         prediction=pred_label,
                         confidence=confidence)

部署步骤

  1. 确保服务器已安装Python环境
  2. 安装所需依赖包:
pip install -r requirements.txt

  1. 将模型文件、Flask应用和模板文件上传到服务器
  2. 运行Flask应用:
python app.py

总结

本文详细介绍了使用MobileNetV3训练水果分类模型并用Flask部署的完整流程。通过使用预训练模型,我们可以在较小的数据集上获得不错的分类效果。Flask框架的轻量级特性使得部署变得简单快捷。在实际应用中,可以根据具体需求进行进一步的优化和改进。

参考资料

  1. MobileNetV3论文:Searching for MobileNetV3
  2. Flask官方文档:https://flask.palletsprojects.com/
  3. PyTorch官方文档:https://pytorch.org/docs/stable/index.html
  4. 百度图片API文档
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王小葱鸭

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值