Android设备上部署Pytorch,实现性别识别,男女分类

本文详细介绍将Pytorch训练的性别识别模型移植到Android设备的过程,包括在Android项目中引入Pytorch库、将模型文件放置于assets目录及编写代码读取模型并封装预测功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇文章《Pytorch实现性别识别,男女分类》
我们用pytorch实现了性别识别神经网络的训练和测试,这篇文章我们来介绍如何把训练好的模型迁移到Android设备上。

一、Android上引入pytorch

在app module下的build.gradle上加上

    implementation 'org.pytorch:pytorch_android:1.3.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'

二、把训练好的模型net.pt放到assets目录下

在这里插入图片描述

三、编写代码

3.1、读取assets下的模型文件

private String assetFilePath(Context context, String assetName) {
        File file = new File(context.getFilesDir(), assetName);
        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        } catch (IOException e) {
            Log.e("pytorchandroid", "Error process asset " + assetName + " to file path");
        }
        return null;
    }

3.2、封装pytorch工具类


import android.graphics.Bitmap;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

public class Classifier {
    //类别
    public static final String[] SEXS = new String[]{"男","女"};

    Module model;
    float[] mean = {0.485f, 0.456f, 0.406f};
    float[] std = {0.229f, 0.224f, 0.225f};

    /**
     * 加载assets中的模型
     * @param modelPath
     */
    public Classifier(String modelPath){
        model = Module.load(modelPath);
    }

    /**
     * 传入图片预测性别
     * @param bitmap
     * @param size 规定传入的图片要符合一个大小标准,这里是32*32
     * @return
     */
    public String predict(Bitmap bitmap, int size){
        Tensor tensor = preprocess(bitmap,size);
        IValue inputs = IValue.from(tensor);
        Tensor outputs = model.forward(inputs).toTensor();
        float[] scores = outputs.getDataAsFloatArray();
        int classIndex = argMax(scores);
        return SEXS[classIndex];
    }

    /**
     * 调整图片大小
     * @param bitmap
     * @param size
     * @return
     */
    public Tensor preprocess(Bitmap bitmap, int size){
        bitmap = Bitmap.createScaledBitmap(bitmap,size,size,false);
        return TensorImageUtils.bitmapToFloat32Tensor(bitmap,this.mean,this.std);
    }

    /**
     * 计算最大的概率
     * @param inputs
     * @return
     */
    public int argMax(float[] inputs){
        int maxIndex = -1;
        float maxvalue = 0.0f;
        for (int i = 0; i < inputs.length; i++){
            if(inputs[i] > maxvalue) {
                maxIndex = i;
                maxvalue = inputs[i];
            }
        }
        return maxIndex;
    }

}

其中调整Bitmap大小的方法很重要,否则会报错Caused by: java.lang.RuntimeException: shape '[-1, 400]' is invalid for input of size 150544 The above operation failed in interpreter.

3.3、调用模型预测

//这里的size要根据模型的需要进行改变,本模型需要32*32大小的图片
String pred = classifier.predict(bitmap, 32);

最终运行效果如下:
在这里插入图片描述
如果您想要完整代码移步这里:https://download.csdn.net/download/zhangdongren/12358642

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值