之前写过一篇TensorFlow Java 环境的搭建 TensorFlow Java+eclipse下环境搭建,今天看看TensorFlow Java API 的简单说明 和操作。
TensorFlow是什么
由 Google 开源,是一个深度学习库, 是一套使用数据流图 (data flow graphics)进行数据计算的软件库(software library) 和应用接口(API),并以此作为基础加上其它功能的库和开发工具成为一套进行机器学习、特别是深度学习(deep learning)的应用程序开发框架 (framework)。 ---------------谷歌开发技术推广部 大中华区主管 栾跃 (Bill Luan)
支持CNN、RNN和LSTM算法,是目前在 Image,NLP (神经语言学)最流行的深度神经网络模型。
TensorFlow 优点
基于Python,写的很快并且具有可读性。
在多GPU系统上的运行更为顺畅。
代码编译效率较高。
社区发展的非常迅速并且活跃。
能够生成显示网络拓扑结构和性能的可视化图
TensorFlow 的工作原理
TensorFlow是用数据流图(data flow graphs)技术来进行数值计算的
边:用于传送节点之间的多维数组,即张量( tensor )
节点:表示数学运算操作符 用operation表示,简称op
TensorFlow Java API
public class HelloTF {
public static void main(String[] args) throws Exception {
try (Graph g = new Graph(); Session s = new Session(g)) {
// 使用占位符构造一个图,添加两个浮点型的张量
Output x = g.opBuilder("Placeholder", "x").setAttr("dtype", DataType.FLOAT).build().output(0);//创建一个OP
Output y = g.opBuilder("Placeholder", "y").setAttr("dtype", DataType.FLOAT).build().output(0);
Output z = g.opBuilder("Add", "z").addInput(x).addInput(y).build().output(0);
System.out.println( " z= " + z);
// 多次执行,每次使用不同的x和y值
float[] X = new float[] { 1, 2, 3 };
float[] Y = new float[] { 4, 5, 6 };
for (int i = 0; i < X.length; i++) {
try (Tensor tx = Tensor.create(X[i]);
Tensor ty = Tensor.create(Y[i]);
Tensor tz = s.runner().feed("x", tx).feed("y", ty).fetch("z").run().get(0)) {
System.out.println(X[i] + " + " + Y[i] + " = " + tz.floatValue());
}
}
}
Graph graph = new Graph();
Tensor tensor = Tensor.create(2);
Tensor tensor2 = tensor.create(3);
Output output = graph.opBuilder("Const", "mx").setAttr("dtype", tensor.dataType()).setAttr("value", tensor).build().output(0);
Output output2 = graph.opBuilder("Const", "my").setAttr("dtype", tensor2.dataType()).setAttr("value", tensor2).build().output(0);
Output output3 =graph.opBuilder("Sub", "mz").addInput(output).addInput(output2).build().output(0);
Session session = new Session(graph);
Tensor ttt= session.runner().fetch("mz").run().get(0);
System.out.println(ttt.intValue());
Tensor t= session.runner().feed("mx", tensor).feed("my", tensor2).fetch("mz").run().get(0);
System.out.println(t.intValue());
session.close();
tensor.close();
tensor2.close();
graph.close();
}
}
复制代码
z= <Add 'z:0' shape=<unknown> dtype=FLOAT>
1.0 + 4.0 = 5.0
2.0 + 5.0 = 7.0
3.0 + 6.0 = 9.0
-1
-1
复制代码