一.引言
使用 Java API 调用 Tensorflow API 期间报错 Expects arg[0] to be int64 but int32 is provided:
java.lang.IllegalArgumentException: Expects arg[0] to be int64 but int32 is provided
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:326)
at org.tensorflow.Session$Runner.run(Session.java:276)
根据提示可以看到是数据要求为 int64 但是传输了 int32 的类型。
二.异常排除
1.修改 placeholder ❌
既然类型不匹配,那就把 dtype=int64 修改为 int32,修改后导出模型继续报错:
ValueError: Tensor conversion requested dtype int64 for Tensor with dtype int32: ...
原来是 SparseTensor 构建时 index 必须采用 int64 作为 dtype:
故这里修改 placeholder 不生效,不过其他场景如果不受源码 dtype 影响的话,大家可以使用该方法。
2.修改 Tensor 类型 ✔️
使用 tensor.create 方法时选择 long 对应 python 内 int64 类型,之前报错是因为传入了 Integer,所以对应了 int32,而 SparseTensor 构造需要 int64:
long[] temp_value = new long[size];
// int[] temp_value = new int[size];
return Tensor.create(temp_value);
修改后 Session 运行正常。