You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

336 lines
11 KiB

import torch
import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
import random
from typing import Tuple, List, Optional
class MyLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.norm = AdaptiveSoftTanh(input_size)
self.dropout = nn.Dropout(dropout)
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
self.init_weight()
self.h_state = None
self.c_state = None
self.initialized = False
def init_weight(self) -> None:
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
else:
nn.init.zeros_(param)
def init_hidden(
self, batch_size: int, device: torch.device,
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None]
) -> None:
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if h_state is not None else h_state
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if c_state is not None else c_state
self.initialized = True
def _forward_seq(self, x: torch.Tensor) \
-> torch.Tensor:
_, T, _ = x.shape
# 存储所有时间步的输出
output = []
normed = self.norm(x)
# 遍历序列的每个时间步
for t in range(T):
x_t = normed[:, t, :] # 当前时间步输入 (B, C)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output.append(self.h_state)
# 将输出转换为张量
output = torch.stack(output, dim=1)
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
else:
output = output
return self.dropout(output)
def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor:
normed = self.norm(x)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output = self.h_state
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
else:
output = output
return self.dropout(output)
def forward(self, x) -> torch.Tensor:
if not self.initialized:
B = x.size(0)
self.init_hidden(B, x.device)
self.initialized = True
if x.dim() == 2:
return self._forward_step(x)
elif x.dim() == 3:
return self._forward_seq(x)
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self) -> None:
self.h_state = None
self.c_state = None
self.initialized = False
class LSTMEncoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMEncoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
x = self.embedding(x)
output = x
# 保存每一层最后的隐藏状态和细胞状态
h, c = [], []
for layer in self.layers:
output = layer(output)
h.append(layer.h_state)
c.append(layer.c_state)
return h, c
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class LSTMDecoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMDecoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
self.initialized = False
def init_hidden(
self, batch_size: int, device: torch.device,
h: List[torch.Tensor], c: List[torch.Tensor]
) -> None:
for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i])
self.initialized = True
def _forward_step(self, x: torch.Tensor) -> torch.Tensor:
output = x
for i, layer in enumerate(self.layers):
output = layer(output)
return output
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor:
_, T, _ = x.shape
outputs = []
for t in range(T):
if t == 0 or random.random() < tfr:
outputs.append(self._forward_step(x[:, t, :]))
else:
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1)
previous = self.embedding(logits)
outputs.append(self._forward_step(previous))
outputs = torch.stack(outputs, dim=1)
return outputs
def forward(self, x: torch.Tensor, tfr: float = 0.0):
x = self.embedding(x)
if x.dim() == 3:
return self.fc(self._forward_seq(x, tfr=tfr))
elif x.dim() == 2:
return self.fc(self._forward_step(x))
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self):
for layer in self.layers:
layer.reset()
self.initialized = True
class Seq2SeqLSTM(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super().__init__()
self.encoder = LSTMEncoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
self.decoder = LSTMDecoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor:
self.reset()
h, c = self.encoder(src)
batch_size = src.size(0)
device = src.device
self.decoder.init_hidden(batch_size, device, h, c)
return self.decoder(tgt, tfr=tfr)
def reset(self):
self.encoder.reset()
self.decoder.reset()
def greedy_decode(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
pad_token_id: int,
max_length: int
) -> torch.Tensor:
"""
推理阶段的序列生成方法(支持批处理)
Args:
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len)
bos_token_id: 起始符token ID
eos_token_id: 结束符token ID
pad_token_id: 填充符token ID
max_length: 最大生成长度
Returns:
生成的序列 (batch_size, tgt_len)
"""
self.reset()
batch_size = src.size(0)
device = src.device
# 编码源序列
h, c = self.encoder(src)
# 初始化解码器状态
self.decoder.init_hidden(batch_size, device, h, c)
# 准备输出序列张量(全部初始化为pad_token_id)
sequences = torch.full(
(batch_size, max_length),
pad_token_id,
dtype=torch.long,
device=device
)
# 初始输入为BOS
current_input = torch.full(
(batch_size,),
bos_token_id,
dtype=torch.long,
device=device
)
# 标记序列是否已结束(初始全为False)
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# 自回归生成序列
for t in range(max_length):
# 跳过已结束序列的计算
if finished.all():
break
# 获取当前时间步输出
logits = self.decoder(current_input) # (batch_size, vocab_size)
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,)
# 更新未结束序列的输出
sequences[~finished, t] = next_tokens[~finished]
# 检测EOS标记
eos_mask = (next_tokens == eos_token_id)
finished = finished | eos_mask # 更新结束标志
# 准备下一步输入:未结束序列用新token,已结束序列用PAD
current_input = torch.where(
~finished,
next_tokens,
torch.full_like(next_tokens, pad_token_id)
)
return sequences