RNN
前向过程:
- ht=g(Uht−1+Wxt+bh)h_t = g(Uh_{t-1} + Wx_t +b_h)ht=g(Uht−1+Wxt+bh)
- yt=g(Wyht+by)y_t = g(W_yh_t + b_y)yt=g(Wyht+by)
pytorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class RNNCell(nn.Module):
def __init__(self, input_size, hidden_dim):
super(RNNCell, self).__init__()
self.input_size = input_size
self.hidden_dim = hidden_dim
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.linear2 = nn.Linear(input_size, hidden_dim)
def forward(self, x, h_pre):
"""
:param x: (batch, input_size)
:param h_pre: (batch, hidden_dim)
:return: h_next (batch, hidden_dim)
"""
h_next = torch.tanh(self.linear1(h_pre) + self.linear2(x))
return h_next
class RNN(nn.Module):
def __init__(self, input_size, hidden_dim):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_dim = hidden_dim
self.rnn_cell = RNNCell(input_size, hidden_dim)
def forward(self, x):
"""
:param x: (seq_len, batch,input_size)
:return:
output (seq_len, batch, hidden_dim)
h_n (1, batch, hidden_dim)
"""
seq_len, batch, _ = x.shape
h = torch.zeros(batch, self.hidden_dim)
output = torch.zeros(seq_len, batch, self.hidden_dim)
for i in range(seq_len):
inp = x[i, :, :]
h = self.rnn_cell(inp, h)
output[i, :, :] = h
h_n = output[-1:, :, :]
return output, h_n
LSTM
前向过程:
- 输入门: it=σ(Wixt+Uiht−1+bi)i_t = \sigma (W_ix_t + U_ih_{t-1} + b_i)it=σ(Wixt+Uiht−1+bi)
- 遗忘门: ft=σ(Wfxt+Ufht−1+bf)f_t = \sigma (W_fx_t + U_fh_{t-1} + b_f)ft=σ(Wfxt+Ufht−1+bf)
- 输出门: ot=σ(Woxt+Uoht−1+bo)o_t = \sigma (W_ox_t + U_oh_{t-1} + b_o)ot=σ(Woxt+Uoht−1+bo)
- c^t=tanh(Wcxt+Ucht−1+bc)\hat{c}_t = tanh(W_cx_t + U_ch_{t-1} + b_c)c^t=tanh(Wcxt+Ucht−1+bc)
- ct=ft⊙ct−1+it⊙c^tc_t = f_t \odot c_{t-1} + i_t \odot \hat{c} _tct=ft⊙ct−1+it⊙c^t
- ht=ot⊙tanh(ct)h_t = o_t \odot tanh(c_t)ht=ot⊙tanh(ct)
pytorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
class Gate(nn.Module):
def __init__(self, input_size, hidden_dim):
super(Gate, self).__init__()
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.linear2 = nn.Linear(input_size, hidden_dim)
def forward(self, x, h_pre, active_func):
h_next = active_func(self.linear1(h_pre) + self.linear2(x))
return h_next
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_dim):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_dim = hidden_dim
self.gate = clones(Gate(input_size, hidden_dim), 4)
def forward(self, x, h_pre, c_pre):
"""
:param x: (batch, input_size)
:param h_pre: (batch, hidden_dim)
:param c_pre: (batch, hidden_dim)
:return: h_next(batch, hidden_dim), c_next(batch, hidden_dim)
"""
f_t = self.gate[0](x, h_pre, torch.sigmoid)
i_t = self.gate[1](x, h_pre, torch.sigmoid)
g_t = self.gate[2](x, h_pre, torch.tanh)
o_t = self.gate[3](x, h_pre, torch.sigmoid)
c_next = f_t * c_pre + i_t * g_t
h_next = o_t * torch.tanh(c_next)
return h_next, c_next
class LSTM(nn.Module):
def __init__(self, input_size, hidden_dim):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_dim = hidden_dim
self.lstm_cell = LSTMCell(input_size, hidden_dim)
def forward(self, x):
"""
:param x: (seq_len, batch,input_size)
:return:
output (seq_len, batch, hidden_dim)
h_n (1, batch, hidden_dim)
c_n (1, batch, hidden_dim)
"""
seq_len, batch, _ = x.shape
h = torch.zeros(batch, self.hidden_dim)
c = torch.zeros(batch, self.hidden_dim)
output = torch.zeros(seq_len, batch, self.hidden_dim)
for i in range(seq_len):
inp = x[i, :, :]
h, c = self.lstm_cell(inp, h, c)
output[i, :, :] = h
h_n = output[-1:, :, :]
return output, (h_n, c.unsqueeze(0))
GRU
前向过程:
更新门:
- rt=σ(Wxrxt+Whrht−1+br)r_t = \sigma (W_{xr}x_t + W_{hr}h_{t-1} + b_r)rt=σ(Wxrxt+Whrht−1+br)
- zt=σ(Wxzxt+Whzht−1+bz)z_t = \sigma (W_{xz}x_t + W_{hz}h_{t-1} + b_z)zt=σ(Wxzxt+Whzht−1+bz)
候选隐含状态:
- h^t=tanh(Wxhxt+rt⊙Whhht−1+bh)\hat{h}_t = tanh(W_{xh}x_t + r_t \odot W_{hh}h_{t-1} + b_h)h^t=tanh(Wxhxt+rt⊙Whhht−1+bh)
隐含状态:
- ht=zt⊙ht−1+(1−zt)⊙h^th_t = z_t \odot h_{t-1} + (1-z_t) \odot \hat{h}_tht=zt⊙ht−1+(1−zt)⊙h^t
输出:
- yt=softmax(Whyht+by)y_t = softmax(W_{hy}h_t + b_y)yt=softmax(Whyht+by)