You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
464 lines
16 KiB
464 lines
16 KiB
# train.py
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.amp import autocast, GradScaler
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from MyDataset import MyDataset, collate_fn
|
|
from functools import partial
|
|
from MyLSTM import Seq2SeqLSTM
|
|
from MyTokenizer import load_tokenizer
|
|
import os
|
|
from tqdm import tqdm
|
|
import gc
|
|
import logging
|
|
import math
|
|
from BucketManager import BucketManager
|
|
import random
|
|
import numpy as np
|
|
from TryBatch import find_batch_size
|
|
|
|
|
|
CONFIG = {
|
|
# dir
|
|
"train_cache_dir": "data/cache/train",
|
|
"valid_cache_dir": "data/cache/valid",
|
|
"checkpoint_dir": "model/checkpoints",
|
|
|
|
# data settings
|
|
"num_workers": 20,
|
|
"safty_factor": 0.7,
|
|
|
|
# model config
|
|
"embedding_dim": 768,
|
|
"num_layers": 4,
|
|
"dropout": 0.01,
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
|
|
# optimizer and scheduler config
|
|
"epochs": 8,
|
|
"lr_eta_max": 1e-3,
|
|
"lr_eta_min": 1e-5,
|
|
"tf_eta_max": 1,
|
|
"tf_eta_min": 0.5,
|
|
"betas": (0.9, 0.999),
|
|
"weight_decay": 1e-4,
|
|
|
|
"label_smoothing": 0.1,
|
|
|
|
# grad and precision
|
|
"use_amp": True,
|
|
"grad_clip": 1.0,
|
|
|
|
# logging
|
|
"log_file": "training.log",
|
|
}
|
|
|
|
|
|
class CosineAnnealingTFR:
|
|
def __init__(self, T_max, eta_max=1.0, eta_min=0.0, last_epoch=-1, verbose=False):
|
|
self.T_max = T_max
|
|
self.eta_max = eta_max
|
|
self.eta_min = eta_min
|
|
self.verbose = verbose
|
|
self.last_epoch = last_epoch
|
|
|
|
# 初始化状态
|
|
self.current_tfr = eta_max if last_epoch == -1 else self._compute_tfr(last_epoch)
|
|
if last_epoch == -1:
|
|
self.step(0)
|
|
|
|
def _compute_tfr(self, epoch):
|
|
cos = math.cos(math.pi * epoch / self.T_max)
|
|
return self.eta_min + (self.eta_max - self.eta_min) * (1 + cos) / 2
|
|
|
|
def step(self, epoch=None):
|
|
if epoch is not None:
|
|
self.last_epoch = epoch
|
|
else:
|
|
self.last_epoch += 1
|
|
|
|
self.current_tfr = self._compute_tfr(self.last_epoch)
|
|
|
|
if self.verbose:
|
|
print(f'Epoch {self.last_epoch:5d}: TFR adjusted to {self.current_tfr:.4f}')
|
|
|
|
def get_last_tfr(self):
|
|
return self.current_tfr
|
|
|
|
def state_dict(self):
|
|
return {
|
|
'T_max': self.T_max,
|
|
'eta_max': self.eta_max,
|
|
'eta_min': self.eta_min,
|
|
'last_epoch': self.last_epoch,
|
|
'verbose': self.verbose
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.__dict__.update(state_dict)
|
|
self.current_tfr = self._compute_tfr(self.last_epoch)
|
|
|
|
|
|
def setup_logger(log_file):
|
|
logger = logging.getLogger("train_logger")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter(
|
|
fmt="%(asctime)s | %(levelname)s | %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
fh = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
|
fh.setFormatter(formatter)
|
|
|
|
ch = logging.StreamHandler()
|
|
ch.setFormatter(formatter)
|
|
|
|
logger.addHandler(fh)
|
|
logger.addHandler(ch)
|
|
return logger
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, epoch, bucket_idx, config):
|
|
os.makedirs(config["checkpoint_dir"], exist_ok=True)
|
|
checkpoint = {
|
|
'epoch': epoch,
|
|
'bucket_idx': bucket_idx,
|
|
'bucket_list': config["bucket_list"],
|
|
'model_state': model.state_dict(),
|
|
'optim_state': optimizer.state_dict(),
|
|
'scaler_state': scaler.state_dict(),
|
|
'lr_scheduler_state': lr_scheduler.state_dict(),
|
|
'tf_scheduler_state': tf_scheduler.state_dict(),
|
|
'random_state': random.getstate(),
|
|
'numpy_random_state': np.random.get_state(),
|
|
'torch_random_state': torch.get_rng_state(),
|
|
}
|
|
# if whole epoch trained, save extra checkpoint
|
|
if bucket_idx == len(config["bucket_list"]):
|
|
torch.save(model, os.path.join(config["checkpoint_dir"], f"model_{epoch:03d}.pt"))
|
|
torch.save(checkpoint, os.path.join(config["checkpoint_dir"], f"checkpoint_{epoch:03d}.pt"))
|
|
# save latest
|
|
torch.save(model, os.path.join(config["checkpoint_dir"], "latest_model.pt"))
|
|
torch.save(checkpoint, os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt"))
|
|
|
|
|
|
def load_latest_checkpoint(model, optimizer, scaler, lr_scheduler, tf_scheduler, config):
|
|
checkpoint_path = os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")
|
|
if not os.path.exists(checkpoint_path):
|
|
return model, optimizer, scaler, lr_scheduler, tf_scheduler, 0, 0, config["bucket_list"]
|
|
|
|
checkpoint = torch.load(checkpoint_path, weights_only=False)
|
|
model.load_state_dict(checkpoint['model_state'])
|
|
optimizer.load_state_dict(checkpoint['optim_state'])
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state'])
|
|
tf_scheduler.load_state_dict(checkpoint['tf_scheduler_state'])
|
|
random.setstate(checkpoint['random_state'])
|
|
np.random.set_state(checkpoint['numpy_random_state'])
|
|
torch.set_rng_state(checkpoint['torch_random_state'])
|
|
|
|
if scaler and 'scaler_state' in checkpoint:
|
|
scaler.load_state_dict(checkpoint['scaler_state'])
|
|
|
|
return (
|
|
model, optimizer, scaler,
|
|
lr_scheduler, tf_scheduler,
|
|
checkpoint['epoch'],
|
|
checkpoint['bucket_idx'],
|
|
checkpoint["bucket_list"]
|
|
)
|
|
|
|
|
|
def validate(model, config):
|
|
dataset = MyDataset(
|
|
cache_path="data/cache/valid",
|
|
src_range=(0, 128),
|
|
tgt_range=(0, 128),
|
|
shard_idx=0
|
|
)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=512,
|
|
shuffle=False,
|
|
num_workers=config["num_workers"],
|
|
collate_fn=partial(
|
|
collate_fn,
|
|
pad_token_id=config["pad_token_id"],
|
|
src_max_length=128,
|
|
tgt_max_length=128,
|
|
)
|
|
)
|
|
|
|
model.eval()
|
|
total_loss = 0.0
|
|
total_correct = 0
|
|
total_valid_tokens = 0
|
|
total_samples = 0
|
|
|
|
with torch.no_grad():
|
|
loop = tqdm(
|
|
dataloader,
|
|
unit="batch",
|
|
colour="blue",
|
|
desc="Validating",
|
|
bar_format='{l_bar}{bar:32}{r_bar}',
|
|
dynamic_ncols=True,
|
|
leave=False
|
|
)
|
|
for batch in loop:
|
|
src = batch["src"].to(config["device"])
|
|
tgt = batch["tgt"].to(config["device"])
|
|
labels = batch["label"].to(config["device"])
|
|
batch_size = src.size(0)
|
|
|
|
output = model(src, tgt, 1)
|
|
loss = nn.CrossEntropyLoss(
|
|
ignore_index=config["pad_token_id"],
|
|
label_smoothing=config["label_smoothing"],
|
|
)(
|
|
output.reshape(-1, output.size(-1)),
|
|
labels.contiguous().view(-1)
|
|
)
|
|
|
|
preds = torch.argmax(output, dim=-1)
|
|
mask = (labels != config["pad_token_id"])
|
|
correct = (preds[mask] == labels[mask]).sum()
|
|
|
|
total_valid = mask.sum()
|
|
total_loss += loss.item() * batch_size
|
|
total_correct += correct.item()
|
|
total_valid_tokens += total_valid.item()
|
|
total_samples += batch_size
|
|
|
|
loop.set_postfix(
|
|
loss=f'{loss.item():.3f}',
|
|
acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \
|
|
if total_valid > 0 else '0.000%',
|
|
)
|
|
|
|
avg_loss = total_loss / total_samples
|
|
avg_acc = total_correct / total_valid_tokens if total_valid_tokens > 0 else 0.0
|
|
del dataset, dataloader
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
return avg_loss, avg_acc
|
|
|
|
|
|
def train(config):
|
|
logger = setup_logger(config["log_file"])
|
|
|
|
# init tokenizer
|
|
en_tokenizer = load_tokenizer("model/tokenizers/en")
|
|
zh_tokenizer = load_tokenizer("model/tokenizers/zh")
|
|
assert \
|
|
zh_tokenizer.bos_token_id == en_tokenizer.bos_token_id and \
|
|
zh_tokenizer.eos_token_id == en_tokenizer.eos_token_id and \
|
|
zh_tokenizer.unk_token_id == en_tokenizer.unk_token_id and \
|
|
zh_tokenizer.pad_token_id == en_tokenizer.pad_token_id and \
|
|
zh_tokenizer.vocab_size == en_tokenizer.vocab_size
|
|
|
|
config["bos_token_id"] = zh_tokenizer.bos_token_id
|
|
config["eos_token_id"] = zh_tokenizer.eos_token_id
|
|
config["unk_token_id"] = zh_tokenizer.unk_token_id
|
|
config["pad_token_id"] = zh_tokenizer.pad_token_id
|
|
config["vocab_size"] = zh_tokenizer.vocab_size
|
|
|
|
# init bucket manager
|
|
buckets = BucketManager(
|
|
cache_path=config["train_cache_dir"],
|
|
force_rebuild=False
|
|
)
|
|
|
|
# buckets.reset_batch_size()
|
|
if not buckets.found_optimal():
|
|
buckets.find_optimal_batch_size(partial(find_batch_size, config=config))
|
|
|
|
config["bucket_list"] = buckets.get_info()
|
|
config["T_max"] = buckets.get_total_iterations(safety_factor=config["safty_factor"])
|
|
|
|
# init model
|
|
model = Seq2SeqLSTM(
|
|
vocab_size=config["vocab_size"],
|
|
embedding_dim=config["embedding_dim"],
|
|
padding_idx=config["pad_token_id"],
|
|
num_layers=config["num_layers"],
|
|
dropout=config["dropout"],
|
|
).to(config["device"])
|
|
|
|
# init optimier and scheduler
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=config["lr_eta_max"],
|
|
betas=config["betas"],
|
|
weight_decay=config["weight_decay"]
|
|
)
|
|
lr_scheduler = CosineAnnealingLR(
|
|
optimizer,
|
|
T_max=config["T_max"] * 2,
|
|
eta_min=config["lr_eta_min"]
|
|
)
|
|
tf_scheduler = CosineAnnealingTFR(
|
|
T_max=config["T_max"],
|
|
eta_max=config["tf_eta_max"],
|
|
eta_min=config["tf_eta_min"]
|
|
)
|
|
scaler = GradScaler(config["device"], enabled=config["use_amp"])
|
|
|
|
# recover from checkpoint
|
|
start_epoch = 0
|
|
start_bucket_idx = 0
|
|
|
|
if os.path.exists(os.path.join(config["checkpoint_dir"], "latest_checkpoint.pt")):
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler, start_epoch, start_bucket_idx, config["bucket_list"] = load_latest_checkpoint(
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler, config
|
|
)
|
|
|
|
if start_epoch >= config["epochs"]:
|
|
logger.info("Training already completed.")
|
|
return
|
|
|
|
# main train loop
|
|
for epoch in range(start_epoch, config["epochs"]):
|
|
for bucket_idx in range(start_bucket_idx, len(config["bucket_list"])):
|
|
if bucket_idx % len(config["bucket_list"]) == 0 and epoch != 0:
|
|
random.shuffle(config["bucket_list"])
|
|
|
|
bucket = config["bucket_list"][bucket_idx]
|
|
|
|
logger.info(
|
|
f"Processing bucket {bucket_idx + 1}/{len(config['bucket_list'])} " + \
|
|
f"src: {str(bucket['src_range'])} tgt: {str(bucket['tgt_range'])} " + \
|
|
f"shard: {str(bucket['shard_idx'] + 1)}/{str(bucket['num_shards'])}"
|
|
)
|
|
|
|
dataset = MyDataset(
|
|
cache_path=config["train_cache_dir"],
|
|
src_range=bucket["src_range"],
|
|
tgt_range=bucket["tgt_range"],
|
|
shard_idx=bucket["shard_idx"]
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=math.floor(bucket["suggested_batch_size"] * config["safty_factor"]),
|
|
num_workers=config["num_workers"],
|
|
collate_fn=partial(
|
|
collate_fn,
|
|
pad_token_id=config["pad_token_id"],
|
|
src_max_length=bucket["src_range"][1],
|
|
tgt_max_length=bucket["tgt_range"][1],
|
|
),
|
|
shuffle=False,
|
|
)
|
|
|
|
# train loop
|
|
model.train()
|
|
total_loss = 0.0
|
|
total_correct = 0
|
|
total_valid_tokens = 0
|
|
total_samples = 0
|
|
|
|
loop = tqdm(
|
|
dataloader,
|
|
unit="batch",
|
|
colour="green",
|
|
desc=f"Epoch {epoch + 1} | Bucket {bucket_idx + 1}",
|
|
bar_format='{l_bar}{bar:32}{r_bar}',
|
|
dynamic_ncols=True,
|
|
leave=False,
|
|
)
|
|
|
|
for batch in loop:
|
|
src = batch["src"].to(config["device"])
|
|
tgt = batch["tgt"].to(config["device"])
|
|
labels = batch["label"].to(config["device"])
|
|
batch_size = src.size(0)
|
|
|
|
# mixed precision
|
|
with autocast(config["device"], enabled=config["use_amp"]):
|
|
if epoch % 3 == 0:
|
|
tfr = 1
|
|
else:
|
|
tfr = tf_scheduler.get_last_tfr()
|
|
output = model(src, tgt, tfr)
|
|
loss = nn.CrossEntropyLoss(
|
|
ignore_index=config["pad_token_id"],
|
|
label_smoothing=config["label_smoothing"]
|
|
)(
|
|
output.reshape(-1, output.size(-1)),
|
|
labels.contiguous().view(-1)
|
|
)
|
|
|
|
# grad
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(
|
|
model.parameters(),
|
|
max_norm=config["grad_clip"]
|
|
)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
# update scheduler after a batch
|
|
lr_scheduler.step()
|
|
if epoch % 3 != 0:
|
|
tf_scheduler.step()
|
|
|
|
# preformance eval
|
|
with torch.no_grad():
|
|
preds = torch.argmax(output, dim=-1)
|
|
mask = (labels != config["pad_token_id"])
|
|
correct = (preds[mask] == labels[mask]).sum()
|
|
|
|
total_valid = mask.sum()
|
|
total_loss += loss.item() * batch_size
|
|
total_correct += correct.item()
|
|
total_valid_tokens += total_valid.item()
|
|
total_samples += batch_size
|
|
|
|
loop.set_postfix(
|
|
loss=f'{loss.item():.3f}',
|
|
acc=f'{(correct.float() / total_valid).item() * 100:.3f}%' \
|
|
if total_valid > 0 else '0.000%',
|
|
lr=f'{lr_scheduler.get_last_lr()[0]:.3e}',
|
|
tf=f'{tfr:.3e}'
|
|
)
|
|
|
|
# log parquet info
|
|
avg_loss = total_loss / total_samples
|
|
avg_acc = total_correct / total_valid_tokens if total_valid_tokens > 0 else 0.0
|
|
logger.info(
|
|
f"Epoch: {epoch + 1} | Bucket: {bucket_idx + 1} | Loss: {avg_loss:.3f} | Acc: {avg_acc * 100:.3f}%"
|
|
)
|
|
|
|
# save checkpoint after a bucket
|
|
save_checkpoint(
|
|
model, optimizer, scaler, lr_scheduler, tf_scheduler,
|
|
epoch, bucket_idx + 1, config
|
|
)
|
|
|
|
# clean up memory
|
|
del dataset, dataloader
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# validate after a bucket
|
|
valid_loss, valid_acc = validate(model, config)
|
|
logger.info(f"Validation | Loss: {valid_loss:.3f} | Acc: {valid_acc * 100:.3f}%")
|
|
|
|
# reset index
|
|
start_bucket_idx = 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
seed = 42
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
train(CONFIG)
|
|
|