# try_batch.py import torch import torch.nn as nn from torch.amp import autocast, GradScaler from MyLSTM import Seq2SeqLSTM import multiprocessing import gc import math import random def try_batch(config: dict, src_max: int, tgt_max: int, batch_size: int, max_batch_iter: int = 2): # init model, optimizer, scaler model = Seq2SeqLSTM( vocab_size=config["vocab_size"], embedding_dim=config["embedding_dim"], padding_idx=config["pad_token_id"], num_layers=config["num_layers"], dropout=config["dropout"], ).to(config["device"]) optimizer = torch.optim.AdamW( model.parameters(), lr=config["lr_eta_max"], betas=config["betas"], weight_decay=config["weight_decay"] ) scaler = GradScaler(config["device"], enabled=config["use_amp"]) # construct fake batch for testing srcs = torch.full((1, src_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) tgts = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) labels = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) # training loop for i in range(max_batch_iter): with autocast(config["device"], enabled=config["use_amp"]): output = model(srcs, tgts) loss = nn.CrossEntropyLoss()( output.reshape(-1, output.size(-1)), labels.contiguous().view(-1) ) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=config["grad_clip"] ) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) def _worker(config, src_max, tgt_max, batch_size, max_batch_iter, result_queue): try: torch.cuda.reset_peak_memory_stats() try_batch(config, src_max, tgt_max, batch_size, max_batch_iter) peak_mem = torch.cuda.max_memory_allocated() result_queue.put(('success', peak_mem)) except Exception as e: if 'out of memory' in str(e).lower() or 'alloc' in str(e).lower(): peak_mem = torch.cuda.max_memory_allocated() result_queue.put(('oom', peak_mem)) else: result_queue.put(('error', type(e).__name__, str(e))) finally: gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() def find_batch_size( src_max: int, tgt_max: int, config: dict, initial_batch_size: int = 16384, max_batch_iter: int = 2, timeout: int = 30 ) -> int: low, high = 1, initial_batch_size best_size = 0 max_attempts = math.ceil(math.log2(initial_batch_size)) attempt = 0 while low <= high and attempt < max_attempts: attempt += 1 mid = (low + high) // 2 ctx = multiprocessing.get_context('spawn') result_queue = ctx.Queue() process = ctx.Process( target=_worker, args=(config, src_max, tgt_max, mid, max_batch_iter, result_queue) ) try: torch.cuda.reset_peak_memory_stats() process.start() process.join(timeout=timeout) if process.is_alive(): process.terminate() process.join() raise TimeoutError(f"Batch size {mid} timed out after {timeout}s") if result_queue.empty(): raise RuntimeError("Subprocess exited without returning results") status, data = result_queue.get() if status == 'success': current_peak = data / (1024 ** 3) print(f"Attempt {attempt:3d}: {mid:5d} | OK (Peak Memory: {current_peak:5.2f}GB)") best_size = mid low = mid + 1 elif status == 'oom': current_peak = data / (1024 ** 3) print(f"Attempt {attempt:3d}: {mid:5d} | Failed (OOM, Peak: {current_peak:5.2f}GB)") high = mid - 1 elif status == 'error': exc_type, exc_msg = data print(f"Attempt {attempt:3d}: {mid:5d} | CRITICAL ERROR: {exc_type} - {exc_msg}") raise RuntimeError(f"Critical error at batch size {mid}: {exc_type} - {exc_msg}") except Exception as e: print(f"Attempt {attempt:3d}: {mid:5d} | Error: {str(e)}") high = mid - 1 finally: gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() return best_size