You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

280 lines
9.3 KiB

5 days ago
import torch
import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
from typing import Tuple, List, Optional
class MyLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.norm = AdaptiveSoftTanh(input_size)
self.dropout = nn.Dropout(dropout)
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
self.init_weight()
self.h_state = None
self.c_state = None
self.initialized = False
def init_weight(self) -> None:
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
else:
nn.init.zeros_(param)
def init_hidden(
self, batch_size: int, device: torch.device,
4 days ago
h_state: Optional[torch.Tensor] = None, c_state: Optional[torch.Tensor] = None
5 days ago
) -> None:
4 days ago
self.h_state = h_state if h_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.c_state = c_state if c_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
5 days ago
self.initialized = True
def _forward_seq(self, x: torch.Tensor) \
-> torch.Tensor:
_, T, _ = x.shape
# 存储所有时间步的输出
output = []
4 days ago
x_o = x
x = self.norm(x)
5 days ago
# 遍历序列的每个时间步
for t in range(T):
4 days ago
x_t = x[:, t, :] # 当前时间步输入 (B, C)
5 days ago
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# 当前时间步输出
output.append(self.h_state)
# 将输出转换为张量
output = torch.stack(output, dim=1)
# 添加残差连接
if self.input_size == self.hidden_size:
4 days ago
output = output + x_o
5 days ago
else:
output = output
return self.dropout(output)
def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor:
4 days ago
x_o = x
x = self.norm(x)
5 days ago
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
4 days ago
gates = x @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
5 days ago
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# 当前时间步输出
output = self.h_state
# 添加残差连接
if self.input_size == self.hidden_size:
4 days ago
output = output + x_o
5 days ago
else:
output = output
return self.dropout(output)
def forward(self, x) -> torch.Tensor:
4 days ago
if self.initialized is False:
batch_size = x.size(0)
device = x.device
self.init_hidden(batch_size, device)
5 days ago
if x.dim() == 2:
return self._forward_step(x)
elif x.dim() == 3:
return self._forward_seq(x)
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self) -> None:
self.h_state = None
self.c_state = None
self.initialized = False
class LSTMEncoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMEncoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
x = self.embedding(x)
output = x
# 保存每一层最后的隐藏状态和细胞状态
h, c = [], []
for layer in self.layers:
output = layer(output)
h.append(layer.h_state)
c.append(layer.c_state)
return h, c
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class LSTMDecoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMDecoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
def init_hidden(
self, batch_size: int, device: torch.device,
h: List[torch.Tensor], c: List[torch.Tensor]
) -> None:
for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i])
4 days ago
def forward(self, x: torch.Tensor):
5 days ago
x = self.embedding(x)
4 days ago
for layer in self.layers:
x = layer(x)
return self.fc(x)
5 days ago
def reset(self):
for layer in self.layers:
layer.reset()
class Seq2SeqLSTM(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super().__init__()
self.encoder = LSTMEncoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
self.decoder = LSTMDecoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
4 days ago
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
5 days ago
self.reset()
h, c = self.encoder(src)
batch_size = src.size(0)
device = src.device
self.decoder.init_hidden(batch_size, device, h, c)
4 days ago
return self.decoder(tgt)
5 days ago
def reset(self):
self.encoder.reset()
self.decoder.reset()
def greedy_decode(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
pad_token_id: int,
max_length: int
) -> torch.Tensor:
batch_size = src.size(0)
device = src.device
4 days ago
# 初始化输出序列 (全部填充为pad_token_id)
output_seq = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device)
5 days ago
4 days ago
# 初始化第一个解码输入为BOS标记
decoder_input = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
5 days ago
4 days ago
# 初始化已完成序列的掩码
5 days ago
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
4 days ago
# 编码源序列
self.reset()
h, c = self.encoder(src)
self.decoder.init_hidden(batch_size, device, h, c)
with torch.no_grad():
for step in range(max_length):
# 执行单步解码
logits = self.decoder(decoder_input) # (batch_size, 1, vocab_size)
5 days ago
4 days ago
# 获取当前步的预测标记
next_tokens = logits.argmax(dim=-1) # (batch_size, 1)
# 将预测结果写入输出序列
output_seq[:, step] = next_tokens.squeeze(1)
# 更新已完成序列的掩码
finished = finished | (next_tokens.squeeze(1) == eos_token_id)
# 如果所有序列都已完成则提前终止
if finished.all():
break
# 准备下一步的输入:
# 未完成序列使用预测标记,已完成序列使用pad_token_id
decoder_input = next_tokens.masked_fill(
finished.unsqueeze(1),
pad_token_id
)
5 days ago
4 days ago
return output_seq