基于 Pytorch 的 LSTM 学习随笔


视频参考 PyTorch30——LSTM和LSTMP的原理及其逐行代码实现

导包

import torch
import torch.nn as nn

定义参数,并调用官方LSTM(此处都是单向LSTM)

# T 为T时刻(时间序列)
batch_size, T, i_size, h_size = 2, 3, 4, 5
input = torch.randn(batch_size, T, i_size)  # 输入序列
c0 = torch.randn(batch_size, h_size)
h0 = torch.randn(batch_size, h_size)

# 调用官方lstm
lstm_layer = nn.LSTM(input_size=i_size, hidden_size=h_size, batch_first=True)
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

output

输出:

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])

tensor([[[-0.1593,  0.0920, -0.1179, -0.1437, -0.1078],
         [-0.2673,  0.0440, -0.3039, -0.2347, -0.1679],
         [-0.0064, -0.0189, -0.0040, -0.1203, -0.0929]],

        [[ 0.0719, -0.0904, -0.2322,  0.1735, -0.4898],
         [ 0.2594,  0.0302,  0.0551,  0.1185, -0.0740],
         [-0.0121, -0.0488, -0.2208, -0.1240, -0.1958]]],
       grad_fn=)

手写LSTM

def lstm_forward(input, initial_state, w_ih, w_hh, b_ih, b_hh):
    h0, c0 = initial_state
    batch_size, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1,
                                        1)  # [4 * h_size, i_size],需要扩维度,扩充batch_size维:unsqueeze(0) 扩在开始——第零维;tile()复制batch_size倍 -> [batch_size, 4 * h_size, i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size, 1, 1)  # [4 * h_size, h_size]

    output_size = h_size
    output = torch.zeros(batch_size, T, output_size)  # 输出序列

    # 对每一时刻进行运算
    for t in range(T):
        # 计算W * X
        # 获取当前时刻的输入向量(X_t)
        x = input[:, t, :]  # [batch_size, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # x -> [batch_size, i_size, 1]; times->[bs, 4*h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # times->[bs, 4*h_size]

        # 计算W * H
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)
        # 由于w是四部分拼起来的,因此计算各个门的时候只需要取相应部分
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])
        f_t = torch.sigmoid(
            w_times_x[:, h_size: 2 * h_size] + w_times_h_prev[:, h_size: 2 * h_size] + b_ih[h_size: 2 * h_size] + b_hh[
                                                                                                                  h_size: 2 * h_size])
        g_t = torch.tanh(w_times_x[:, 2 * h_size:3 * h_size] + w_times_h_prev[:, 2 * h_size:3 * h_size] + b_ih[
                                                                                                          2 * h_size:3 * h_size] + b_hh[
                                                                                                                                   2 * h_size:3 * h_size])
        o_t = torch.sigmoid(
            w_times_x[:, 3 * h_size:] + w_times_h_prev[:, 3 * h_size:] + b_ih[3 * h_size:] + b_hh[3 * h_size:])
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)


output_custom, (h_final_custom, c_final_custom) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0,
             lstm_layer.bias_hh_l0)
output_custom

输出(使用同一组参数,输出了同样的结果):

tensor([[[-0.1593,  0.0920, -0.1179, -0.1437, -0.1078],
         [-0.2673,  0.0440, -0.3039, -0.2347, -0.1679],
         [-0.0064, -0.0189, -0.0040, -0.1203, -0.0929]],

        [[ 0.0719, -0.0904, -0.2322,  0.1735, -0.4898],
         [ 0.2594,  0.0302,  0.0551,  0.1185, -0.0740],
         [-0.0121, -0.0488, -0.2208, -0.1240, -0.1958]]], grad_fn=)

NOTE

关于squeeze(压缩维度)和unsqueeze(添加维度)

batch_size, T, i_size, h_size = 2, 3, 4, 5
input = torch.randn(batch_size, T, i_size)  # 输入序列
c0 = torch.randn(batch_size, h_size)
h0 = torch.randn(batch_size, h_size, 1)
h0, h0.squeeze(-1)

输出:

(tensor([[[ 1.0476],
          [ 1.5756],
          [ 0.3098],
          [-0.5701],
          [-1.1242]],
 
         [[ 0.7007],
          [-1.5588],
          [-0.4844],
          [-1.0391],
          [ 1.5587]]]),
 tensor([[ 1.0476,  1.5756,  0.3098, -0.5701, -1.1242],
         [ 0.7007, -1.5588, -0.4844, -1.0391,  1.5587]]))

batch_size, T, i_size, h_size = 2, 3, 4, 5
input = torch.randn(batch_size, T, i_size)  # 输入序列
c0 = torch.randn(batch_size, h_size)
h0 = torch.randn(batch_size, h_size)
h0, h0.unsqueeze(0)

输出:

(tensor([[-0.3069, -0.1296, -0.6759,  0.8653, -0.3475],
         [-1.2103, -0.1932, -0.4064,  0.7696,  0.6389]]),
 tensor([[[-0.3069, -0.1296, -0.6759,  0.8653, -0.3475],
          [-1.2103, -0.1932, -0.4064,  0.7696,  0.6389]]])