import torch import torch.nn as nn from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh import random from typing import Tuple, List, Optional class MyLSTM(nn.Module): def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.norm = AdaptiveSoftTanh(input_size) self.dropout = nn.Dropout(dropout) self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size)) self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size)) self.bias = nn.Parameter(torch.Tensor(4 * hidden_size)) self.init_weight() self.h_state = None self.c_state = None self.initialized = False def init_weight(self) -> None: for param in self.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) def init_hidden( self, batch_size: int, device: torch.device, h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None] ) -> None: self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \ if h_state is not None else h_state self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \ if c_state is not None else c_state self.initialized = True def _forward_seq(self, x: torch.Tensor) \ -> torch.Tensor: _, T, _ = x.shape # 存储所有时间步的输出 output = [] normed = self.norm(x) # 遍历序列的每个时间步 for t in range(T): x_t = normed[:, t, :] # 当前时间步输入 (B, C) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size) gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias # 分割为四个门 (每个都是batch_size, hidden_size) i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) i_t = SoftSigmoid(i_gate) # 输入门 f_t = SoftSigmoid(f_gate) # 遗忘门 g_t = SoftTanh(g_gate) # 候选门 o_t = SoftSigmoid(o_gate) # 输出门 # i_t = torch.sigmoid(i_gate) # 输入门 # f_t = torch.sigmoid(f_gate) # 遗忘门 # g_t = torch.tanh(g_gate) # 候选门 # o_t = torch.sigmoid(o_gate) # 输出门 # 更新细胞状态 self.c_state = f_t * self.c_state + i_t * g_t # 更新隐藏状态 self.h_state = o_t * SoftTanh(self.c_state) # self.h_state = o_t * torch.tanh(self.c_state) # 当前时间步输出 output.append(self.h_state) # 将输出转换为张量 output = torch.stack(output, dim=1) # 添加残差连接 if self.input_size == self.hidden_size: output = output + x else: output = output return self.dropout(output) def _forward_step(self, x: torch.Tensor) \ -> torch.Tensor: normed = self.norm(x) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size) gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias # 分割为四个门 (每个都是batch_size, hidden_size) i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) i_t = SoftSigmoid(i_gate) # 输入门 f_t = SoftSigmoid(f_gate) # 遗忘门 g_t = SoftTanh(g_gate) # 候选门 o_t = SoftSigmoid(o_gate) # 输出门 # i_t = torch.sigmoid(i_gate) # 输入门 # f_t = torch.sigmoid(f_gate) # 遗忘门 # g_t = torch.tanh(g_gate) # 候选门 # o_t = torch.sigmoid(o_gate) # 输出门 # 更新细胞状态 self.c_state = f_t * self.c_state + i_t * g_t # 更新隐藏状态 self.h_state = o_t * SoftTanh(self.c_state) # self.h_state = o_t * torch.tanh(self.c_state) # 当前时间步输出 output = self.h_state # 添加残差连接 if self.input_size == self.hidden_size: output = output + x else: output = output return self.dropout(output) def forward(self, x) -> torch.Tensor: if not self.initialized: B = x.size(0) self.init_hidden(B, x.device) self.initialized = True if x.dim() == 2: return self._forward_step(x) elif x.dim() == 3: return self._forward_seq(x) else: raise ValueError("input dim must be 2(step) or 3(sequence)") def reset(self) -> None: self.h_state = None self.c_state = None self.initialized = False class LSTMEncoder(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super(LSTMEncoder, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx ) self.layers = nn.ModuleList([ MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ for _ in range(num_layers) ]) def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: x = self.embedding(x) output = x # 保存每一层最后的隐藏状态和细胞状态 h, c = [], [] for layer in self.layers: output = layer(output) h.append(layer.h_state) c.append(layer.c_state) return h, c def reset(self) -> None: for layer in self.layers: layer.reset() class LSTMDecoder(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super(LSTMDecoder, self).__init__() self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx ) self.layers = nn.ModuleList([ MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ for _ in range(num_layers) ]) self.fc = nn.Linear(embedding_dim, vocab_size) self.initialized = False def init_hidden( self, batch_size: int, device: torch.device, h: List[torch.Tensor], c: List[torch.Tensor] ) -> None: for i, layer in enumerate(self.layers): layer.init_hidden(batch_size, device, h[i], c[i]) self.initialized = True def _forward_step(self, x: torch.Tensor) -> torch.Tensor: output = x for i, layer in enumerate(self.layers): output = layer(output) return output def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor: _, T, _ = x.shape outputs = [] for t in range(T): if t == 0 or random.random() < tfr: outputs.append(self._forward_step(x[:, t, :])) else: logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1) previous = self.embedding(logits) outputs.append(self._forward_step(previous)) outputs = torch.stack(outputs, dim=1) return outputs def forward(self, x: torch.Tensor, tfr: float = 0.0): x = self.embedding(x) if x.dim() == 3: return self.fc(self._forward_seq(x, tfr=tfr)) elif x.dim() == 2: return self.fc(self._forward_step(x)) else: raise ValueError("input dim must be 2(step) or 3(sequence)") def reset(self): for layer in self.layers: layer.reset() self.initialized = True class Seq2SeqLSTM(nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, padding_idx: int, num_layers: int, dropout: float ) -> None: super().__init__() self.encoder = LSTMEncoder( vocab_size=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx, num_layers=num_layers, dropout=dropout ) self.decoder = LSTMDecoder( vocab_size=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx, num_layers=num_layers, dropout=dropout ) def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor: self.reset() h, c = self.encoder(src) batch_size = src.size(0) device = src.device self.decoder.init_hidden(batch_size, device, h, c) return self.decoder(tgt, tfr=tfr) def reset(self): self.encoder.reset() self.decoder.reset() def greedy_decode( self, src: torch.Tensor, bos_token_id: int, eos_token_id: int, pad_token_id: int, max_length: int ) -> torch.Tensor: """ 推理阶段的序列生成方法(支持批处理) Args: src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len) bos_token_id: 起始符token ID eos_token_id: 结束符token ID pad_token_id: 填充符token ID max_length: 最大生成长度 Returns: 生成的序列 (batch_size, tgt_len) """ self.reset() batch_size = src.size(0) device = src.device # 编码源序列 h, c = self.encoder(src) # 初始化解码器状态 self.decoder.init_hidden(batch_size, device, h, c) # 准备输出序列张量(全部初始化为pad_token_id) sequences = torch.full( (batch_size, max_length), pad_token_id, dtype=torch.long, device=device ) # 初始输入为BOS current_input = torch.full( (batch_size,), bos_token_id, dtype=torch.long, device=device ) # 标记序列是否已结束(初始全为False) finished = torch.zeros(batch_size, dtype=torch.bool, device=device) # 自回归生成序列 for t in range(max_length): # 跳过已结束序列的计算 if finished.all(): break # 获取当前时间步输出 logits = self.decoder(current_input) # (batch_size, vocab_size) next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,) # 更新未结束序列的输出 sequences[~finished, t] = next_tokens[~finished] # 检测EOS标记 eos_mask = (next_tokens == eos_token_id) finished = finished | eos_mask # 更新结束标志 # 准备下一步输入:未结束序列用新token,已结束序列用PAD current_input = torch.where( ~finished, next_tokens, torch.full_like(next_tokens, pad_token_id) ) return sequences