上一篇文章《Pytorch实现性别识别,男女分类》
我们用pytorch实现了性别识别神经网络的训练和测试,这篇文章我们来介绍如何把训练好的模型迁移到Android设备上。
一、Android上引入pytorch
在app module下的build.gradle上加上
implementation 'org.pytorch:pytorch_android:1.3.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
二、把训练好的模型net.pt放到assets目录下
三、编写代码
3.1、读取assets下的模型文件
private String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e("pytorchandroid", "Error process asset " + assetName + " to file path");
}
return null;
}
3.2、封装pytorch工具类
import android.graphics.Bitmap;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
public class Classifier {
//类别
public static final String[] SEXS = new String[]{"男","女"};
Module model;
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
/**
* 加载assets中的模型
* @param modelPath
*/
public Classifier(String modelPath){
model = Module.load(modelPath);
}
/**
* 传入图片预测性别
* @param bitmap
* @param size 规定传入的图片要符合一个大小标准,这里是32*32
* @return
*/
public String predict(Bitmap bitmap, int size){
Tensor tensor = preprocess(bitmap,size);
IValue inputs = IValue.from(tensor);
Tensor outputs = model.forward(inputs).toTensor();
float[] scores = outputs.getDataAsFloatArray();
int classIndex = argMax(scores);
return SEXS[classIndex];
}
/**
* 调整图片大小
* @param bitmap
* @param size
* @return
*/
public Tensor preprocess(Bitmap bitmap, int size){
bitmap = Bitmap.createScaledBitmap(bitmap,size,size,false);
return TensorImageUtils.bitmapToFloat32Tensor(bitmap,this.mean,this.std);
}
/**
* 计算最大的概率
* @param inputs
* @return
*/
public int argMax(float[] inputs){
int maxIndex = -1;
float maxvalue = 0.0f;
for (int i = 0; i < inputs.length; i++){
if(inputs[i] > maxvalue) {
maxIndex = i;
maxvalue = inputs[i];
}
}
return maxIndex;
}
}
其中调整Bitmap大小的方法很重要,否则会报错Caused by: java.lang.RuntimeException: shape '[-1, 400]' is invalid for input of size 150544 The above operation failed in interpreter.
3.3、调用模型预测
//这里的size要根据模型的需要进行改变,本模型需要32*32大小的图片
String pred = classifier.predict(bitmap, 32);
最终运行效果如下:
如果您想要完整代码移步这里:https://download.csdn.net/download/zhangdongren/12358642