mxnet中ndarray*ndarray用来作为掩码进行与运算的用法
def batch_loss(encoder, decoder, X, Y, loss):
batch_size = X.shape[0]
enc_state = encoder.begin_state(batch_size=batch_size)
enc_outputs, enc_state = encoder(X, enc_state)
# 初始化解码器的隐藏状态
dec_state = decoder.begin_state(enc_state)
# ...
原创
2021-01-13 17:11:49 ·
166 阅读 ·
0 评论