You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

279 lines
9.3 KiB

import torch
import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
from typing import Tuple, List, Optional
class MyLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.norm = AdaptiveSoftTanh(input_size)
self.dropout = nn.Dropout(dropout)
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
self.init_weight()
self.h_state = None
self.c_state = None
self.initialized = False
def init_weight(self) -> None:
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
else:
nn.init.zeros_(param)
def init_hidden(
self, batch_size: int, device: torch.device,
h_state: Optional[torch.Tensor] = None, c_state: Optional[torch.Tensor] = None
) -> None:
self.h_state = h_state if h_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.c_state = c_state if c_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.initialized = True
def _forward_seq(self, x: torch.Tensor) \
-> torch.Tensor:
_, T, _ = x.shape
# 存储所有时间步的输出
output = []
x_o = x
x = self.norm(x)
# 遍历序列的每个时间步
for t in range(T):
x_t = x[:, t, :] # 当前时间步输入 (B, C)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# 当前时间步输出
output.append(self.h_state)
# 将输出转换为张量
output = torch.stack(output, dim=1)
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x_o
else:
output = output
return self.dropout(output)
def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor:
x_o = x
x = self.norm(x)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# 当前时间步输出
output = self.h_state
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x_o
else:
output = output
return self.dropout(output)
def forward(self, x) -> torch.Tensor:
if self.initialized is False:
batch_size = x.size(0)
device = x.device
self.init_hidden(batch_size, device)
if x.dim() == 2:
return self._forward_step(x)
elif x.dim() == 3:
return self._forward_seq(x)
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self) -> None:
self.h_state = None
self.c_state = None
self.initialized = False
class LSTMEncoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMEncoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
x = self.embedding(x)
output = x
# 保存每一层最后的隐藏状态和细胞状态
h, c = [], []
for layer in self.layers:
output = layer(output)
h.append(layer.h_state)
c.append(layer.c_state)
return h, c
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class LSTMDecoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMDecoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
def init_hidden(
self, batch_size: int, device: torch.device,
h: List[torch.Tensor], c: List[torch.Tensor]
) -> None:
for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i])
def forward(self, x: torch.Tensor):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return self.fc(x)
def reset(self):
for layer in self.layers:
layer.reset()
class Seq2SeqLSTM(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super().__init__()
self.encoder = LSTMEncoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
self.decoder = LSTMDecoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
self.reset()
h, c = self.encoder(src)
batch_size = src.size(0)
device = src.device
self.decoder.init_hidden(batch_size, device, h, c)
return self.decoder(tgt)
def reset(self):
self.encoder.reset()
self.decoder.reset()
def greedy_decode(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
pad_token_id: int,
max_length: int
) -> torch.Tensor:
batch_size = src.size(0)
device = src.device
# 初始化输出序列 (全部填充为pad_token_id)
output_seq = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device)
# 初始化第一个解码输入为BOS标记
decoder_input = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
# 初始化已完成序列的掩码
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# 编码源序列
self.reset()
h, c = self.encoder(src)
self.decoder.init_hidden(batch_size, device, h, c)
with torch.no_grad():
for step in range(max_length):
# 执行单步解码
logits = self.decoder(decoder_input) # (batch_size, 1, vocab_size)
# 获取当前步的预测标记
next_tokens = logits.argmax(dim=-1) # (batch_size, 1)
# 将预测结果写入输出序列
output_seq[:, step] = next_tokens.squeeze(1)
# 更新已完成序列的掩码
finished = finished | (next_tokens.squeeze(1) == eos_token_id)
# 如果所有序列都已完成则提前终止
if finished.all():
break
# 准备下一步的输入:
# 未完成序列使用预测标记,已完成序列使用pad_token_id
decoder_input = next_tokens.masked_fill(
finished.unsqueeze(1),
pad_token_id
)
return output_seq