You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
336 lines
11 KiB
336 lines
11 KiB
import torch
|
|
import torch.nn as nn
|
|
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
|
|
import random
|
|
from typing import Tuple, List, Optional
|
|
|
|
|
|
class MyLSTM(nn.Module):
|
|
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.norm = AdaptiveSoftTanh(input_size)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
|
|
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
|
|
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
|
|
|
|
self.init_weight()
|
|
|
|
self.h_state = None
|
|
self.c_state = None
|
|
self.initialized = False
|
|
|
|
def init_weight(self) -> None:
|
|
for param in self.parameters():
|
|
if param.dim() > 1:
|
|
nn.init.xavier_uniform_(param)
|
|
else:
|
|
nn.init.zeros_(param)
|
|
|
|
def init_hidden(
|
|
self, batch_size: int, device: torch.device,
|
|
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None]
|
|
) -> None:
|
|
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \
|
|
if h_state is not None else h_state
|
|
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \
|
|
if c_state is not None else c_state
|
|
self.initialized = True
|
|
|
|
def _forward_seq(self, x: torch.Tensor) \
|
|
-> torch.Tensor:
|
|
_, T, _ = x.shape
|
|
|
|
# 存储所有时间步的输出
|
|
output = []
|
|
normed = self.norm(x)
|
|
# 遍历序列的每个时间步
|
|
for t in range(T):
|
|
x_t = normed[:, t, :] # 当前时间步输入 (B, C)
|
|
|
|
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
|
|
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
|
|
|
|
# 分割为四个门 (每个都是batch_size, hidden_size)
|
|
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
|
|
|
|
i_t = SoftSigmoid(i_gate) # 输入门
|
|
f_t = SoftSigmoid(f_gate) # 遗忘门
|
|
g_t = SoftTanh(g_gate) # 候选门
|
|
o_t = SoftSigmoid(o_gate) # 输出门
|
|
|
|
# i_t = torch.sigmoid(i_gate) # 输入门
|
|
# f_t = torch.sigmoid(f_gate) # 遗忘门
|
|
# g_t = torch.tanh(g_gate) # 候选门
|
|
# o_t = torch.sigmoid(o_gate) # 输出门
|
|
|
|
# 更新细胞状态
|
|
self.c_state = f_t * self.c_state + i_t * g_t
|
|
|
|
# 更新隐藏状态
|
|
self.h_state = o_t * SoftTanh(self.c_state)
|
|
# self.h_state = o_t * torch.tanh(self.c_state)
|
|
|
|
# 当前时间步输出
|
|
output.append(self.h_state)
|
|
|
|
# 将输出转换为张量
|
|
output = torch.stack(output, dim=1)
|
|
|
|
# 添加残差连接
|
|
if self.input_size == self.hidden_size:
|
|
output = output + x
|
|
else:
|
|
output = output
|
|
|
|
return self.dropout(output)
|
|
|
|
def _forward_step(self, x: torch.Tensor) \
|
|
-> torch.Tensor:
|
|
normed = self.norm(x)
|
|
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
|
|
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
|
|
|
|
# 分割为四个门 (每个都是batch_size, hidden_size)
|
|
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
|
|
|
|
i_t = SoftSigmoid(i_gate) # 输入门
|
|
f_t = SoftSigmoid(f_gate) # 遗忘门
|
|
g_t = SoftTanh(g_gate) # 候选门
|
|
o_t = SoftSigmoid(o_gate) # 输出门
|
|
|
|
# i_t = torch.sigmoid(i_gate) # 输入门
|
|
# f_t = torch.sigmoid(f_gate) # 遗忘门
|
|
# g_t = torch.tanh(g_gate) # 候选门
|
|
# o_t = torch.sigmoid(o_gate) # 输出门
|
|
|
|
# 更新细胞状态
|
|
self.c_state = f_t * self.c_state + i_t * g_t
|
|
|
|
# 更新隐藏状态
|
|
self.h_state = o_t * SoftTanh(self.c_state)
|
|
# self.h_state = o_t * torch.tanh(self.c_state)
|
|
|
|
# 当前时间步输出
|
|
output = self.h_state
|
|
|
|
# 添加残差连接
|
|
if self.input_size == self.hidden_size:
|
|
output = output + x
|
|
else:
|
|
output = output
|
|
|
|
return self.dropout(output)
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
if not self.initialized:
|
|
B = x.size(0)
|
|
self.init_hidden(B, x.device)
|
|
self.initialized = True
|
|
|
|
if x.dim() == 2:
|
|
return self._forward_step(x)
|
|
elif x.dim() == 3:
|
|
return self._forward_seq(x)
|
|
else:
|
|
raise ValueError("input dim must be 2(step) or 3(sequence)")
|
|
|
|
def reset(self) -> None:
|
|
self.h_state = None
|
|
self.c_state = None
|
|
self.initialized = False
|
|
|
|
|
|
class LSTMEncoder(nn.Module):
|
|
def __init__(
|
|
self, vocab_size: int, embedding_dim: int,
|
|
padding_idx: int, num_layers: int, dropout: float
|
|
) -> None:
|
|
super(LSTMEncoder, self).__init__()
|
|
self.embedding = nn.Embedding(
|
|
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
|
|
)
|
|
self.layers = nn.ModuleList([
|
|
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
x = self.embedding(x)
|
|
output = x
|
|
# 保存每一层最后的隐藏状态和细胞状态
|
|
h, c = [], []
|
|
for layer in self.layers:
|
|
output = layer(output)
|
|
h.append(layer.h_state)
|
|
c.append(layer.c_state)
|
|
return h, c
|
|
|
|
def reset(self) -> None:
|
|
for layer in self.layers:
|
|
layer.reset()
|
|
|
|
|
|
class LSTMDecoder(nn.Module):
|
|
def __init__(
|
|
self, vocab_size: int, embedding_dim: int,
|
|
padding_idx: int, num_layers: int, dropout: float
|
|
) -> None:
|
|
super(LSTMDecoder, self).__init__()
|
|
self.embedding = nn.Embedding(
|
|
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
|
|
)
|
|
self.layers = nn.ModuleList([
|
|
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
|
|
for _ in range(num_layers)
|
|
])
|
|
self.fc = nn.Linear(embedding_dim, vocab_size)
|
|
|
|
self.initialized = False
|
|
|
|
def init_hidden(
|
|
self, batch_size: int, device: torch.device,
|
|
h: List[torch.Tensor], c: List[torch.Tensor]
|
|
) -> None:
|
|
for i, layer in enumerate(self.layers):
|
|
layer.init_hidden(batch_size, device, h[i], c[i])
|
|
self.initialized = True
|
|
|
|
def _forward_step(self, x: torch.Tensor) -> torch.Tensor:
|
|
output = x
|
|
for i, layer in enumerate(self.layers):
|
|
output = layer(output)
|
|
return output
|
|
|
|
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor:
|
|
_, T, _ = x.shape
|
|
outputs = []
|
|
for t in range(T):
|
|
if t == 0 or random.random() < tfr:
|
|
outputs.append(self._forward_step(x[:, t, :]))
|
|
else:
|
|
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1)
|
|
previous = self.embedding(logits)
|
|
outputs.append(self._forward_step(previous))
|
|
outputs = torch.stack(outputs, dim=1)
|
|
return outputs
|
|
|
|
def forward(self, x: torch.Tensor, tfr: float = 0.0):
|
|
x = self.embedding(x)
|
|
if x.dim() == 3:
|
|
return self.fc(self._forward_seq(x, tfr=tfr))
|
|
elif x.dim() == 2:
|
|
return self.fc(self._forward_step(x))
|
|
else:
|
|
raise ValueError("input dim must be 2(step) or 3(sequence)")
|
|
|
|
def reset(self):
|
|
for layer in self.layers:
|
|
layer.reset()
|
|
self.initialized = True
|
|
|
|
|
|
class Seq2SeqLSTM(nn.Module):
|
|
def __init__(
|
|
self, vocab_size: int, embedding_dim: int,
|
|
padding_idx: int, num_layers: int, dropout: float
|
|
) -> None:
|
|
super().__init__()
|
|
self.encoder = LSTMEncoder(
|
|
vocab_size=vocab_size, embedding_dim=embedding_dim,
|
|
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
|
|
)
|
|
self.decoder = LSTMDecoder(
|
|
vocab_size=vocab_size, embedding_dim=embedding_dim,
|
|
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
|
|
)
|
|
|
|
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor:
|
|
self.reset()
|
|
h, c = self.encoder(src)
|
|
batch_size = src.size(0)
|
|
device = src.device
|
|
self.decoder.init_hidden(batch_size, device, h, c)
|
|
return self.decoder(tgt, tfr=tfr)
|
|
|
|
def reset(self):
|
|
self.encoder.reset()
|
|
self.decoder.reset()
|
|
|
|
def greedy_decode(
|
|
self,
|
|
src: torch.Tensor,
|
|
bos_token_id: int,
|
|
eos_token_id: int,
|
|
pad_token_id: int,
|
|
max_length: int
|
|
) -> torch.Tensor:
|
|
"""
|
|
推理阶段的序列生成方法(支持批处理)
|
|
Args:
|
|
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len)
|
|
bos_token_id: 起始符token ID
|
|
eos_token_id: 结束符token ID
|
|
pad_token_id: 填充符token ID
|
|
max_length: 最大生成长度
|
|
Returns:
|
|
生成的序列 (batch_size, tgt_len)
|
|
"""
|
|
self.reset()
|
|
batch_size = src.size(0)
|
|
device = src.device
|
|
|
|
# 编码源序列
|
|
h, c = self.encoder(src)
|
|
|
|
# 初始化解码器状态
|
|
self.decoder.init_hidden(batch_size, device, h, c)
|
|
|
|
# 准备输出序列张量(全部初始化为pad_token_id)
|
|
sequences = torch.full(
|
|
(batch_size, max_length),
|
|
pad_token_id,
|
|
dtype=torch.long,
|
|
device=device
|
|
)
|
|
|
|
# 初始输入为BOS
|
|
current_input = torch.full(
|
|
(batch_size,),
|
|
bos_token_id,
|
|
dtype=torch.long,
|
|
device=device
|
|
)
|
|
|
|
# 标记序列是否已结束(初始全为False)
|
|
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
|
|
# 自回归生成序列
|
|
for t in range(max_length):
|
|
# 跳过已结束序列的计算
|
|
if finished.all():
|
|
break
|
|
|
|
# 获取当前时间步输出
|
|
logits = self.decoder(current_input) # (batch_size, vocab_size)
|
|
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,)
|
|
|
|
# 更新未结束序列的输出
|
|
sequences[~finished, t] = next_tokens[~finished]
|
|
|
|
# 检测EOS标记
|
|
eos_mask = (next_tokens == eos_token_id)
|
|
finished = finished | eos_mask # 更新结束标志
|
|
|
|
# 准备下一步输入:未结束序列用新token,已结束序列用PAD
|
|
current_input = torch.where(
|
|
~finished,
|
|
next_tokens,
|
|
torch.full_like(next_tokens, pad_token_id)
|
|
)
|
|
|
|
return sequences
|
|
|
|
|