|
|
@ -55,51 +55,6 @@ CONFIG = { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class CosineAnnealingTFR: |
|
|
|
def __init__(self, T_max, eta_max=1.0, eta_min=0.0, last_epoch=-1, verbose=False): |
|
|
|
self.T_max = T_max |
|
|
|
self.eta_max = eta_max |
|
|
|
self.eta_min = eta_min |
|
|
|
self.verbose = verbose |
|
|
|
self.last_epoch = last_epoch |
|
|
|
|
|
|
|
# 初始化状态 |
|
|
|
self.current_tfr = eta_max if last_epoch == -1 else self._compute_tfr(last_epoch) |
|
|
|
if last_epoch == -1: |
|
|
|
self.step(0) |
|
|
|
|
|
|
|
def _compute_tfr(self, epoch): |
|
|
|
cos = math.cos(math.pi * epoch / self.T_max) |
|
|
|
return self.eta_min + (self.eta_max - self.eta_min) * (1 + cos) / 2 |
|
|
|
|
|
|
|
def step(self, epoch=None): |
|
|
|
if epoch is not None: |
|
|
|
self.last_epoch = epoch |
|
|
|
else: |
|
|
|
self.last_epoch += 1 |
|
|
|
|
|
|
|
self.current_tfr = self._compute_tfr(self.last_epoch) |
|
|
|
|
|
|
|
if self.verbose: |
|
|
|
print(f'Epoch {self.last_epoch:5d}: TFR adjusted to {self.current_tfr:.4f}') |
|
|
|
|
|
|
|
def get_last_tfr(self): |
|
|
|
return self.current_tfr |
|
|
|
|
|
|
|
def state_dict(self): |
|
|
|
return { |
|
|
|
'T_max': self.T_max, |
|
|
|
'eta_max': self.eta_max, |
|
|
|
'eta_min': self.eta_min, |
|
|
|
'last_epoch': self.last_epoch, |
|
|
|
'verbose': self.verbose |
|
|
|
} |
|
|
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
self.__dict__.update(state_dict) |
|
|
|
self.current_tfr = self._compute_tfr(self.last_epoch) |
|
|
|
|
|
|
|
|
|
|
|
def setup_logger(log_file): |
|
|
|
logger = logging.getLogger("train_logger") |
|
|
|
logger.setLevel(logging.INFO) |
|
|
@ -120,7 +75,7 @@ def setup_logger(log_file): |
|
|
|
return logger |
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, bucket_idx, config): |
|
|
|
def save_checkpoint(model, optimizer, scaler, lr_scheduler, epoch, bucket_idx, config): |
|
|
|
os.makedirs(config["checkpoint_dir"], exist_ok=True) |
|
|
|
checkpoint = { |
|
|
|
'epoch': epoch, |
|
|
@ -130,7 +85,6 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, |
|
|
|
'optim_state': optimizer.state_dict(), |
|
|
|
'scaler_state': scaler.state_dict(), |
|
|
|
'lr_scheduler_state': lr_scheduler.state_dict(), |
|
|
|
'tf_scheduler_state': tf_scheduler.state_dict(), |
|
|
|
'random_state': random.getstate(), |
|
|
|
'numpy_random_state': np.random.get_state(), |
|
|
|
'torch_random_state': torch.get_rng_state(), |
|
|
@ -144,16 +98,15 @@ def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, |
|
|
|
torch.save(checkpoint, os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")) |
|
|
|
|
|
|
|
|
|
|
|
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, config): |
|
|
|
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, config): |
|
|
|
checkpoint_path = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt") |
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
|
return model, optimizer, scaler, lr_scheduler, tf_scheduler, 0, 0, config["bucket_list"] |
|
|
|
return model, optimizer, scaler, lr_scheduler, 0, 0, config["bucket_list"] |
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, weights_only=False) |
|
|
|
model.load_state_dict(checkpoint['model_state']) |
|
|
|
optimizer.load_state_dict(checkpoint['optim_state']) |
|
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state']) |
|
|
|
tf_scheduler.load_state_dict(checkpoint['tf_scheduler_state']) |
|
|
|
random.setstate(checkpoint['random_state']) |
|
|
|
np.random.set_state(checkpoint['numpy_random_state']) |
|
|
|
torch.set_rng_state(checkpoint['torch_random_state']) |
|
|
@ -162,8 +115,7 @@ def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, |
|
|
|
scaler.load_state_dict(checkpoint['scaler_state']) |
|
|
|
|
|
|
|
return ( |
|
|
|
model, optimizer, scaler, |
|
|
|
lr_scheduler, tf_scheduler, |
|
|
|
model, optimizer, scaler, lr_scheduler, |
|
|
|
checkpoint['epoch'], |
|
|
|
checkpoint['bucket_idx'], |
|
|
|
checkpoint["bucket_list"] |
|
|
@ -212,7 +164,7 @@ def validate(model, config): |
|
|
|
labels = batch["label"].to(config["device"]) |
|
|
|
batch_size = src.size(0) |
|
|
|
|
|
|
|
output = model(src, tgt, 1) |
|
|
|
output = model(src, tgt) |
|
|
|
loss = nn.CrossEntropyLoss( |
|
|
|
ignore_index=config["pad_token_id"], |
|
|
|
label_smoothing=config["label_smoothing"], |
|
|
@ -298,11 +250,6 @@ def train(config): |
|
|
|
T_max=config["T_max"] * 2, |
|
|
|
eta_min=config["lr_eta_min"] |
|
|
|
) |
|
|
|
tf_scheduler = CosineAnnealingTFR( |
|
|
|
T_max=config["T_max"], |
|
|
|
eta_max=config["tf_eta_max"], |
|
|
|
eta_min=config["tf_eta_min"] |
|
|
|
) |
|
|
|
scaler = GradScaler(config["device"], enabled=config["use_amp"]) |
|
|
|
|
|
|
|
# recover from checkpoint |
|
|
@ -310,8 +257,8 @@ def train(config): |
|
|
|
start_bucket_idx = 0 |
|
|
|
|
|
|
|
if os.path.exists(os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")): |
|
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint( |
|
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler, config |
|
|
|
model, optimizer, scaler, lr_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint( |
|
|
|
model, optimizer, scaler, lr_scheduler, config |
|
|
|
) |
|
|
|
|
|
|
|
if start_epoch >= config["epochs"]: |
|
|
@ -321,7 +268,7 @@ def train(config): |
|
|
|
# main train loop |
|
|
|
for epoch in range(start_epoch, config["epochs"]): |
|
|
|
for bucket_idx in range(start_bucket_idx, len(config["bucket_list"])): |
|
|
|
if bucket_idx % len(config["bucket_list"]) == 0 and epoch != 0: |
|
|
|
if bucket_idx % len(config["bucket_list"]) == 0: |
|
|
|
random.shuffle(config["bucket_list"]) |
|
|
|
|
|
|
|
bucket = config["bucket_list"][bucket_idx] |
|
|
@ -377,11 +324,7 @@ def train(config): |
|
|
|
|
|
|
|
# mixed precision |
|
|
|
with autocast(config["device"], enabled=config["use_amp"]): |
|
|
|
if epoch % 3 == 0: |
|
|
|
tfr = 1 |
|
|
|
else: |
|
|
|
tfr = tf_scheduler.get_last_tfr() |
|
|
|
output = model(src, tgt, tfr) |
|
|
|
output = model(src, tgt) |
|
|
|
loss = nn.CrossEntropyLoss( |
|
|
|
ignore_index=config["pad_token_id"], |
|
|
|
label_smoothing=config["label_smoothing"] |
|
|
@ -403,8 +346,6 @@ def train(config): |
|
|
|
|
|
|
|
# update scheduler after a batch |
|
|
|
lr_scheduler.step() |
|
|
|
if epoch % 3 != 0: |
|
|
|
tf_scheduler.step() |
|
|
|
|
|
|
|
# preformance eval |
|
|
|
with torch.no_grad(): |
|
|
@ -422,8 +363,7 @@ def train(config): |
|
|
|
loss=f'{loss.item():.3f}', |
|
|
|
acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \ |
|
|
|
if total_valid > 0 else '0.000%', |
|
|
|
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}', |
|
|
|
tf=f'{tfr:.3e}' |
|
|
|
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}' |
|
|
|
) |
|
|
|
|
|
|
|
# log parquet info |
|
|
@ -435,7 +375,7 @@ def train(config): |
|
|
|
|
|
|
|
# save checkpoint after a bucket |
|
|
|
save_checkpoint( |
|
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler, |
|
|
|
model, optimizer, scaler, lr_scheduler, |
|
|
|
epoch, bucket_idx + 1, config |
|
|
|
) |
|
|
|
|
|
|
|