tensorflow的ckpt文件转为npy文件

本文介绍如何将TensorFlow的ckpt文件转换为npy格式,适用于VGG19等预训练模型。首先从ckpt文件中读取参数并保存为npy格式,接着如果使用了预训练模型,则需额外步骤将卷积层参数整合到新的npy文件中。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

我们迁移模型参数时,需要将ckpt文件转为npy文件,以VGG19为例:

"""
将ckpt文件转化为npy文件
"""

import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow


def ckpt2npy():
    checkpoint_path = 'model/model.ckpt-4999'
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()

    # 定义好ckpt模型的每一层
    vgg19 = {'conv1_1': [[], []], 'conv1_2': [[], []], 'conv2_1': [[], []], 'conv2_2': [[], []], 'conv3_1': [[], []],
             'conv3_2': [[], []], 'conv3_3': [[], []], 'conv3_4': [[], []], 'conv4_1': [[], []], 'conv4_2': [[], []],
             'conv4_3': [[], []], 'conv4_4': [[], []], 'conv5_1': [[], []], 'conv5_2': [[], []], 'conv5_3': [[], []],
             'conv5_4': [[], []], 'fc6': [[], []], 'fc7': [[], []], 'fc8': [[], []]}

    for key in var_to_shape_map:
        str_name = key
        print('tensor_name:', str_name)

        if str_name.find('/') > -1:
            names = str_name.split('/')
            # first layer name and weight, bias
            layer_name = names[0]
            layer_add_info = names[1]
        else:
            layer_name = str_name
            layer_add_info = None

        if layer_add_info == 'filter':
            vgg19[layer_name][0] = reader.get_tensor(key)
        if layer_add_info == 'weights':
            vgg19[layer_name][0] = reader.get_tensor(key)
        elif layer_add_info == 'bias':
            vgg19[layer_name][1] = reader.get_tensor(key)
        else:
            vgg19[layer_name] = reader.get_tensor(key)

    np.save('vgg19_2.npy', vgg19)

如果你的神经网络的参数每一层都是随机初始化,没有使用预训练模型,直接用上面的方法就行。
然而,我们经常使用预训练模型,例如VGG19,冻结前面的卷积层,只训练全连接层。如果你使用了预训练模型,且预训练的模型文件也是npy,那你生成的ckpt文件里面的卷积层参数是空的,因为你训练时是直接加载的原npy文件参数,如下图:
在这里插入图片描述
这时,我们还需要将原npy文件的卷积层参数,搬到新的npy文件中,方法如下:

#  先加载两个npy文件
vgg19 = np.load('vgg19.npy', encoding='latin1').item()  # 别人预训练的npy文件
vgg19_2 = np.load('vgg19_2.npy', encoding='latin1').item()  # ckpt2npy()中保存的npy文件(没有卷积层参数)
conv_layer = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4',
              'conv4_1', 'conv4_2', 'conv4_3', 'conv4_4', 'conv5_1', 'conv5_2', 'conv5_3', 'conv5_4']

for con in conv_layer:
    vgg19_2[con][0] = vgg19[con][0]
    vgg19_2[con][1] = vgg19[con][1]

np.save('vgg19_3.npy', vgg19_2)  # 保存新的npy文件,有卷积层和全连接层参数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值