189 changed files with 183988 additions and 2 deletions
@ -0,0 +1,3 @@ |
|||
*.pt filter=lfs diff=lfs merge=lfs -text |
|||
*.zst filter=lfs diff=lfs merge=lfs -text |
|||
*.parquet filter=lfs diff=lfs merge=lfs -text |
@ -0,0 +1,16 @@ |
|||
{ |
|||
// 使用 IntelliSense 了解相关属性。 |
|||
// 悬停以查看现有属性的描述。 |
|||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 |
|||
"version": "0.2.0", |
|||
"configurations": [ |
|||
{ |
|||
"name": "Python: 当前文件", |
|||
"type": "debugpy", |
|||
"request": "launch", |
|||
"program": "${file}", |
|||
"console": "integratedTerminal", |
|||
"cwd": "${fileDirname}" |
|||
} |
|||
] |
|||
} |
@ -0,0 +1,6 @@ |
|||
{ |
|||
"editor.formatOnSave": false, |
|||
"editor.defaultFormatter": "charliermarsh.ruff", |
|||
"ruff.lint.ignore": ["E501", "E502", "W293", "W391", "Q", "ANN", "D", "ERA", "PLR"], |
|||
"ruff.lint.select": ["E", "W", "F"] |
|||
} |
@ -0,0 +1,423 @@ |
|||
import json |
|||
from typing import List, Optional, Generator, Dict, Tuple |
|||
from collections import defaultdict, Counter |
|||
from pathlib import Path |
|||
import shutil |
|||
import math |
|||
from MyDataset import MyDataset |
|||
|
|||
|
|||
class BucketManager: |
|||
def __init__( |
|||
self, |
|||
cache_path: str, |
|||
min_samples: int = 1024, |
|||
max_samples: int = 524288, |
|||
num_cuts: int = 16, |
|||
processed_data: Optional[List[dict]] = None, |
|||
force_rebuild: bool = False, |
|||
): |
|||
self._validate_init_params(cache_path, processed_data, force_rebuild) |
|||
|
|||
self.cache_path = Path(cache_path) |
|||
self.num_cuts = num_cuts |
|||
self.min_samples = min_samples |
|||
self.max_samples = max_samples |
|||
|
|||
self.buckets = defaultdict(lambda: defaultdict(list)) |
|||
self.valid_buckets = [] |
|||
self.total_samples = 0 |
|||
self.src_buckets = [] |
|||
self.tgt_buckets = [] |
|||
|
|||
self.total_original_samples = 0 |
|||
self.discarded_samples = 0 |
|||
self.total_padding = 0 |
|||
self.total_actual_tokens = 0 |
|||
|
|||
if processed_data is not None: |
|||
if not force_rebuild and self._metadata_exists(): |
|||
self._load_from_metadata() |
|||
else: |
|||
if force_rebuild and self.cache_path.exists(): |
|||
shutil.rmtree(self.cache_path) |
|||
self.cache_path.mkdir(parents=True, exist_ok=True) |
|||
self._process_data_with_dp(processed_data) |
|||
self._save_metadata() |
|||
else: |
|||
if not self._metadata_exists(): |
|||
raise FileNotFoundError(f"No cache found at {cache_path}") |
|||
self._load_from_metadata() |
|||
|
|||
def _validate_init_params( |
|||
self, |
|||
cache_path: str, |
|||
processed_data: Optional[List[dict]], |
|||
force_rebuild: bool |
|||
): |
|||
if not cache_path: |
|||
raise ValueError("cache_path cannot be empty") |
|||
|
|||
if processed_data is None and not force_rebuild and not Path(cache_path).exists(): |
|||
raise FileNotFoundError(f"Cache path missing: {cache_path}") |
|||
|
|||
@staticmethod |
|||
def _optimal_1d_partition(lengths: List[int], num_cuts: int) -> List[Tuple[int, int]]: |
|||
if not lengths: |
|||
return [] |
|||
|
|||
length_counts = Counter(lengths) |
|||
unique_lengths = sorted(length_counts.keys()) |
|||
|
|||
if len(unique_lengths) <= num_cuts: |
|||
buckets = [] |
|||
for i, length in enumerate(unique_lengths): |
|||
start = length |
|||
end = unique_lengths[i + 1] if i + 1 < len(unique_lengths) else length + 1 |
|||
buckets.append((start, end)) |
|||
return buckets |
|||
|
|||
n = len(unique_lengths) |
|||
|
|||
dp = [[float('inf')] * (num_cuts + 1) for _ in range(n + 1)] |
|||
parent = [[-1] * (num_cuts + 1) for _ in range(n + 1)] |
|||
|
|||
dp[0][0] = 0 |
|||
|
|||
for i in range(1, n + 1): |
|||
for j in range(1, min(i + 1, num_cuts + 1)): |
|||
for k in range(j - 1, i): |
|||
if dp[k][j - 1] == float('inf'): |
|||
continue |
|||
|
|||
bucket_start = unique_lengths[k] |
|||
bucket_end = unique_lengths[i - 1] + 1 |
|||
bucket_max = unique_lengths[i - 1] |
|||
|
|||
padding_in_bucket = 0 |
|||
for idx in range(k, i): |
|||
length = unique_lengths[idx] |
|||
count = length_counts[length] |
|||
padding_in_bucket += (bucket_max - length) * count |
|||
|
|||
total_padding = dp[k][j - 1] + padding_in_bucket |
|||
|
|||
if total_padding < dp[i][j]: |
|||
dp[i][j] = total_padding |
|||
parent[i][j] = k |
|||
|
|||
if dp[n][num_cuts] == float('inf'): |
|||
return [(min(lengths), max(lengths) + 1)] |
|||
|
|||
buckets = [] |
|||
i, j = n, num_cuts |
|||
|
|||
while j > 0 and i > 0: |
|||
k = parent[i][j] |
|||
if k < 0 or k >= i: |
|||
break |
|||
bucket_start = unique_lengths[k] |
|||
bucket_end = unique_lengths[i - 1] + 1 |
|||
buckets.append((bucket_start, bucket_end)) |
|||
i, j = k, j - 1 |
|||
|
|||
buckets.reverse() |
|||
return buckets |
|||
|
|||
def _create_2d_buckets( |
|||
self, |
|||
data: List[Tuple[int, int]] |
|||
) -> List[Tuple[Tuple[int, int], Tuple[int, int]]]: |
|||
if not data: |
|||
return [] |
|||
|
|||
src_lengths = [item[0] for item in data] |
|||
tgt_lengths = [item[1] for item in data] |
|||
|
|||
print("Calculating optimal buckets for src lengths...") |
|||
self.src_buckets = self._optimal_1d_partition(src_lengths, self.num_cuts) |
|||
print(f"src buckets: {self.src_buckets}") |
|||
|
|||
print("Calculating optimal buckets for tgt lengths...") |
|||
self.tgt_buckets = self._optimal_1d_partition(tgt_lengths, self.num_cuts) |
|||
print(f"tgt buckets: {self.tgt_buckets}") |
|||
|
|||
bucket_samples = {} |
|||
for src_bucket in self.src_buckets: |
|||
for tgt_bucket in self.tgt_buckets: |
|||
bucket_key = (src_bucket, tgt_bucket) |
|||
bucket_samples[bucket_key] = [] |
|||
|
|||
for src_len, tgt_len in data: |
|||
src_bucket = None |
|||
tgt_bucket = None |
|||
|
|||
for bucket in self.src_buckets: |
|||
if bucket[0] <= src_len < bucket[1]: |
|||
src_bucket = bucket |
|||
break |
|||
|
|||
for bucket in self.tgt_buckets: |
|||
if bucket[0] <= tgt_len < bucket[1]: |
|||
tgt_bucket = bucket |
|||
break |
|||
|
|||
if src_bucket and tgt_bucket: |
|||
bucket_key = (src_bucket, tgt_bucket) |
|||
bucket_samples[bucket_key].append((src_len, tgt_len)) |
|||
|
|||
valid_buckets = [] |
|||
|
|||
for bucket_key, samples in bucket_samples.items(): |
|||
if len(samples) >= self.min_samples: |
|||
valid_buckets.append(bucket_key) |
|||
|
|||
return valid_buckets |
|||
|
|||
def _process_data_with_dp(self, data: List[dict]): |
|||
length_pairs = [(len(item["src"]), len(item["tgt"])) for item in data] |
|||
self.total_original_samples = len(data) |
|||
|
|||
valid_buckets = self._create_2d_buckets(length_pairs) |
|||
|
|||
bucket_data = {bucket: [] for bucket in valid_buckets} |
|||
|
|||
total_actual_tokens = 0 |
|||
total_padding = 0 |
|||
used_samples = 0 |
|||
|
|||
for item in data: |
|||
src_len = len(item["src"]) |
|||
tgt_len = len(item["tgt"]) |
|||
found_bucket = False |
|||
|
|||
for (src_start, src_end), (tgt_start, tgt_end) in valid_buckets: |
|||
if src_start <= src_len < src_end and tgt_start <= tgt_len < tgt_end: |
|||
src_max = src_end - 1 |
|||
tgt_max = tgt_end - 1 |
|||
src_pad = src_max - src_len |
|||
tgt_pad = tgt_max - tgt_len |
|||
|
|||
total_padding += src_pad + tgt_pad |
|||
total_actual_tokens += src_len + tgt_len |
|||
used_samples += 1 |
|||
|
|||
bucket_data[(src_start, src_end), (tgt_start, tgt_end)].append(item) |
|||
found_bucket = True |
|||
break |
|||
|
|||
if not found_bucket: |
|||
pass |
|||
|
|||
self.discarded_samples = len(data) - used_samples |
|||
self.total_padding = total_padding |
|||
self.total_actual_tokens = total_actual_tokens |
|||
|
|||
print("\nBuilding datasets from buckets...") |
|||
total_samples = 0 |
|||
self.valid_buckets = [] |
|||
|
|||
sorted_buckets = sorted(valid_buckets, key=lambda x: (x[1][0], x[0][0])) |
|||
|
|||
for bucket_key in sorted_buckets: |
|||
(src_start, src_end), (tgt_start, tgt_end) = bucket_key |
|||
bucket_items = bucket_data[bucket_key] |
|||
data_len = len(bucket_items) |
|||
|
|||
base_num_shards = max(1, (data_len + self.max_samples - 1) // self.max_samples) |
|||
|
|||
last_shard_size = data_len % self.max_samples |
|||
|
|||
if last_shard_size == 0 and data_len > 0: |
|||
last_shard_size = self.max_samples |
|||
|
|||
if base_num_shards > 1 and last_shard_size <= self.max_samples * 0.5: |
|||
num_shards = base_num_shards - 1 |
|||
else: |
|||
num_shards = base_num_shards |
|||
|
|||
for shard_idx in range(num_shards): |
|||
if shard_idx == num_shards - 1: |
|||
start_idx = shard_idx * self.max_samples |
|||
end_idx = data_len |
|||
else: |
|||
start_idx = shard_idx * self.max_samples |
|||
end_idx = start_idx + self.max_samples |
|||
|
|||
shard_data = bucket_items[start_idx:end_idx] |
|||
shard_size = len(shard_data) |
|||
|
|||
MyDataset( |
|||
cache_path=str(self.cache_path), |
|||
processed_data=shard_data, |
|||
src_range=(src_start, src_end), |
|||
tgt_range=(tgt_start, tgt_end), |
|||
shard_idx=shard_idx, |
|||
) |
|||
|
|||
bucket_info = { |
|||
"src_range": (src_start, src_end), |
|||
"tgt_range": (tgt_start, tgt_end), |
|||
"shard_idx": shard_idx, |
|||
"num_shards": num_shards, |
|||
"suggested_batch_size": 0, |
|||
"num_samples": shard_size |
|||
} |
|||
self.valid_buckets.append(bucket_info) |
|||
total_samples += shard_size |
|||
self.total_samples = total_samples |
|||
print(f"Valid buckets: {len(self.valid_buckets)}") |
|||
print(f"Total samples: {total_samples}") |
|||
|
|||
def _save_metadata(self): |
|||
meta = { |
|||
"num_cuts": self.num_cuts, |
|||
"valid_buckets": [{ |
|||
"src_range": list(b["src_range"]), |
|||
"tgt_range": list(b["tgt_range"]), |
|||
"shard_idx": b["shard_idx"], |
|||
"num_shards": b["num_shards"], |
|||
"suggested_batch_size": b["suggested_batch_size"], |
|||
"num_samples": b["num_samples"] |
|||
} for b in self.valid_buckets], |
|||
"min_samples": self.min_samples, |
|||
"max_samples": self.max_samples, |
|||
"total_samples": self.total_samples, |
|||
"total_original_samples": self.total_original_samples, |
|||
"discarded_samples": self.discarded_samples, |
|||
"total_padding": self.total_padding, |
|||
"total_actual_tokens": self.total_actual_tokens |
|||
} |
|||
|
|||
meta_path = self.cache_path / "buckets_meta.json" |
|||
meta_path.write_text(json.dumps(meta, indent=2)) |
|||
print(f"Metadata saved to {meta_path}") |
|||
|
|||
def _load_from_metadata(self): |
|||
meta_path = self.cache_path / "buckets_meta.json" |
|||
meta = json.loads(meta_path.read_text()) |
|||
|
|||
self.num_cuts = meta.get("num_cuts", 8) |
|||
self.min_samples = meta.get("min_samples", self.min_samples) |
|||
self.max_samples = meta.get("max_samples", self.max_samples) |
|||
self.total_samples = meta.get("total_samples", 0) |
|||
|
|||
self.total_original_samples = meta.get("total_original_samples", 0) |
|||
self.discarded_samples = meta.get("discarded_samples", 0) |
|||
self.total_padding = meta.get("total_padding", 0) |
|||
self.total_actual_tokens = meta.get("total_actual_tokens", 0) |
|||
|
|||
self.valid_buckets = [] |
|||
for b in meta["valid_buckets"]: |
|||
self.valid_buckets.append({ |
|||
"src_range": tuple(b["src_range"]), |
|||
"tgt_range": tuple(b["tgt_range"]), |
|||
"shard_idx": b["shard_idx"], |
|||
"num_shards": b["num_shards"], |
|||
"suggested_batch_size": b.get("suggested_batch_size", 0), |
|||
"num_samples": b["num_samples"] |
|||
}) |
|||
|
|||
print(f"Loaded {len(self.valid_buckets)} buckets from {meta_path}") |
|||
|
|||
def _metadata_exists(self) -> bool: |
|||
return (self.cache_path / "buckets_meta.json").exists() |
|||
|
|||
def __iter__(self) -> Generator[Dict, None, None]: |
|||
yield from self.valid_buckets |
|||
|
|||
def __len__(self) -> int: |
|||
return len(self.valid_buckets) |
|||
|
|||
def __getitem__(self, index: int) -> dict: |
|||
return self.valid_buckets[index] |
|||
|
|||
def reset_batch_size(self) -> None: |
|||
for b in self.valid_buckets: |
|||
b["suggested_batch_size"] = 0 |
|||
|
|||
def find_optimal_batch_size(self, find_batch_size_func) -> None: |
|||
bucket_type_map = {} |
|||
for i, bucket in enumerate(self.valid_buckets): |
|||
src_range = bucket["src_range"] |
|||
tgt_range = bucket["tgt_range"] |
|||
bucket_type = (src_range, tgt_range) |
|||
|
|||
if bucket_type in bucket_type_map: |
|||
bucket["suggested_batch_size"] = bucket_type_map[bucket_type] |
|||
print(f"Bucket {i + 1}/{len(self.valid_buckets)} reused batch size: {bucket['suggested_batch_size']}\n") |
|||
continue |
|||
|
|||
print( |
|||
f"Searching optimal batch size for bucket {i + 1}/{len(self.valid_buckets)} " + \ |
|||
f"src: {src_range} tgt: {tgt_range} ..." |
|||
) |
|||
|
|||
batch_size = find_batch_size_func(src_max=src_range[1], tgt_max=tgt_range[1]) |
|||
bucket["suggested_batch_size"] = batch_size |
|||
bucket_type_map[bucket_type] = bucket["suggested_batch_size"] |
|||
|
|||
print(f"Found batch size: {bucket['suggested_batch_size']}\n") |
|||
|
|||
self._save_metadata() |
|||
|
|||
def found_optimal(self) -> bool: |
|||
return all(b["suggested_batch_size"] > 0 for b in self.valid_buckets) |
|||
|
|||
def get_total_iterations(self, safety_factor: float, drop_last: bool = False) -> int: |
|||
total_iterations = 0 |
|||
|
|||
for bucket in self.valid_buckets: |
|||
batch_size = math.floor(bucket["suggested_batch_size"] * safety_factor) |
|||
num_samples = bucket["num_samples"] |
|||
|
|||
if batch_size <= 0: |
|||
raise ValueError("Batch size must be positive. Call find_optimal_batch_size first.") |
|||
|
|||
if drop_last: |
|||
iterations = num_samples // batch_size |
|||
else: |
|||
iterations = (num_samples + batch_size - 1) // batch_size |
|||
|
|||
total_iterations += iterations |
|||
|
|||
return total_iterations |
|||
|
|||
def get_info(self) -> List[dict]: |
|||
return self.valid_buckets |
|||
|
|||
def print_stats(self): |
|||
print("\n" + "=" * 60) |
|||
print("Bucket Statistics") |
|||
print("=" * 60) |
|||
print(f"Total buckets: {len(self.valid_buckets)}") |
|||
print(f"Total samples: {self.total_samples}") |
|||
print(f"Min samples per bucket: {self.min_samples}") |
|||
print(f"Max samples per bucket: {self.max_samples}") |
|||
|
|||
print("\nData Distribution:") |
|||
print("-" * 60) |
|||
print(f"Original samples: {self.total_original_samples}") |
|||
print(f"Discarded samples: {self.discarded_samples} ({self.discarded_samples / self.total_original_samples * 100:.2f}%)") |
|||
|
|||
total_tokens_with_padding = self.total_actual_tokens + self.total_padding |
|||
if total_tokens_with_padding > 0: |
|||
padding_rate = self.total_padding / total_tokens_with_padding |
|||
else: |
|||
padding_rate = 0.0 |
|||
print(f"Total padding tokens: {self.total_padding}") |
|||
print(f"Padding rate: {padding_rate * 100:.2f}%") |
|||
|
|||
print("\nBucket Details:") |
|||
print("-" * 80) |
|||
print(f"{'Bucket ID':<8} {'Src Range':<15} {'Tgt Range':<15} {'Samples':<10} {'Shards':<8} {'Batch Size':<10}") |
|||
print("-" * 80) |
|||
|
|||
for i, bucket in enumerate(self.valid_buckets): |
|||
src_range = f"{bucket['src_range'][0]}-{bucket['src_range'][1] - 1}" |
|||
tgt_range = f"{bucket['tgt_range'][0]}-{bucket['tgt_range'][1] - 1}" |
|||
samples = bucket['num_samples'] |
|||
shards = f"{bucket['shard_idx'] + 1}/{bucket['num_shards']}" |
|||
batch_size = bucket['suggested_batch_size'] |
|||
|
|||
print(f"{i:<8} {src_range:<15} {tgt_range:<15} {samples:<10} {shards:<8} {batch_size:<10}") |
@ -0,0 +1,84 @@ |
|||
import os |
|||
import pickle |
|||
import zstandard as zstd |
|||
import torch |
|||
from torch.utils.data import Dataset |
|||
import torch.nn.functional as F |
|||
|
|||
|
|||
class MyDataset(Dataset): |
|||
def __init__( |
|||
self, |
|||
cache_path: str, |
|||
processed_data: list[dict] = None, |
|||
src_range: tuple = (0, 0), |
|||
tgt_range: tuple = (0, 0), |
|||
shard_idx: int = 0, |
|||
zstd_level: int = 9 |
|||
): |
|||
self.src_range = src_range |
|||
self.tgt_range = tgt_range |
|||
self.zstd_level = zstd_level |
|||
self.processed_data = processed_data |
|||
|
|||
base_name = f"cached_src_{src_range[0]}_{src_range[1]}_tgt_{tgt_range[0]}_{tgt_range[1]}" |
|||
self.cache_path = os.path.join( |
|||
cache_path, |
|||
f"{base_name}_shard_{shard_idx}.pkl.zst" |
|||
) |
|||
|
|||
if os.path.exists(self.cache_path): |
|||
with open(self.cache_path, 'rb') as f: |
|||
dctx = zstd.ZstdDecompressor() |
|||
with dctx.stream_reader(f) as reader: |
|||
self.processed_data = pickle.load(reader) |
|||
return |
|||
elif self.processed_data is None: |
|||
raise ValueError(f"Cache file not found: {self.cache_path} and no data provided") |
|||
|
|||
# Cache保存 |
|||
print("Saving cache, please wait...") |
|||
cctx = zstd.ZstdCompressor(level=self.zstd_level) |
|||
with open(self.cache_path, 'wb') as f: |
|||
with cctx.stream_writer(f) as compressor: |
|||
pickle.dump( |
|||
self.processed_data, |
|||
compressor, |
|||
protocol=pickle.HIGHEST_PROTOCOL |
|||
) |
|||
print(f"Cache saved at: {self.cache_path}") |
|||
|
|||
def __len__(self): |
|||
return len(self.processed_data) |
|||
|
|||
def __getitem__(self, idx): |
|||
item = self.processed_data[idx] |
|||
return { |
|||
"src": torch.from_numpy(item["src"]).long(), |
|||
"tgt": torch.from_numpy(item["tgt"]).long(), |
|||
"label": torch.from_numpy(item["label"]).long(), |
|||
} |
|||
|
|||
|
|||
def collate_fn(batch, pad_token_id: int, src_max_length: int, tgt_max_length: int): |
|||
src = [item["src"] for item in batch] |
|||
tgt = [item["tgt"] for item in batch] |
|||
label = [item["label"] for item in batch] |
|||
|
|||
# pad src |
|||
src_padded = [F.pad(t, (0, src_max_length - t.size(0)), value=pad_token_id) for t in src] |
|||
src_padded = torch.stack(src_padded) |
|||
|
|||
# pad tgt |
|||
tgt_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in tgt] |
|||
tgt_padded = torch.stack(tgt_padded) |
|||
|
|||
# pad label |
|||
label_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in label] |
|||
label_padded = torch.stack(label_padded) |
|||
|
|||
return { |
|||
"src": src_padded, |
|||
"tgt": tgt_padded, |
|||
"label": label_padded, |
|||
} |
@ -0,0 +1,336 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh |
|||
import random |
|||
from typing import Tuple, List, Optional |
|||
|
|||
|
|||
class MyLSTM(nn.Module): |
|||
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None: |
|||
super().__init__() |
|||
self.input_size = input_size |
|||
self.hidden_size = hidden_size |
|||
self.norm = AdaptiveSoftTanh(input_size) |
|||
self.dropout = nn.Dropout(dropout) |
|||
|
|||
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size)) |
|||
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size)) |
|||
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size)) |
|||
|
|||
self.init_weight() |
|||
|
|||
self.h_state = None |
|||
self.c_state = None |
|||
self.initialized = False |
|||
|
|||
def init_weight(self) -> None: |
|||
for param in self.parameters(): |
|||
if param.dim() > 1: |
|||
nn.init.xavier_uniform_(param) |
|||
else: |
|||
nn.init.zeros_(param) |
|||
|
|||
def init_hidden( |
|||
self, batch_size: int, device: torch.device, |
|||
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None] |
|||
) -> None: |
|||
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \ |
|||
if h_state is not None else h_state |
|||
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \ |
|||
if c_state is not None else c_state |
|||
self.initialized = True |
|||
|
|||
def _forward_seq(self, x: torch.Tensor) \ |
|||
-> torch.Tensor: |
|||
_, T, _ = x.shape |
|||
|
|||
# 存储所有时间步的输出 |
|||
output = [] |
|||
normed = self.norm(x) |
|||
# 遍历序列的每个时间步 |
|||
for t in range(T): |
|||
x_t = normed[:, t, :] # 当前时间步输入 (B, C) |
|||
|
|||
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size) |
|||
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias |
|||
|
|||
# 分割为四个门 (每个都是batch_size, hidden_size) |
|||
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) |
|||
|
|||
i_t = SoftSigmoid(i_gate) # 输入门 |
|||
f_t = SoftSigmoid(f_gate) # 遗忘门 |
|||
g_t = SoftTanh(g_gate) # 候选门 |
|||
o_t = SoftSigmoid(o_gate) # 输出门 |
|||
|
|||
# i_t = torch.sigmoid(i_gate) # 输入门 |
|||
# f_t = torch.sigmoid(f_gate) # 遗忘门 |
|||
# g_t = torch.tanh(g_gate) # 候选门 |
|||
# o_t = torch.sigmoid(o_gate) # 输出门 |
|||
|
|||
# 更新细胞状态 |
|||
self.c_state = f_t * self.c_state + i_t * g_t |
|||
|
|||
# 更新隐藏状态 |
|||
self.h_state = o_t * SoftTanh(self.c_state) |
|||
# self.h_state = o_t * torch.tanh(self.c_state) |
|||
|
|||
# 当前时间步输出 |
|||
output.append(self.h_state) |
|||
|
|||
# 将输出转换为张量 |
|||
output = torch.stack(output, dim=1) |
|||
|
|||
# 添加残差连接 |
|||
if self.input_size == self.hidden_size: |
|||
output = output + x |
|||
else: |
|||
output = output |
|||
|
|||
return self.dropout(output) |
|||
|
|||
def _forward_step(self, x: torch.Tensor) \ |
|||
-> torch.Tensor: |
|||
normed = self.norm(x) |
|||
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size) |
|||
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias |
|||
|
|||
# 分割为四个门 (每个都是batch_size, hidden_size) |
|||
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1) |
|||
|
|||
i_t = SoftSigmoid(i_gate) # 输入门 |
|||
f_t = SoftSigmoid(f_gate) # 遗忘门 |
|||
g_t = SoftTanh(g_gate) # 候选门 |
|||
o_t = SoftSigmoid(o_gate) # 输出门 |
|||
|
|||
# i_t = torch.sigmoid(i_gate) # 输入门 |
|||
# f_t = torch.sigmoid(f_gate) # 遗忘门 |
|||
# g_t = torch.tanh(g_gate) # 候选门 |
|||
# o_t = torch.sigmoid(o_gate) # 输出门 |
|||
|
|||
# 更新细胞状态 |
|||
self.c_state = f_t * self.c_state + i_t * g_t |
|||
|
|||
# 更新隐藏状态 |
|||
self.h_state = o_t * SoftTanh(self.c_state) |
|||
# self.h_state = o_t * torch.tanh(self.c_state) |
|||
|
|||
# 当前时间步输出 |
|||
output = self.h_state |
|||
|
|||
# 添加残差连接 |
|||
if self.input_size == self.hidden_size: |
|||
output = output + x |
|||
else: |
|||
output = output |
|||
|
|||
return self.dropout(output) |
|||
|
|||
def forward(self, x) -> torch.Tensor: |
|||
if not self.initialized: |
|||
B = x.size(0) |
|||
self.init_hidden(B, x.device) |
|||
self.initialized = True |
|||
|
|||
if x.dim() == 2: |
|||
return self._forward_step(x) |
|||
elif x.dim() == 3: |
|||
return self._forward_seq(x) |
|||
else: |
|||
raise ValueError("input dim must be 2(step) or 3(sequence)") |
|||
|
|||
def reset(self) -> None: |
|||
self.h_state = None |
|||
self.c_state = None |
|||
self.initialized = False |
|||
|
|||
|
|||
class LSTMEncoder(nn.Module): |
|||
def __init__( |
|||
self, vocab_size: int, embedding_dim: int, |
|||
padding_idx: int, num_layers: int, dropout: float |
|||
) -> None: |
|||
super(LSTMEncoder, self).__init__() |
|||
self.embedding = nn.Embedding( |
|||
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx |
|||
) |
|||
self.layers = nn.ModuleList([ |
|||
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ |
|||
for _ in range(num_layers) |
|||
]) |
|||
|
|||
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|||
x = self.embedding(x) |
|||
output = x |
|||
# 保存每一层最后的隐藏状态和细胞状态 |
|||
h, c = [], [] |
|||
for layer in self.layers: |
|||
output = layer(output) |
|||
h.append(layer.h_state) |
|||
c.append(layer.c_state) |
|||
return h, c |
|||
|
|||
def reset(self) -> None: |
|||
for layer in self.layers: |
|||
layer.reset() |
|||
|
|||
|
|||
class LSTMDecoder(nn.Module): |
|||
def __init__( |
|||
self, vocab_size: int, embedding_dim: int, |
|||
padding_idx: int, num_layers: int, dropout: float |
|||
) -> None: |
|||
super(LSTMDecoder, self).__init__() |
|||
self.embedding = nn.Embedding( |
|||
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx |
|||
) |
|||
self.layers = nn.ModuleList([ |
|||
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \ |
|||
for _ in range(num_layers) |
|||
]) |
|||
self.fc = nn.Linear(embedding_dim, vocab_size) |
|||
|
|||
self.initialized = False |
|||
|
|||
def init_hidden( |
|||
self, batch_size: int, device: torch.device, |
|||
h: List[torch.Tensor], c: List[torch.Tensor] |
|||
) -> None: |
|||
for i, layer in enumerate(self.layers): |
|||
layer.init_hidden(batch_size, device, h[i], c[i]) |
|||
self.initialized = True |
|||
|
|||
def _forward_step(self, x: torch.Tensor) -> torch.Tensor: |
|||
output = x |
|||
for i, layer in enumerate(self.layers): |
|||
output = layer(output) |
|||
return output |
|||
|
|||
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor: |
|||
_, T, _ = x.shape |
|||
outputs = [] |
|||
for t in range(T): |
|||
if t == 0 or random.random() < tfr: |
|||
outputs.append(self._forward_step(x[:, t, :])) |
|||
else: |
|||
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1) |
|||
previous = self.embedding(logits) |
|||
outputs.append(self._forward_step(previous)) |
|||
outputs = torch.stack(outputs, dim=1) |
|||
return outputs |
|||
|
|||
def forward(self, x: torch.Tensor, tfr: float = 0.0): |
|||
x = self.embedding(x) |
|||
if x.dim() == 3: |
|||
return self.fc(self._forward_seq(x, tfr=tfr)) |
|||
elif x.dim() == 2: |
|||
return self.fc(self._forward_step(x)) |
|||
else: |
|||
raise ValueError("input dim must be 2(step) or 3(sequence)") |
|||
|
|||
def reset(self): |
|||
for layer in self.layers: |
|||
layer.reset() |
|||
self.initialized = True |
|||
|
|||
|
|||
class Seq2SeqLSTM(nn.Module): |
|||
def __init__( |
|||
self, vocab_size: int, embedding_dim: int, |
|||
padding_idx: int, num_layers: int, dropout: float |
|||
) -> None: |
|||
super().__init__() |
|||
self.encoder = LSTMEncoder( |
|||
vocab_size=vocab_size, embedding_dim=embedding_dim, |
|||
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout |
|||
) |
|||
self.decoder = LSTMDecoder( |
|||
vocab_size=vocab_size, embedding_dim=embedding_dim, |
|||
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout |
|||
) |
|||
|
|||
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor: |
|||
self.reset() |
|||
h, c = self.encoder(src) |
|||
batch_size = src.size(0) |
|||
device = src.device |
|||
self.decoder.init_hidden(batch_size, device, h, c) |
|||
return self.decoder(tgt, tfr=tfr) |
|||
|
|||
def reset(self): |
|||
self.encoder.reset() |
|||
self.decoder.reset() |
|||
|
|||
def greedy_decode( |
|||
self, |
|||
src: torch.Tensor, |
|||
bos_token_id: int, |
|||
eos_token_id: int, |
|||
pad_token_id: int, |
|||
max_length: int |
|||
) -> torch.Tensor: |
|||
""" |
|||
推理阶段的序列生成方法(支持批处理) |
|||
Args: |
|||
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len) |
|||
bos_token_id: 起始符token ID |
|||
eos_token_id: 结束符token ID |
|||
pad_token_id: 填充符token ID |
|||
max_length: 最大生成长度 |
|||
Returns: |
|||
生成的序列 (batch_size, tgt_len) |
|||
""" |
|||
self.reset() |
|||
batch_size = src.size(0) |
|||
device = src.device |
|||
|
|||
# 编码源序列 |
|||
h, c = self.encoder(src) |
|||
|
|||
# 初始化解码器状态 |
|||
self.decoder.init_hidden(batch_size, device, h, c) |
|||
|
|||
# 准备输出序列张量(全部初始化为pad_token_id) |
|||
sequences = torch.full( |
|||
(batch_size, max_length), |
|||
pad_token_id, |
|||
dtype=torch.long, |
|||
device=device |
|||
) |
|||
|
|||
# 初始输入为BOS |
|||
current_input = torch.full( |
|||
(batch_size,), |
|||
bos_token_id, |
|||
dtype=torch.long, |
|||
device=device |
|||
) |
|||
|
|||
# 标记序列是否已结束(初始全为False) |
|||
finished = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|||
|
|||
# 自回归生成序列 |
|||
for t in range(max_length): |
|||
# 跳过已结束序列的计算 |
|||
if finished.all(): |
|||
break |
|||
|
|||
# 获取当前时间步输出 |
|||
logits = self.decoder(current_input) # (batch_size, vocab_size) |
|||
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,) |
|||
|
|||
# 更新未结束序列的输出 |
|||
sequences[~finished, t] = next_tokens[~finished] |
|||
|
|||
# 检测EOS标记 |
|||
eos_mask = (next_tokens == eos_token_id) |
|||
finished = finished | eos_mask # 更新结束标志 |
|||
|
|||
# 准备下一步输入:未结束序列用新token,已结束序列用PAD |
|||
current_input = torch.where( |
|||
~finished, |
|||
next_tokens, |
|||
torch.full_like(next_tokens, pad_token_id) |
|||
) |
|||
|
|||
return sequences |
|||
|
@ -0,0 +1,221 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
import math |
|||
import matplotlib.pyplot as plt |
|||
|
|||
|
|||
class SoftTanhFunction(torch.autograd.Function): |
|||
''' |
|||
SoftTanh(x) = sign(x) * tanh(ln(1 + ln(1 + |x|))) |
|||
f(x) ∈ (-1, 1) |
|||
dy/dx ∈ (0, 1] |
|||
''' |
|||
@staticmethod |
|||
def forward(ctx, x): |
|||
abs_x = torch.abs(x) |
|||
u = torch.log1p(abs_x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
y = torch.sign(x) * tanh_v |
|||
ctx.save_for_backward(abs_x) |
|||
return y |
|||
|
|||
@staticmethod |
|||
def backward(ctx, grad_output): |
|||
abs_x, = ctx.saved_tensors |
|||
u = torch.log1p(abs_x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
sech_v_square = 1 - tanh_v.square() |
|||
denominator = (1 + u) * (1 + abs_x) |
|||
d_x = sech_v_square / denominator |
|||
grad_x = grad_output * d_x |
|||
return grad_x |
|||
|
|||
|
|||
def SoftTanh(x): |
|||
return SoftTanhFunction.apply(x) |
|||
|
|||
|
|||
class SoftSigmoidFunction(torch.autograd.Function): |
|||
''' |
|||
SoftSigmoid(x) = (sign(x) * tanh(ln(1 + ln(1 + |2x|))) + 1) / 2 |
|||
f(x) ∈ (0, 1) |
|||
dy/dx ∈ (0, 1] |
|||
''' |
|||
@staticmethod |
|||
def forward(ctx, x): |
|||
abs_2x = torch.abs(2 * x) |
|||
u = torch.log1p(abs_2x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
y = (torch.sign(x) * tanh_v + 1) * 0.5 |
|||
ctx.save_for_backward(abs_2x) |
|||
return y |
|||
|
|||
@staticmethod |
|||
def backward(ctx, grad_output): |
|||
abs_2x, = ctx.saved_tensors |
|||
u = torch.log1p(abs_2x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
sech_v_square = 1 - tanh_v.square() |
|||
denominator = (1 + u) * (1 + abs_2x) |
|||
d_x = sech_v_square / denominator |
|||
grad_x = grad_output * d_x |
|||
return grad_x |
|||
|
|||
|
|||
def SoftSigmoid(x): |
|||
return SoftSigmoidFunction.apply(x) |
|||
|
|||
|
|||
class AdaptiveSoftTanhFunction(torch.autograd.Function): |
|||
''' |
|||
AdaptiveSoftTanh(x) = alpha * sign(x) * tanh(ln(1 + ln(1 + |x|))) + beta |
|||
f(x) ∈ (beta - |alpha|, beta + |alpha|) |
|||
dy/dx ∈ (0, alpha] if alpha > 0 |
|||
dy/dx ∈ [alpha, 0) if alpha < 0 |
|||
''' |
|||
@staticmethod |
|||
def forward(ctx, x, alpha, beta): |
|||
abs_x = torch.abs(x) |
|||
u = torch.log1p(abs_x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
y = torch.sign(x) * tanh_v * alpha + beta |
|||
ctx.save_for_backward(x, alpha) |
|||
return y |
|||
|
|||
@staticmethod |
|||
def backward(ctx, grad_output): |
|||
x, alpha, = ctx.saved_tensors |
|||
abs_x = torch.abs(x) |
|||
u = torch.log1p(abs_x) |
|||
v = torch.log1p(u) |
|||
tanh_v = torch.tanh(v) |
|||
sech_v_square = 1 - tanh_v.square() |
|||
denominator = (1 + u) * (1 + abs_x) |
|||
|
|||
d_x = sech_v_square / denominator |
|||
|
|||
grad_x = grad_output * d_x |
|||
grad_alpha = grad_output * torch.sign(x) * tanh_v |
|||
grad_beta = grad_output.clone() |
|||
|
|||
sum_dims = [d for d in range(grad_output.dim()) if d != grad_output.dim() - 1] |
|||
if sum_dims: |
|||
grad_alpha = grad_alpha.sum(dim=sum_dims) |
|||
grad_beta = grad_beta.sum(dim=sum_dims) |
|||
|
|||
return grad_x, grad_alpha, grad_beta |
|||
|
|||
|
|||
class AdaptiveSoftTanh(nn.Module): |
|||
def __init__(self, channels): |
|||
super().__init__() |
|||
self.alpha = nn.Parameter(torch.Tensor(channels)) |
|||
self.beta = nn.Parameter(torch.zeros(channels)) |
|||
nn.init.normal_(self.alpha, mean=0, std=math.sqrt(2 / channels)) |
|||
|
|||
def forward(self, x): |
|||
return AdaptiveSoftTanhFunction.apply(x, self.alpha, self.beta) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
print("=== 测试 x=0 的情况 ===") |
|||
|
|||
# 创建包含0的输入张量 |
|||
x = torch.tensor([0.0], dtype=torch.float32, requires_grad=True) |
|||
|
|||
# 前向传播 |
|||
y = SoftTanh(x) |
|||
print(f"前向输出 (x=0): {y.item()}") |
|||
|
|||
# 反向传播 |
|||
y.backward() |
|||
print(f"梯度值 (x=0): {x.grad.item()}") |
|||
|
|||
# 验证梯度计算 |
|||
grad_ok = torch.allclose(x.grad, torch.tensor([1.0])) |
|||
print(f"梯度验证: {'通过' if grad_ok else '失败'}") |
|||
|
|||
x = torch.linspace(-10, 10, 10000, requires_grad=True) |
|||
|
|||
y_soft = SoftTanh(x) |
|||
y_tanh = torch.tanh(x) |
|||
|
|||
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0] |
|||
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0] |
|||
|
|||
x_np = x.detach().numpy() |
|||
y_soft_np = y_soft.detach().numpy() |
|||
y_tanh_np = y_tanh.detach().numpy() |
|||
grad_soft_np = grad_soft.detach().numpy() |
|||
grad_tanh_np = grad_tanh.detach().numpy() |
|||
|
|||
plt.figure(figsize=(12, 10)) |
|||
# function graph |
|||
plt.subplot(2, 1, 1) |
|||
plt.plot(x_np, y_soft_np, 'b-', linewidth=2, label='SoftTanh') |
|||
plt.plot(x_np, y_tanh_np, 'r--', linewidth=2, label='Tanh') |
|||
plt.title('Function Comparison') |
|||
plt.xlabel('x') |
|||
plt.ylabel('y') |
|||
plt.grid(True) |
|||
plt.legend() |
|||
# grad graph |
|||
plt.subplot(2, 1, 2) |
|||
plt.plot(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient') |
|||
plt.plot(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient') |
|||
plt.title('Gradient Comparison') |
|||
plt.xlabel('x') |
|||
plt.ylabel('dy/dx') |
|||
plt.grid(True) |
|||
plt.legend() |
|||
plt.tight_layout() |
|||
plt.savefig('./softtanh_comparison.png', dpi=300) |
|||
plt.close() |
|||
|
|||
# 使用对数空间生成x值,只考虑正半轴 |
|||
x = torch.logspace(-2, 3, 100000, base=10, requires_grad=True) |
|||
|
|||
y_soft = SoftTanh(x) |
|||
y_tanh = torch.tanh(x) |
|||
|
|||
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0] |
|||
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0] |
|||
|
|||
# 转换为numpy数组用于绘图 |
|||
x_np = x.detach().numpy() |
|||
y_soft_np = y_soft.detach().numpy() |
|||
y_tanh_np = y_tanh.detach().numpy() |
|||
grad_soft_np = grad_soft.detach().numpy() |
|||
grad_tanh_np = grad_tanh.detach().numpy() |
|||
|
|||
# 创建图形和子图 |
|||
plt.figure(figsize=(14, 12)) |
|||
|
|||
# 函数图像(对数坐标) |
|||
plt.subplot(2, 1, 1) |
|||
plt.loglog(x_np, 1 - y_soft_np, 'b-', linewidth=2, label='SoftTanh (1-y)') |
|||
plt.loglog(x_np, 1 - y_tanh_np, 'r--', linewidth=2, label='Tanh (1-y)') |
|||
plt.title('Function Asymptotic Behavior (Log-Log Scale)') |
|||
plt.xlabel('x (log scale)') |
|||
plt.ylabel('1 - y (log scale)') |
|||
plt.grid(True, which="both", ls="--") |
|||
plt.legend() |
|||
|
|||
# 梯度图像(对数坐标) |
|||
plt.subplot(2, 1, 2) |
|||
plt.loglog(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient') |
|||
plt.loglog(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient') |
|||
plt.title('Gradient Decay (Log-Log Scale)') |
|||
plt.xlabel('x (log scale)') |
|||
plt.ylabel('dy/dx (log scale)') |
|||
plt.grid(True, which="both", ls="--") |
|||
plt.legend() |
|||
|
|||
plt.tight_layout() |
|||
plt.savefig('./softtanh_log_comparison.png', dpi=300) |
|||
plt.close() |
@ -0,0 +1,123 @@ |
|||
from tokenizers import ( |
|||
decoders, |
|||
models, |
|||
normalizers, |
|||
pre_tokenizers, |
|||
trainers, |
|||
Tokenizer, |
|||
processors |
|||
) |
|||
from transformers import PreTrainedTokenizerFast |
|||
import os |
|||
import jieba |
|||
|
|||
SPECIAL_TOKENS = { |
|||
"pad_token": "[PAD]", |
|||
"bos_token": "[BOS]", |
|||
"eos_token": "[EOS]", |
|||
"unk_token": "[UNK]", |
|||
} |
|||
|
|||
|
|||
def create_tokenizer( |
|||
corpus_path: str, |
|||
save_dir: str, |
|||
language: str, |
|||
vocab_size: int, |
|||
): |
|||
"""create a univeral bpe tokenizer""" |
|||
if not os.path.exists(corpus_path): |
|||
raise FileNotFoundError(f"corpus file {corpus_path} not exists") |
|||
|
|||
tokenizer = Tokenizer(models.BPE( |
|||
unk_token=SPECIAL_TOKENS["unk_token"] |
|||
)) |
|||
|
|||
tokenizer.normalizer = normalizers.Sequence([ |
|||
normalizers.NFKC(), |
|||
normalizers.StripAccents(), |
|||
normalizers.NFD(), |
|||
normalizers.Lowercase(), |
|||
]) |
|||
|
|||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([ |
|||
pre_tokenizers.Whitespace(), |
|||
pre_tokenizers.Punctuation(), |
|||
pre_tokenizers.Digits(individual_digits=True), |
|||
pre_tokenizers.ByteLevel( |
|||
add_prefix_space=True, |
|||
use_regex=True |
|||
) |
|||
]) |
|||
|
|||
trainer = trainers.BpeTrainer( |
|||
special_tokens=list(SPECIAL_TOKENS.values()), |
|||
vocab_size=vocab_size, |
|||
min_frequency=4, |
|||
show_progress=True |
|||
) |
|||
|
|||
tokenizer.train(files=[corpus_path], trainer=trainer) |
|||
tokenizer.decoder = decoders.ByteLevel() |
|||
|
|||
# 添加后处理器以自动添加特殊标记 |
|||
bos_token = SPECIAL_TOKENS["bos_token"] |
|||
eos_token = SPECIAL_TOKENS["eos_token"] |
|||
tokenizer.post_processor = processors.TemplateProcessing( |
|||
single=f"{bos_token} $A {eos_token}", |
|||
pair=f"{bos_token} $A {eos_token} {bos_token} $B {eos_token}", |
|||
special_tokens=[ |
|||
(bos_token, tokenizer.token_to_id(bos_token)), |
|||
(eos_token, tokenizer.token_to_id(eos_token)), |
|||
], |
|||
) |
|||
|
|||
fast_tokenizer = PreTrainedTokenizerFast( |
|||
tokenizer_object=tokenizer, |
|||
**SPECIAL_TOKENS, |
|||
padding_side="right", |
|||
truncation_side="right", |
|||
do_lower_case=True |
|||
) |
|||
|
|||
os.makedirs(save_dir, exist_ok=True) |
|||
fast_tokenizer.save_pretrained(save_dir) |
|||
return fast_tokenizer |
|||
|
|||
|
|||
def load_tokenizer(path: str): |
|||
if not os.path.exists(path): |
|||
raise FileNotFoundError("can not find tokenizer, plz train first") |
|||
return PreTrainedTokenizerFast.from_pretrained(path) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
en_tokenizer = create_tokenizer( |
|||
corpus_path="data/txt/corpus.en", |
|||
save_dir="model/tokenizers/en", |
|||
language="en", |
|||
vocab_size=16384, |
|||
) |
|||
zh_tokenizer = create_tokenizer( |
|||
corpus_path="data/txt/corpus.zh", |
|||
save_dir="model/tokenizers/zh", |
|||
language="zh", |
|||
vocab_size=16384, |
|||
) |
|||
|
|||
en_tokenizer = load_tokenizer("model/tokenizers/en") |
|||
zh_tokenizer = load_tokenizer("model/tokenizers/zh") |
|||
|
|||
# 测试英文处理 |
|||
en_text = "How many books do you think you've read so far?" |
|||
en_encoding = en_tokenizer(en_text) |
|||
print("en encoding:", en_encoding.tokens()) |
|||
print("en decoding:", en_tokenizer.decode(en_encoding.input_ids)) |
|||
|
|||
# 测试中文处理 |
|||
zh_text = "到目前为止你认为你读过多少书?" |
|||
jieba.initialize() |
|||
zh_text = " ".join(jieba.lcut(zh_text)) |
|||
zh_encoding = zh_tokenizer(zh_text) |
|||
print("zh encoding:", zh_encoding.tokens()) |
|||
print("zh decoding:", zh_tokenizer.decode(zh_encoding.input_ids)) |
@ -1,2 +1 @@ |
|||
# MyArena |
|||
|
|||
# MyArena |
|||
|
@ -0,0 +1,140 @@ |
|||
# try_batch.py |
|||
import torch |
|||
import torch.nn as nn |
|||
from torch.amp import autocast, GradScaler |
|||
from MyLSTM import Seq2SeqLSTM |
|||
import multiprocessing |
|||
import gc |
|||
import math |
|||
import random |
|||
|
|||
|
|||
def try_batch(config: dict, src_max: int, tgt_max: int, batch_size: int, max_batch_iter: int = 2): |
|||
# init model, optimizer, scaler |
|||
model = Seq2SeqLSTM( |
|||
vocab_size=config["vocab_size"], |
|||
embedding_dim=config["embedding_dim"], |
|||
padding_idx=config["pad_token_id"], |
|||
num_layers=config["num_layers"], |
|||
dropout=config["dropout"], |
|||
).to(config["device"]) |
|||
|
|||
optimizer = torch.optim.AdamW( |
|||
model.parameters(), |
|||
lr=config["lr_eta_max"], |
|||
betas=config["betas"], |
|||
weight_decay=config["weight_decay"] |
|||
) |
|||
|
|||
scaler = GradScaler(config["device"], enabled=config["use_amp"]) |
|||
|
|||
# construct fake batch for testing |
|||
srcs = torch.full((1, src_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) |
|||
tgts = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) |
|||
labels = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"]) |
|||
|
|||
# training loop |
|||
for i in range(max_batch_iter): |
|||
with autocast(config["device"], enabled=config["use_amp"]): |
|||
if i % 2 == 0: |
|||
tfr = random.random() |
|||
else: |
|||
tfr = 1 |
|||
output = model(srcs, tgts, tfr) |
|||
loss = nn.CrossEntropyLoss()( |
|||
output.reshape(-1, output.size(-1)), |
|||
labels.contiguous().view(-1) |
|||
) |
|||
|
|||
scaler.scale(loss).backward() |
|||
scaler.unscale_(optimizer) |
|||
torch.nn.utils.clip_grad_norm_( |
|||
model.parameters(), |
|||
max_norm=config["grad_clip"] |
|||
) |
|||
scaler.step(optimizer) |
|||
scaler.update() |
|||
optimizer.zero_grad(set_to_none=True) |
|||
|
|||
|
|||
def _worker(config, src_max, tgt_max, batch_size, max_batch_iter, result_queue): |
|||
try: |
|||
torch.cuda.reset_peak_memory_stats() |
|||
try_batch(config, src_max, tgt_max, batch_size, max_batch_iter) |
|||
peak_mem = torch.cuda.max_memory_allocated() |
|||
result_queue.put(('success', peak_mem)) |
|||
except Exception as e: |
|||
if 'out of memory' in str(e).lower() or 'alloc' in str(e).lower(): |
|||
peak_mem = torch.cuda.max_memory_allocated() |
|||
result_queue.put(('oom', peak_mem)) |
|||
else: |
|||
result_queue.put(('error', type(e).__name__, str(e))) |
|||
finally: |
|||
gc.collect() |
|||
torch.cuda.empty_cache() |
|||
torch.cuda.synchronize() |
|||
|
|||
|
|||
def find_batch_size( |
|||
src_max: int, |
|||
tgt_max: int, |
|||
config: dict, |
|||
initial_batch_size: int = 16384, |
|||
max_batch_iter: int = 2, |
|||
timeout: int = 30 |
|||
) -> int: |
|||
low, high = 1, initial_batch_size |
|||
best_size = 0 |
|||
max_attempts = math.ceil(math.log2(initial_batch_size)) |
|||
attempt = 0 |
|||
|
|||
while low <= high and attempt < max_attempts: |
|||
attempt += 1 |
|||
mid = (low + high) // 2 |
|||
|
|||
ctx = multiprocessing.get_context('spawn') |
|||
result_queue = ctx.Queue() |
|||
process = ctx.Process( |
|||
target=_worker, |
|||
args=(config, src_max, tgt_max, mid, max_batch_iter, result_queue) |
|||
) |
|||
|
|||
try: |
|||
torch.cuda.reset_peak_memory_stats() |
|||
process.start() |
|||
process.join(timeout=timeout) |
|||
|
|||
if process.is_alive(): |
|||
process.terminate() |
|||
process.join() |
|||
raise TimeoutError(f"Batch size {mid} timed out after {timeout}s") |
|||
|
|||
if result_queue.empty(): |
|||
raise RuntimeError("Subprocess exited without returning results") |
|||
|
|||
status, data = result_queue.get() |
|||
|
|||
if status == 'success': |
|||
current_peak = data / (1024 ** 3) |
|||
print(f"Attempt {attempt:3d}: {mid:5d} | OK (Peak Memory: {current_peak:5.2f}GB)") |
|||
best_size = mid |
|||
low = mid + 1 |
|||
elif status == 'oom': |
|||
current_peak = data / (1024 ** 3) |
|||
print(f"Attempt {attempt:3d}: {mid:5d} | Failed (OOM, Peak: {current_peak:5.2f}GB)") |
|||
high = mid - 1 |
|||
elif status == 'error': |
|||
exc_type, exc_msg = data |
|||
print(f"Attempt {attempt:3d}: {mid:5d} | CRITICAL ERROR: {exc_type} - {exc_msg}") |
|||
raise RuntimeError(f"Critical error at batch size {mid}: {exc_type} - {exc_msg}") |
|||
|
|||
except Exception as e: |
|||
print(f"Attempt {attempt:3d}: {mid:5d} | Error: {str(e)}") |
|||
high = mid - 1 |
|||
|
|||
finally: |
|||
gc.collect() |
|||
torch.cuda.empty_cache() |
|||
torch.cuda.synchronize() |
|||
|
|||
return best_size |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,122 @@ |
|||
import os |
|||
import pandas as pd |
|||
import glob |
|||
import jieba |
|||
from multiprocessing import Pool |
|||
from tqdm import tqdm |
|||
import logging |
|||
|
|||
logging.basicConfig( |
|||
level=logging.INFO, |
|||
format='%(asctime)s - %(levelname)s - %(message)s' |
|||
) |
|||
|
|||
NUM_WORKERS = 20 |
|||
INPUT_DIRS = { |
|||
'train': 'data/train', |
|||
'valid': 'data/valid' |
|||
} |
|||
OUTPUT_DIR = 'data/cache/txt' |
|||
EN_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.en') |
|||
ZH_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.zh') |
|||
|
|||
|
|||
def init_jieba(): |
|||
jieba_logger = logging.getLogger('jieba') |
|||
jieba_logger.setLevel(logging.WARNING) |
|||
jieba.disable_parallel() |
|||
|
|||
|
|||
def process_line(record): |
|||
try: |
|||
en_text = record['translation']['en'] |
|||
zh_text = record['translation']['zh'] |
|||
|
|||
# 中文分词 |
|||
zh_words = jieba.lcut(zh_text) |
|||
zh_sentence = ' '.join(zh_words) |
|||
|
|||
return (en_text, zh_sentence) |
|||
except KeyError as e: |
|||
logging.warning(f"Missing field in record: {str(e)}") |
|||
return None |
|||
except Exception as e: |
|||
logging.warning(f"Line processing error: {str(e)}") |
|||
return None |
|||
|
|||
|
|||
def process_shard(shard_path): |
|||
try: |
|||
df = pd.read_parquet(shard_path) |
|||
records = df.to_dict(orient='records') |
|||
total = len(records) |
|||
logging.info(f"Processing {shard_path} ({total} lines)") |
|||
|
|||
with Pool(NUM_WORKERS, initializer=init_jieba) as pool: |
|||
results = [] |
|||
for result in tqdm( |
|||
pool.imap(process_line, records), |
|||
total=total, |
|||
desc=f"Processing {os.path.basename(shard_path)}", |
|||
unit="lines", |
|||
colour='green', |
|||
bar_format='{l_bar}{bar:32}{r_bar}' |
|||
): |
|||
if result is not None: |
|||
results.append(result) |
|||
|
|||
en_sentences, zh_sentences = zip(*results) if results else ([], []) |
|||
logging.info(f"Processed {len(results)} lines from {shard_path}") |
|||
return list(en_sentences), list(zh_sentences) |
|||
|
|||
except Exception as e: |
|||
logging.error(f"Shard processing failed: {shard_path} - {str(e)}") |
|||
return [], [] |
|||
|
|||
|
|||
def main(): |
|||
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|||
|
|||
all_shards = [] |
|||
for _, dir_path in INPUT_DIRS.items(): |
|||
shards = glob.glob(os.path.join(dir_path, '*.parquet')) |
|||
if not shards: |
|||
logging.warning(f"No Parquet files found in {dir_path}") |
|||
all_shards.extend(shards) |
|||
|
|||
if not all_shards: |
|||
logging.error("No Parquet files found in any input directories") |
|||
return |
|||
|
|||
all_shards.sort(key=lambda x: os.path.abspath(x)) |
|||
|
|||
logging.info(f"Found {len(all_shards)} shards to process") |
|||
jieba.initialize() |
|||
|
|||
all_en = [] |
|||
all_zh = [] |
|||
|
|||
for shard_path in all_shards: |
|||
en_sentences, zh_sentences = process_shard(shard_path) |
|||
all_en.extend(en_sentences) |
|||
all_zh.extend(zh_sentences) |
|||
|
|||
if len(all_en) != len(all_zh): |
|||
logging.warning(f"Data length mismatch: {len(all_en)} English vs {len(all_zh)} Chinese sentences") |
|||
|
|||
logging.info(f"Writing {len(all_en)} sentences to final files") |
|||
|
|||
with open(EN_OUTPUT_PATH, 'w', encoding='utf-8') as f_en, \ |
|||
open(ZH_OUTPUT_PATH, 'w', encoding='utf-8') as f_zh: |
|||
|
|||
for en, zh in tqdm(zip(all_en, all_zh), total=len(all_en), desc="Writing files"): |
|||
f_en.write(en + '\n') |
|||
f_zh.write(zh + '\n') |
|||
|
|||
logging.info("Corpus generation completed successfully") |
|||
logging.info(f"English corpus saved at: {EN_OUTPUT_PATH}") |
|||
logging.info(f"Chinese corpus saved at: {ZH_OUTPUT_PATH}") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
File diff suppressed because it is too large
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files changed in this diff
Loading…
Reference in new issue