python-tensorflow模型保存、提取

本文介绍了如何在TensorFlow中保存和加载模型。包括三种不同的加载方式:完全加载、通过映射关系加载以及直接加载图。这些方法适用于不同的场景,如多人协作开发时避免变量冲突等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

模型保存

import tensorflow as tf

#%%
#模型保存
with tf.variable_scope('a',reuse=tf.AUTO_REUSE):
    v1 = tf.get_variable(name='v1',initializer=tf.constant(3.0))
    v2 = tf.get_variable(name='v2',initializer=tf.constant(4.0))

result = v1+ v2

saver = tf.train.Saver()

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   #将模型保存到model文件夹下,文件的前缀为model.ckpt
   print(sess.run(result))
   saver.save(sess,'./model/model.ckpt')

模型加载(定义模型,完全加载)

#模型的提取 方式一 (完全加载,需要完整的恢复保存之前的数据格式)
with tf.variable_scope('a',reuse=tf.AUTO_REUSE):
    v1 = tf.get_variable(name='v1',initializer=tf.constant(3.0))
    v2 = tf.get_variable(name='v2',initializer=tf.constant(4.0))

result = v1+ v2

saver = tf.train.Saver()

with tf.Session() as sess:
   #将模型保存到model文件夹下,文件的前缀为model.ckpt
   #会从保存的图中加载变量
   saver.restore(sess,'./model/model.ckpt')
   print(sess.run(result))

模型加载(定义模型,映射关系加载),使用多个模型时用

#模型的提取 方式三 (给定映射关系)
#如果模型的编写和模型的使用时不同的人,在使用多个模型时,可能出现使用了同一个变量的情况
a = tf.Variable(tf.constant(1.0),name='a')
b = tf.Variable(tf.constant(2.0),name='b')
result = a+b

saver =tf.train.Saver({"v1":a,'v2':b})
with tf.Session() as sess:
    saver.restore(sess,'./model/model.ckpt')
    print(sess.run(result))

 模型加载(不定义模型,直接加载图)

#模型的提取 方式二 (直接加载图,不需要定义变量了)
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')

with tf.Session() as sess:
   saver.restore(sess,'./model/model.ckpt')
   print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
   print(sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))
   print(sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))

  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值