模型保存
import tensorflow as tf
#%%
#模型保存
with tf.variable_scope('a',reuse=tf.AUTO_REUSE):
v1 = tf.get_variable(name='v1',initializer=tf.constant(3.0))
v2 = tf.get_variable(name='v2',initializer=tf.constant(4.0))
result = v1+ v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#将模型保存到model文件夹下,文件的前缀为model.ckpt
print(sess.run(result))
saver.save(sess,'./model/model.ckpt')
模型加载(定义模型,完全加载)
#模型的提取 方式一 (完全加载,需要完整的恢复保存之前的数据格式)
with tf.variable_scope('a',reuse=tf.AUTO_REUSE):
v1 = tf.get_variable(name='v1',initializer=tf.constant(3.0))
v2 = tf.get_variable(name='v2',initializer=tf.constant(4.0))
result = v1+ v2
saver = tf.train.Saver()
with tf.Session() as sess:
#将模型保存到model文件夹下,文件的前缀为model.ckpt
#会从保存的图中加载变量
saver.restore(sess,'./model/model.ckpt')
print(sess.run(result))
模型加载(定义模型,映射关系加载),使用多个模型时用
#模型的提取 方式三 (给定映射关系)
#如果模型的编写和模型的使用时不同的人,在使用多个模型时,可能出现使用了同一个变量的情况
a = tf.Variable(tf.constant(1.0),name='a')
b = tf.Variable(tf.constant(2.0),name='b')
result = a+b
saver =tf.train.Saver({"v1":a,'v2':b})
with tf.Session() as sess:
saver.restore(sess,'./model/model.ckpt')
print(sess.run(result))
模型加载(不定义模型,直接加载图)
#模型的提取 方式二 (直接加载图,不需要定义变量了)
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'./model/model.ckpt')
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
print(sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
print(sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))