_keras_shape和Tensor的shape不匹配从而引发的一系列问题

本文详细解析了在使用Keras自定义MLP层时遇到的一个关于维度不匹配的错误,提供了完整的代码示例及解决方案,指出通过直接传参而非重写__call__方法可解决问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题描述

刚用keras给实验模型写了一个MLP层,是定义成一个类的形式,通过call方法进行调用。这是tensorflow常用的一种模型定义方法。不料却出现了一个见所未见的错误:

ValueError: Dimensions must be equal, but are 4 and 20 for 'batch_normalization_3/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,16,5,4], [20], [20], [0], [0].

MLP部分代码如下:

from keras.backend import relu
from keras.layers import Dense, Layer, BatchNormalization, Lambda


class MLP(Layer):
    def __init__(self, num_layers, hidden_dim, output_dim, **kwargs):
        super(MLP, self).__init__()

        self.linear_or_not = True
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        if num_layers < 1:
            raise ValueError("number of layers should be positive")
        elif num_layers == 1:
            self.linear = Linear_model()
        else:
            self.linear_or_not = False
            self.multi = Multi_model(layers=self.num_layers, hidden_dim=self.hidden_dim,
                                     output_dim=self.output_dim)

    def get_config(self):
        config = {'num_layers': self.num_layers, 'hidden_dim': self.hidden_dim, 'output_dim': self.output_dim}
        base_config = super(MLP, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, input_feature, **kwargs):
        if self.linear_or_not:
            return self.linear(input_feature)
        else:
            return self.multi(input_feature)

class Linear_model(Layer):
    def __init__(self, output_dim):
        super(Linear_model, self).__init__()
        self.output_layer = Dense(units=output_dim)

    def call(self, input_feature, **kwargs):
        return self.output_layer(input_feature)


class Multi_model(Layer):
    def __init__(self, layers, hidden_dim, output_dim):
        super(Multi_model, self).__init__()
        self.layers = layers
        self.dense_list = []
        self.batch_list = []

        for i in range(layers-1):
            self.dense_list.append(Dense(units=hidden_dim))
            self.batch_list.append(BatchNormalization())

        self.dense_list.append(Dense(units=output_dim))

    def call(self, input_feature):
        for i in range(self.layers-1):
            densed = self.dense_list[i](input_feature)
            batched = self.batch_list[i](densed)
            input_feature = Lambda(lambda x: relu(x))(batched)
            input_feature = relu(batched)

        multi_result = self.dense_list[-1](input_feature)
        return multi_result

模型调用处方法图下:

pooled_rep = self.mlps[layer](pooled)
h = self.batches[layer](pooled_rep)

问题解决方案

查阅文档可知,keras对BatchNormalization层的input并没有限制条件。
于是我们之前去debug界面看BatchNormalization层的输入,也就是pooled_rep的值。可以很清楚的看到pooled_rep这个Tensor的shape与其_keras_shape是不匹配的。
在这里插入图片描述
我们再去看__call__的源码:

output_ls = to_list(output)
        inputs_ls = to_list(inputs)
        output_ls_copy = []
        for x in output_ls:
            if x in inputs_ls:
                x = K.identity(x)
            output_ls_copy.append(x)
        output = unpack_singleton(output_ls_copy)
        ......
 
        # 调用_add_inbound_node创建层间连接并保存历史
        self._add_inbound_node(input_tensors=inputs,
                               output_tensors=output,
                               input_masks=previous_mask,
                               output_masks=output_mask,
                               input_shapes=input_shape,
                               output_shapes=output_shape,
                               arguments=user_kwargs)

通过debug可知,直到这一步,output的shape都是正确的:

output = unpack_singleton(output_ls_copy)

直到调用了最后一个self._add_inbound_node方法,output的shape发生了变化。我个人认为这是keras使用tf式建模方法时框架本身的存在缺陷。
最后,我们可以通过将自定义的MLP层改写为直接传参进行调用,而不是通过复写__call__的方式,问题才能够得以解决。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值