import torch import torch.nn as nn from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh from typing import Tuple, List, Optional class MyLSTM(nn.Module): def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.norm = AdaptiveSoftTanh(input_size) self.dropout = nn.Dropout(dropout) self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size)) self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size)) self.bias = nn.Parameter(torch.Tensor(4 * hidden_size)) self.init_weight() self.h_state = None self.c_state = None self.initialized = False def init_weight(self) -> None: for param in self.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) def init_hidden( self, batch_size: int, device: torch.device, h_state: Optional[torch.Tensor] = None, c_state: Optional[torch.Tensor] = None ) -> None: self.h_state = h_state if h_state is not None else torch.zeros( batch_size, self.hidden_size, device=device ) self.c_state = c_state if c_state is not None else torch.zeros( batch_size, self.hidden_size, device=device ) self.initialized = True def _forward_seq(self, x: torch.Tensor) \ -> torch.Tensor: _, T, _ = x.shape # 存储所有时间步的输出 output = [] x_o = x x = self.norm(x) # 遍历序列的每个时间步 for t in range(T): x_t = x[:, t, :] # 当前时间步输入 (B, C) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size) gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias # 分割为四个门 (每个都是batch_size, hidden_size) i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) i_t = SoftSigmoid(i_gate) # 输入门 f_t = SoftSigmoid(f_gate) # 遗忘门 g_t = SoftTanh(g_gate) # 候选门 o_t = SoftSigmoid(o_gate) # 输出门 # 更新细胞状态 self.c_state = f_t * self.c_state + i_t * g_t # 更新隐藏状态 self.h_state = o_t * SoftTanh(self.c_state) # 当前时间步输出 output.append(self.h_state) # 将输出转换为张量 output = torch.stack(output, dim=1) # 添加残差连接 if self.input_size == self.hidden_size: output = output + x_o else: output = output return self.dropout(output) def _forward_step(self, x: torch.Tensor) \ -> torch.Tensor: x_o = x x = self.norm(x) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size) gates = x @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias # 分割为四个门 (每个都是batch_size, hidden_size) i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) i_t = SoftSigmoid(i_gate) # 输入门 f_t = SoftSigmoid(f_gate) # 遗忘门 g_t = SoftTanh(g_gate) # 候选门 o_t = SoftSigmoid(o_gate) # 输出门 # 更新细胞状态 self.c_state = f_t * self.c_state + i_t * g_t # 更新隐藏状态 self.h_state = o_t * SoftTanh(self.c_state) # 当前时间步输出 output = self.h_state # 添加残差连接 if self.input_size == self.hidden_size: output = output + x_o else: output = output return self.dropout(output) def forward(self, x) -> torch.Tensor: if self.initialized is False: batch_size = x.size(0) device = x.device self.init_hidden(batch_size, device) if x.dim() == 2: return self._forward_step(x) elif x.dim() == 3: return self._forward_seq(x) else: raise ValueError("input dim must be 2(step) or 3(sequence)") def reset(self) -> None: self.h_state = None self.c_state = None self.initialized = False class LSTMEncoder(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super(LSTMEncoder, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx ) self.layers = nn.ModuleList([ MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ for _ in range(num_layers) ]) def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: x = self.embedding(x) output = x # 保存每一层最后的隐藏状态和细胞状态 h, c = [], [] for layer in self.layers: output = layer(output) h.append(layer.h_state) c.append(layer.c_state) return h, c def reset(self) -> None: for layer in self.layers: layer.reset() class LSTMDecoder(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super(LSTMDecoder, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx ) self.layers = nn.ModuleList([ MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ for _ in range(num_layers) ]) self.fc = nn.Linear(embedding_dim, vocab_size) def init_hidden( self, batch_size: int, device: torch.device, h: List[torch.Tensor], c: List[torch.Tensor] ) -> None: for i, layer in enumerate(self.layers): layer.init_hidden(batch_size, device, h[i], c[i]) def forward(self, x: torch.Tensor): x = self.embedding(x) for layer in self.layers: x = layer(x) return self.fc(x) def reset(self): for layer in self.layers: layer.reset() class Seq2SeqLSTM(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super().__init__() self.encoder = LSTMEncoder( vocab_size=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx, num_layers=num_layers, dropout=dropout ) self.decoder = LSTMDecoder( vocab_size=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx, num_layers=num_layers, dropout=dropout ) def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: self.reset() h, c = self.encoder(src) batch_size = src.size(0) device = src.device self.decoder.init_hidden(batch_size, device, h, c) return self.decoder(tgt) def reset(self): self.encoder.reset() self.decoder.reset() def greedy_decode( self, src: torch.Tensor, bos_token_id: int, eos_token_id: int, pad_token_id: int, max_length: int ) -> torch.Tensor: batch_size = src.size(0) device = src.device # 初始化输出序列 (全部填充为pad_token_id) output_seq = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device) # 初始化第一个解码输入为BOS标记 decoder_input = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device) # 初始化已完成序列的掩码 finished = torch.zeros(batch_size, dtype=torch.bool, device=device) # 编码源序列 self.reset() h, c = self.encoder(src) self.decoder.init_hidden(batch_size, device, h, c) with torch.no_grad(): for step in range(max_length): # 执行单步解码 logits = self.decoder(decoder_input) # (batch_size, 1, vocab_size) # 获取当前步的预测标记 next_tokens = logits.argmax(dim=-1) # (batch_size, 1) # 将预测结果写入输出序列 output_seq[:, step] = next_tokens.squeeze(1) # 更新已完成序列的掩码 finished = finished | (next_tokens.squeeze(1) == eos_token_id) # 如果所有序列都已完成则提前终止 if finished.all(): break # 准备下一步的输入: # 未完成序列使用预测标记,已完成序列使用pad_token_id decoder_input = next_tokens.masked_fill( finished.unsqueeze(1), pad_token_id ) return output_seq