Browse Source

完善MyLSTM

main
tanxing 4 days ago
parent
commit
3ee8d630b7
  1. 2
      .gitignore
  2. 175
      MyLSTM.py
  3. 6
      TryBatch.py
  4. BIN
      __pycache__/BucketManager.cpython-312.pyc
  5. BIN
      __pycache__/MyDataset.cpython-312.pyc
  6. BIN
      __pycache__/MyLSTM.cpython-312.pyc
  7. BIN
      __pycache__/MyLayer.cpython-312.pyc
  8. BIN
      __pycache__/MyTokenizer.cpython-312.pyc
  9. BIN
      __pycache__/TryBatch.cpython-312.pyc
  10. 2
      inference.py
  11. BIN
      model/checkpoints/latest_checkpoint.pt
  12. BIN
      model/checkpoints/latest_model.pt
  13. 7944
      results/translation_comparison.txt
  14. 82
      train.py
  15. 1457
      training.log

2
.gitignore

@ -1,4 +1,4 @@
_pycache__/
_pycache__
*.pt
!latest_checkpoint.pt
!latest_model.pt

175
MyLSTM.py

@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
import random
from typing import Tuple, List, Optional
@ -32,12 +31,14 @@ class MyLSTM(nn.Module):
def init_hidden(
self, batch_size: int, device: torch.device,
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None]
h_state: Optional[torch.Tensor] = None, c_state: Optional[torch.Tensor] = None
) -> None:
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if h_state is not None else h_state
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if c_state is not None else c_state
self.h_state = h_state if h_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.c_state = c_state if c_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.initialized = True
def _forward_seq(self, x: torch.Tensor) \
@ -46,10 +47,11 @@ class MyLSTM(nn.Module):
# 存储所有时间步的输出
output = []
normed = self.norm(x)
x_o = x
x = self.norm(x)
# 遍历序列的每个时间步
for t in range(T):
x_t = normed[:, t, :] # 当前时间步输入 (B, C)
x_t = x[:, t, :] # 当前时间步输入 (B, C)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
@ -62,17 +64,12 @@ class MyLSTM(nn.Module):
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output.append(self.h_state)
@ -82,7 +79,7 @@ class MyLSTM(nn.Module):
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
output = output + x_o
else:
output = output
@ -90,9 +87,10 @@ class MyLSTM(nn.Module):
def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor:
normed = self.norm(x)
x_o = x
x = self.norm(x)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
gates = x @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
@ -101,35 +99,29 @@ class MyLSTM(nn.Module):
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output = self.h_state
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
output = output + x_o
else:
output = output
return self.dropout(output)
def forward(self, x) -> torch.Tensor:
if not self.initialized:
B = x.size(0)
self.init_hidden(B, x.device)
self.initialized = True
if self.initialized is False:
batch_size = x.size(0)
device = x.device
self.init_hidden(batch_size, device)
if x.dim() == 2:
return self._forward_step(x)
@ -188,8 +180,6 @@ class LSTMDecoder(nn.Module):
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
self.initialized = False
def init_hidden(
self, batch_size: int, device: torch.device,
@ -197,40 +187,16 @@ class LSTMDecoder(nn.Module):
) -> None:
for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i])
self.initialized = True
def _forward_step(self, x: torch.Tensor) -> torch.Tensor:
output = x
for i, layer in enumerate(self.layers):
output = layer(output)
return output
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor:
_, T, _ = x.shape
outputs = []
for t in range(T):
if t == 0 or random.random() < tfr:
outputs.append(self._forward_step(x[:, t, :]))
else:
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1)
previous = self.embedding(logits)
outputs.append(self._forward_step(previous))
outputs = torch.stack(outputs, dim=1)
return outputs
def forward(self, x: torch.Tensor, tfr: float = 0.0):
def forward(self, x: torch.Tensor):
x = self.embedding(x)
if x.dim() == 3:
return self.fc(self._forward_seq(x, tfr=tfr))
elif x.dim() == 2:
return self.fc(self._forward_step(x))
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
for layer in self.layers:
x = layer(x)
return self.fc(x)
def reset(self):
for layer in self.layers:
layer.reset()
self.initialized = True
class Seq2SeqLSTM(nn.Module):
@ -248,13 +214,13 @@ class Seq2SeqLSTM(nn.Module):
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor:
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
self.reset()
h, c = self.encoder(src)
batch_size = src.size(0)
device = src.device
self.decoder.init_hidden(batch_size, device, h, c)
return self.decoder(tgt, tfr=tfr)
return self.decoder(tgt)
def reset(self):
self.encoder.reset()
@ -268,69 +234,46 @@ class Seq2SeqLSTM(nn.Module):
pad_token_id: int,
max_length: int
) -> torch.Tensor:
"""
推理阶段的序列生成方法支持批处理
Args:
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len)
bos_token_id: 起始符token ID
eos_token_id: 结束符token ID
pad_token_id: 填充符token ID
max_length: 最大生成长度
Returns:
生成的序列 (batch_size, tgt_len)
"""
self.reset()
batch_size = src.size(0)
device = src.device
# 编码源序列
h, c = self.encoder(src)
# 初始化解码器状态
self.decoder.init_hidden(batch_size, device, h, c)
# 准备输出序列张量(全部初始化为pad_token_id)
sequences = torch.full(
(batch_size, max_length),
pad_token_id,
dtype=torch.long,
device=device
)
# 初始化输出序列 (全部填充为pad_token_id)
output_seq = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device)
# 初始输入为BOS
current_input = torch.full(
(batch_size,),
bos_token_id,
dtype=torch.long,
device=device
)
# 初始化第一个解码输入为BOS标记
decoder_input = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
# 标记序列是否已结束(初始全为False)
# 初始化已完成序列的掩码
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# 自回归生成序列
for t in range(max_length):
# 跳过已结束序列的计算
if finished.all():
break
# 编码源序列
self.reset()
h, c = self.encoder(src)
self.decoder.init_hidden(batch_size, device, h, c)
with torch.no_grad():
for step in range(max_length):
# 执行单步解码
logits = self.decoder(decoder_input) # (batch_size, 1, vocab_size)
# 获取当前时间步输出
logits = self.decoder(current_input) # (batch_size, vocab_size)
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,)
# 更新未结束序列的输出
sequences[~finished, t] = next_tokens[~finished]
# 检测EOS标记
eos_mask = (next_tokens == eos_token_id)
finished = finished | eos_mask # 更新结束标志
# 准备下一步输入:未结束序列用新token,已结束序列用PAD
current_input = torch.where(
~finished,
next_tokens,
torch.full_like(next_tokens, pad_token_id)
)
# 获取当前步的预测标记
next_tokens = logits.argmax(dim=-1) # (batch_size, 1)
# 将预测结果写入输出序列
output_seq[:, step] = next_tokens.squeeze(1)
# 更新已完成序列的掩码
finished = finished | (next_tokens.squeeze(1) == eos_token_id)
# 如果所有序列都已完成则提前终止
if finished.all():
break
# 准备下一步的输入:
# 未完成序列使用预测标记,已完成序列使用pad_token_id
decoder_input = next_tokens.masked_fill(
finished.unsqueeze(1),
pad_token_id
)
return sequences
return output_seq

6
TryBatch.py

@ -36,11 +36,7 @@ def try_batch(config: dict, src_max: int, tgt_max: int, batch_size: int, max_bat
# training loop
for i in range(max_batch_iter):
with autocast(config["device"], enabled=config["use_amp"]):
if i % 2 == 0:
tfr = random.random()
else:
tfr = 1
output = model(srcs, tgts, tfr)
output = model(srcs, tgts)
loss = nn.CrossEntropyLoss()(
output.reshape(-1, output.size(-1)),
labels.contiguous().view(-1)

BIN
__pycache__/BucketManager.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyDataset.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLSTM.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLayer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyTokenizer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/TryBatch.cpython-312.pyc

Binary file not shown.

2
inference.py

@ -141,7 +141,7 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model/checkpoints/latest_model.pt", map_location=device, weights_only=False)
# model = torch.load("model/checkpoints/model_009.pt", map_location=device, weights_only=False)
# model = torch.load("model/checkpoints/model_003.pt", map_location=device, weights_only=False)
# model = torch.load("temp/latest_model.pt", map_location=device, weights_only=False)
model.eval().to(device)

BIN
model/checkpoints/latest_checkpoint.pt (Stored with Git LFS)

Binary file not shown.

BIN
model/checkpoints/latest_model.pt (Stored with Git LFS)

Binary file not shown.

7944
results/translation_comparison.txt

File diff suppressed because it is too large

82
train.py

@ -55,51 +55,6 @@ CONFIG = {
}
class CosineAnnealingTFR:
def __init__(self, T_max, eta_max=1.0, eta_min=0.0, last_epoch=-1, verbose=False):
self.T_max = T_max
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
self.last_epoch = last_epoch
# 初始化状态
self.current_tfr = eta_max if last_epoch == -1 else self._compute_tfr(last_epoch)
if last_epoch == -1:
self.step(0)
def _compute_tfr(self, epoch):
cos = math.cos(math.pi * epoch / self.T_max)
return self.eta_min + (self.eta_max - self.eta_min) * (1 + cos) / 2
def step(self, epoch=None):
if epoch is not None:
self.last_epoch = epoch
else:
self.last_epoch += 1
self.current_tfr = self._compute_tfr(self.last_epoch)
if self.verbose:
print(f'Epoch {self.last_epoch:5d}: TFR adjusted to {self.current_tfr:.4f}')
def get_last_tfr(self):
return self.current_tfr
def state_dict(self):
return {
'T_max': self.T_max,
'eta_max': self.eta_max,
'eta_min': self.eta_min,
'last_epoch': self.last_epoch,
'verbose': self.verbose
}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self.current_tfr = self._compute_tfr(self.last_epoch)
def setup_logger(log_file):
logger = logging.getLogger("train_logger")
logger.setLevel(logging.INFO)
@ -120,7 +75,7 @@ def setup_logger(log_file):
return logger
def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, bucket_idx, config):
def save_checkpoint(model, optimizer, scaler, lr_scheduler, epoch, bucket_idx, config):
os.makedirs(config["checkpoint_dir"], exist_ok=True)
checkpoint = {
'epoch': epoch,
@ -130,7 +85,6 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch,
'optim_state': optimizer.state_dict(),
'scaler_state': scaler.state_dict(),
'lr_scheduler_state': lr_scheduler.state_dict(),
'tf_scheduler_state': tf_scheduler.state_dict(),
'random_state': random.getstate(),
'numpy_random_state': np.random.get_state(),
'torch_random_state': torch.get_rng_state(),
@ -144,16 +98,15 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch,
torch.save(checkpoint, os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt"))
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, config):
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, config):
checkpoint_path = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")
if not os.path.exists(checkpoint_path):
return model, optimizer, scaler, lr_scheduler, tf_scheduler, 0, 0, config["bucket_list"]
return model, optimizer, scaler, lr_scheduler, 0, 0, config["bucket_list"]
checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state'])
tf_scheduler.load_state_dict(checkpoint['tf_scheduler_state'])
random.setstate(checkpoint['random_state'])
np.random.set_state(checkpoint['numpy_random_state'])
torch.set_rng_state(checkpoint['torch_random_state'])
@ -162,8 +115,7 @@ def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler,
scaler.load_state_dict(checkpoint['scaler_state'])
return (
model, optimizer, scaler,
lr_scheduler, tf_scheduler,
model, optimizer, scaler, lr_scheduler,
checkpoint['epoch'],
checkpoint['bucket_idx'],
checkpoint["bucket_list"]
@ -212,7 +164,7 @@ def validate(model, config):
labels = batch["label"].to(config["device"])
batch_size = src.size(0)
output = model(src, tgt, 1)
output = model(src, tgt)
loss = nn.CrossEntropyLoss(
ignore_index=config["pad_token_id"],
label_smoothing=config["label_smoothing"],
@ -298,11 +250,6 @@ def train(config):
T_max=config["T_max"] * 2,
eta_min=config["lr_eta_min"]
)
tf_scheduler = CosineAnnealingTFR(
T_max=config["T_max"],
eta_max=config["tf_eta_max"],
eta_min=config["tf_eta_min"]
)
scaler = GradScaler(config["device"], enabled=config["use_amp"])
# recover from checkpoint
@ -310,8 +257,8 @@ def train(config):
start_bucket_idx = 0
if os.path.exists(os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")):
model, optimizer, scaler, lr_scheduler, tf_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint(
model, optimizer, scaler, lr_scheduler, tf_scheduler, config
model, optimizer, scaler, lr_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint(
model, optimizer, scaler, lr_scheduler, config
)
if start_epoch >= config["epochs"]:
@ -321,7 +268,7 @@ def train(config):
# main train loop
for epoch in range(start_epoch, config["epochs"]):
for bucket_idx in range(start_bucket_idx, len(config["bucket_list"])):
if bucket_idx % len(config["bucket_list"]) == 0 and epoch != 0:
if bucket_idx % len(config["bucket_list"]) == 0:
random.shuffle(config["bucket_list"])
bucket = config["bucket_list"][bucket_idx]
@ -377,11 +324,7 @@ def train(config):
# mixed precision
with autocast(config["device"], enabled=config["use_amp"]):
if epoch % 3 == 0:
tfr = 1
else:
tfr = tf_scheduler.get_last_tfr()
output = model(src, tgt, tfr)
output = model(src, tgt)
loss = nn.CrossEntropyLoss(
ignore_index=config["pad_token_id"],
label_smoothing=config["label_smoothing"]
@ -403,8 +346,6 @@ def train(config):
# update scheduler after a batch
lr_scheduler.step()
if epoch % 3 != 0:
tf_scheduler.step()
# preformance eval
with torch.no_grad():
@ -422,8 +363,7 @@ def train(config):
loss=f'{loss.item():.3f}',
acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \
if total_valid > 0 else '0.000%',
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}',
tf=f'{tfr:.3e}'
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}'
)
# log parquet info
@ -435,7 +375,7 @@ def train(config):
# save checkpoint after a bucket
save_checkpoint(
model, optimizer, scaler, lr_scheduler, tf_scheduler,
model, optimizer, scaler, lr_scheduler,
epoch, bucket_idx + 1, config
)

1457
training.log

File diff suppressed because it is too large
Loading…
Cancel
Save