Browse Source

完善MyLSTM

main
tanxing 4 days ago
parent
commit
3ee8d630b7
  1. 2
      .gitignore
  2. 163
      MyLSTM.py
  3. 6
      TryBatch.py
  4. BIN
      __pycache__/BucketManager.cpython-312.pyc
  5. BIN
      __pycache__/MyDataset.cpython-312.pyc
  6. BIN
      __pycache__/MyLSTM.cpython-312.pyc
  7. BIN
      __pycache__/MyLayer.cpython-312.pyc
  8. BIN
      __pycache__/MyTokenizer.cpython-312.pyc
  9. BIN
      __pycache__/TryBatch.cpython-312.pyc
  10. 2
      inference.py
  11. BIN
      model/checkpoints/latest_checkpoint.pt
  12. BIN
      model/checkpoints/latest_model.pt
  13. 7944
      results/translation_comparison.txt
  14. 82
      train.py
  15. 1457
      training.log

2
.gitignore

@ -1,4 +1,4 @@
_pycache__/ _pycache__
*.pt *.pt
!latest_checkpoint.pt !latest_checkpoint.pt
!latest_model.pt !latest_model.pt

163
MyLSTM.py

@ -1,7 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
import random
from typing import Tuple, List, Optional from typing import Tuple, List, Optional
@ -32,12 +31,14 @@ class MyLSTM(nn.Module):
def init_hidden( def init_hidden(
self, batch_size: int, device: torch.device, self, batch_size: int, device: torch.device,
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None] h_state: Optional[torch.Tensor] = None, c_state: Optional[torch.Tensor] = None
) -> None: ) -> None:
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \ self.h_state = h_state if h_state is not None else torch.zeros(
if h_state is not None else h_state batch_size, self.hidden_size, device=device
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \ )
if c_state is not None else c_state self.c_state = c_state if c_state is not None else torch.zeros(
batch_size, self.hidden_size, device=device
)
self.initialized = True self.initialized = True
def _forward_seq(self, x: torch.Tensor) \ def _forward_seq(self, x: torch.Tensor) \
@ -46,10 +47,11 @@ class MyLSTM(nn.Module):
# 存储所有时间步的输出 # 存储所有时间步的输出
output = [] output = []
normed = self.norm(x) x_o = x
x = self.norm(x)
# 遍历序列的每个时间步 # 遍历序列的每个时间步
for t in range(T): for t in range(T):
x_t = normed[:, t, :] # 当前时间步输入 (B, C) x_t = x[:, t, :] # 当前时间步输入 (B, C)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
@ -62,17 +64,12 @@ class MyLSTM(nn.Module):
g_t = SoftTanh(g_gate) # 候选门 g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门 o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态 # 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态 # 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state) self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出 # 当前时间步输出
output.append(self.h_state) output.append(self.h_state)
@ -82,7 +79,7 @@ class MyLSTM(nn.Module):
# 添加残差连接 # 添加残差连接
if self.input_size == self.hidden_size: if self.input_size == self.hidden_size:
output = output + x output = output + x_o
else: else:
output = output output = output
@ -90,9 +87,10 @@ class MyLSTM(nn.Module):
def _forward_step(self, x: torch.Tensor) \ def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor: -> torch.Tensor:
normed = self.norm(x) x_o = x
x = self.norm(x)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size) # 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias gates = x @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size) # 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
@ -102,34 +100,28 @@ class MyLSTM(nn.Module):
g_t = SoftTanh(g_gate) # 候选门 g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门 o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态 # 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态 # 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state) self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出 # 当前时间步输出
output = self.h_state output = self.h_state
# 添加残差连接 # 添加残差连接
if self.input_size == self.hidden_size: if self.input_size == self.hidden_size:
output = output + x output = output + x_o
else: else:
output = output output = output
return self.dropout(output) return self.dropout(output)
def forward(self, x) -> torch.Tensor: def forward(self, x) -> torch.Tensor:
if not self.initialized: if self.initialized is False:
B = x.size(0) batch_size = x.size(0)
self.init_hidden(B, x.device) device = x.device
self.initialized = True self.init_hidden(batch_size, device)
if x.dim() == 2: if x.dim() == 2:
return self._forward_step(x) return self._forward_step(x)
@ -189,48 +181,22 @@ class LSTMDecoder(nn.Module):
]) ])
self.fc = nn.Linear(embedding_dim, vocab_size) self.fc = nn.Linear(embedding_dim, vocab_size)
self.initialized = False
def init_hidden( def init_hidden(
self, batch_size: int, device: torch.device, self, batch_size: int, device: torch.device,
h: List[torch.Tensor], c: List[torch.Tensor] h: List[torch.Tensor], c: List[torch.Tensor]
) -> None: ) -> None:
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i]) layer.init_hidden(batch_size, device, h[i], c[i])
self.initialized = True
def _forward_step(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor):
output = x
for i, layer in enumerate(self.layers):
output = layer(output)
return output
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor:
_, T, _ = x.shape
outputs = []
for t in range(T):
if t == 0 or random.random() < tfr:
outputs.append(self._forward_step(x[:, t, :]))
else:
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1)
previous = self.embedding(logits)
outputs.append(self._forward_step(previous))
outputs = torch.stack(outputs, dim=1)
return outputs
def forward(self, x: torch.Tensor, tfr: float = 0.0):
x = self.embedding(x) x = self.embedding(x)
if x.dim() == 3: for layer in self.layers:
return self.fc(self._forward_seq(x, tfr=tfr)) x = layer(x)
elif x.dim() == 2: return self.fc(x)
return self.fc(self._forward_step(x))
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self): def reset(self):
for layer in self.layers: for layer in self.layers:
layer.reset() layer.reset()
self.initialized = True
class Seq2SeqLSTM(nn.Module): class Seq2SeqLSTM(nn.Module):
@ -248,13 +214,13 @@ class Seq2SeqLSTM(nn.Module):
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
) )
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor: def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
self.reset() self.reset()
h, c = self.encoder(src) h, c = self.encoder(src)
batch_size = src.size(0) batch_size = src.size(0)
device = src.device device = src.device
self.decoder.init_hidden(batch_size, device, h, c) self.decoder.init_hidden(batch_size, device, h, c)
return self.decoder(tgt, tfr=tfr) return self.decoder(tgt)
def reset(self): def reset(self):
self.encoder.reset() self.encoder.reset()
@ -268,69 +234,46 @@ class Seq2SeqLSTM(nn.Module):
pad_token_id: int, pad_token_id: int,
max_length: int max_length: int
) -> torch.Tensor: ) -> torch.Tensor:
"""
推理阶段的序列生成方法支持批处理
Args:
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len)
bos_token_id: 起始符token ID
eos_token_id: 结束符token ID
pad_token_id: 填充符token ID
max_length: 最大生成长度
Returns:
生成的序列 (batch_size, tgt_len)
"""
self.reset()
batch_size = src.size(0) batch_size = src.size(0)
device = src.device device = src.device
# 初始化输出序列 (全部填充为pad_token_id)
output_seq = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device)
# 初始化第一个解码输入为BOS标记
decoder_input = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
# 初始化已完成序列的掩码
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# 编码源序列 # 编码源序列
self.reset()
h, c = self.encoder(src) h, c = self.encoder(src)
# 初始化解码器状态
self.decoder.init_hidden(batch_size, device, h, c) self.decoder.init_hidden(batch_size, device, h, c)
# 准备输出序列张量(全部初始化为pad_token_id) with torch.no_grad():
sequences = torch.full( for step in range(max_length):
(batch_size, max_length), # 执行单步解码
pad_token_id, logits = self.decoder(decoder_input) # (batch_size, 1, vocab_size)
dtype=torch.long,
device=device
)
# 初始输入为BOS # 获取当前步的预测标记
current_input = torch.full( next_tokens = logits.argmax(dim=-1) # (batch_size, 1)
(batch_size,),
bos_token_id,
dtype=torch.long,
device=device
)
# 标记序列是否已结束(初始全为False) # 将预测结果写入输出序列
finished = torch.zeros(batch_size, dtype=torch.bool, device=device) output_seq[:, step] = next_tokens.squeeze(1)
# 更新已完成序列的掩码
finished = finished | (next_tokens.squeeze(1) == eos_token_id)
# 自回归生成序列 # 如果所有序列都已完成则提前终止
for t in range(max_length):
# 跳过已结束序列的计算
if finished.all(): if finished.all():
break break
# 获取当前时间步输出 # 准备下一步的输入:
logits = self.decoder(current_input) # (batch_size, vocab_size) # 未完成序列使用预测标记,已完成序列使用pad_token_id
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,) decoder_input = next_tokens.masked_fill(
finished.unsqueeze(1),
# 更新未结束序列的输出 pad_token_id
sequences[~finished, t] = next_tokens[~finished]
# 检测EOS标记
eos_mask = (next_tokens == eos_token_id)
finished = finished | eos_mask # 更新结束标志
# 准备下一步输入:未结束序列用新token,已结束序列用PAD
current_input = torch.where(
~finished,
next_tokens,
torch.full_like(next_tokens, pad_token_id)
) )
return sequences return output_seq

6
TryBatch.py

@ -36,11 +36,7 @@ def try_batch(config: dict, src_max: int, tgt_max: int, batch_size: int, max_bat
# training loop # training loop
for i in range(max_batch_iter): for i in range(max_batch_iter):
with autocast(config["device"], enabled=config["use_amp"]): with autocast(config["device"], enabled=config["use_amp"]):
if i % 2 == 0: output = model(srcs, tgts)
tfr = random.random()
else:
tfr = 1
output = model(srcs, tgts, tfr)
loss = nn.CrossEntropyLoss()( loss = nn.CrossEntropyLoss()(
output.reshape(-1, output.size(-1)), output.reshape(-1, output.size(-1)),
labels.contiguous().view(-1) labels.contiguous().view(-1)

BIN
__pycache__/BucketManager.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyDataset.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLSTM.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLayer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyTokenizer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/TryBatch.cpython-312.pyc

Binary file not shown.

2
inference.py

@ -141,7 +141,7 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model/checkpoints/latest_model.pt", map_location=device, weights_only=False) model = torch.load("model/checkpoints/latest_model.pt", map_location=device, weights_only=False)
# model = torch.load("model/checkpoints/model_009.pt", map_location=device, weights_only=False) # model = torch.load("model/checkpoints/model_003.pt", map_location=device, weights_only=False)
# model = torch.load("temp/latest_model.pt", map_location=device, weights_only=False) # model = torch.load("temp/latest_model.pt", map_location=device, weights_only=False)
model.eval().to(device) model.eval().to(device)

BIN
model/checkpoints/latest_checkpoint.pt (Stored with Git LFS)

Binary file not shown.

BIN
model/checkpoints/latest_model.pt (Stored with Git LFS)

Binary file not shown.

7944
results/translation_comparison.txt

File diff suppressed because it is too large

82
train.py

@ -55,51 +55,6 @@ CONFIG = {
} }
class CosineAnnealingTFR:
def __init__(self, T_max, eta_max=1.0, eta_min=0.0, last_epoch=-1, verbose=False):
self.T_max = T_max
self.eta_max = eta_max
self.eta_min = eta_min
self.verbose = verbose
self.last_epoch = last_epoch
# 初始化状态
self.current_tfr = eta_max if last_epoch == -1 else self._compute_tfr(last_epoch)
if last_epoch == -1:
self.step(0)
def _compute_tfr(self, epoch):
cos = math.cos(math.pi * epoch / self.T_max)
return self.eta_min + (self.eta_max - self.eta_min) * (1 + cos) / 2
def step(self, epoch=None):
if epoch is not None:
self.last_epoch = epoch
else:
self.last_epoch += 1
self.current_tfr = self._compute_tfr(self.last_epoch)
if self.verbose:
print(f'Epoch {self.last_epoch:5d}: TFR adjusted to {self.current_tfr:.4f}')
def get_last_tfr(self):
return self.current_tfr
def state_dict(self):
return {
'T_max': self.T_max,
'eta_max': self.eta_max,
'eta_min': self.eta_min,
'last_epoch': self.last_epoch,
'verbose': self.verbose
}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
self.current_tfr = self._compute_tfr(self.last_epoch)
def setup_logger(log_file): def setup_logger(log_file):
logger = logging.getLogger("train_logger") logger = logging.getLogger("train_logger")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -120,7 +75,7 @@ def setup_logger(log_file):
return logger return logger
def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, bucket_idx, config): def save_checkpoint(model, optimizer, scaler, lr_scheduler, epoch, bucket_idx, config):
os.makedirs(config["checkpoint_dir"], exist_ok=True) os.makedirs(config["checkpoint_dir"], exist_ok=True)
checkpoint = { checkpoint = {
'epoch': epoch, 'epoch': epoch,
@ -130,7 +85,6 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch,
'optim_state': optimizer.state_dict(), 'optim_state': optimizer.state_dict(),
'scaler_state': scaler.state_dict(), 'scaler_state': scaler.state_dict(),
'lr_scheduler_state': lr_scheduler.state_dict(), 'lr_scheduler_state': lr_scheduler.state_dict(),
'tf_scheduler_state': tf_scheduler.state_dict(),
'random_state': random.getstate(), 'random_state': random.getstate(),
'numpy_random_state': np.random.get_state(), 'numpy_random_state': np.random.get_state(),
'torch_random_state': torch.get_rng_state(), 'torch_random_state': torch.get_rng_state(),
@ -144,16 +98,15 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch,
torch.save(checkpoint, os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")) torch.save(checkpoint, os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt"))
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, config): def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, config):
checkpoint_path = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt") checkpoint_path = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")
if not os.path.exists(checkpoint_path): if not os.path.exists(checkpoint_path):
return model, optimizer, scaler, lr_scheduler, tf_scheduler, 0, 0, config["bucket_list"] return model, optimizer, scaler, lr_scheduler, 0, 0, config["bucket_list"]
checkpoint = torch.load(checkpoint_path, weights_only=False) checkpoint = torch.load(checkpoint_path, weights_only=False)
model.load_state_dict(checkpoint['model_state']) model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state']) optimizer.load_state_dict(checkpoint['optim_state'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state'])
tf_scheduler.load_state_dict(checkpoint['tf_scheduler_state'])
random.setstate(checkpoint['random_state']) random.setstate(checkpoint['random_state'])
np.random.set_state(checkpoint['numpy_random_state']) np.random.set_state(checkpoint['numpy_random_state'])
torch.set_rng_state(checkpoint['torch_random_state']) torch.set_rng_state(checkpoint['torch_random_state'])
@ -162,8 +115,7 @@ def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler,
scaler.load_state_dict(checkpoint['scaler_state']) scaler.load_state_dict(checkpoint['scaler_state'])
return ( return (
model, optimizer, scaler, model, optimizer, scaler, lr_scheduler,
lr_scheduler, tf_scheduler,
checkpoint['epoch'], checkpoint['epoch'],
checkpoint['bucket_idx'], checkpoint['bucket_idx'],
checkpoint["bucket_list"] checkpoint["bucket_list"]
@ -212,7 +164,7 @@ def validate(model, config):
labels = batch["label"].to(config["device"]) labels = batch["label"].to(config["device"])
batch_size = src.size(0) batch_size = src.size(0)
output = model(src, tgt, 1) output = model(src, tgt)
loss = nn.CrossEntropyLoss( loss = nn.CrossEntropyLoss(
ignore_index=config["pad_token_id"], ignore_index=config["pad_token_id"],
label_smoothing=config["label_smoothing"], label_smoothing=config["label_smoothing"],
@ -298,11 +250,6 @@ def train(config):
T_max=config["T_max"] * 2, T_max=config["T_max"] * 2,
eta_min=config["lr_eta_min"] eta_min=config["lr_eta_min"]
) )
tf_scheduler = CosineAnnealingTFR(
T_max=config["T_max"],
eta_max=config["tf_eta_max"],
eta_min=config["tf_eta_min"]
)
scaler = GradScaler(config["device"], enabled=config["use_amp"]) scaler = GradScaler(config["device"], enabled=config["use_amp"])
# recover from checkpoint # recover from checkpoint
@ -310,8 +257,8 @@ def train(config):
start_bucket_idx = 0 start_bucket_idx = 0
if os.path.exists(os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")): if os.path.exists(os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")):
model, optimizer, scaler, lr_scheduler, tf_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint( model, optimizer, scaler, lr_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint(
model, optimizer, scaler, lr_scheduler, tf_scheduler, config model, optimizer, scaler, lr_scheduler, config
) )
if start_epoch >= config["epochs"]: if start_epoch >= config["epochs"]:
@ -321,7 +268,7 @@ def train(config):
# main train loop # main train loop
for epoch in range(start_epoch, config["epochs"]): for epoch in range(start_epoch, config["epochs"]):
for bucket_idx in range(start_bucket_idx, len(config["bucket_list"])): for bucket_idx in range(start_bucket_idx, len(config["bucket_list"])):
if bucket_idx % len(config["bucket_list"]) == 0 and epoch != 0: if bucket_idx % len(config["bucket_list"]) == 0:
random.shuffle(config["bucket_list"]) random.shuffle(config["bucket_list"])
bucket = config["bucket_list"][bucket_idx] bucket = config["bucket_list"][bucket_idx]
@ -377,11 +324,7 @@ def train(config):
# mixed precision # mixed precision
with autocast(config["device"], enabled=config["use_amp"]): with autocast(config["device"], enabled=config["use_amp"]):
if epoch % 3 == 0: output = model(src, tgt)
tfr = 1
else:
tfr = tf_scheduler.get_last_tfr()
output = model(src, tgt, tfr)
loss = nn.CrossEntropyLoss( loss = nn.CrossEntropyLoss(
ignore_index=config["pad_token_id"], ignore_index=config["pad_token_id"],
label_smoothing=config["label_smoothing"] label_smoothing=config["label_smoothing"]
@ -403,8 +346,6 @@ def train(config):
# update scheduler after a batch # update scheduler after a batch
lr_scheduler.step() lr_scheduler.step()
if epoch % 3 != 0:
tf_scheduler.step()
# preformance eval # preformance eval
with torch.no_grad(): with torch.no_grad():
@ -422,8 +363,7 @@ def train(config):
loss=f'{loss.item():.3f}', loss=f'{loss.item():.3f}',
acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \ acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \
if total_valid > 0 else '0.000%', if total_valid > 0 else '0.000%',
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}', lr=f'{lr_scheduler.get_last_lr()[0]:.3e}'
tf=f'{tfr:.3e}'
) )
# log parquet info # log parquet info
@ -435,7 +375,7 @@ def train(config):
# save checkpoint after a bucket # save checkpoint after a bucket
save_checkpoint( save_checkpoint(
model, optimizer, scaler, lr_scheduler, tf_scheduler, model, optimizer, scaler, lr_scheduler,
epoch, bucket_idx + 1, config epoch, bucket_idx + 1, config
) )

1457
training.log

File diff suppressed because it is too large
Loading…
Cancel
Save