Browse Source

initial commit

main
tanxing 4 days ago
parent
commit
3ed5fc1db0
  1. 3
      .gitattributes
  2. 16
      .vscode/launch.json
  3. 6
      .vscode/settings.json
  4. 423
      BucketManager.py
  5. 84
      MyDataset.py
  6. 336
      MyLSTM.py
  7. 221
      MyLayer.py
  8. 123
      MyTokenizer.py
  9. 0
      MyTransformer.py
  10. 3
      README.md
  11. 140
      TryBatch.py
  12. BIN
      __pycache__/BucketManager.cpython-312.pyc
  13. BIN
      __pycache__/MyDataset.cpython-312.pyc
  14. BIN
      __pycache__/MyLSTM.cpython-312.pyc
  15. BIN
      __pycache__/MyLayer.cpython-312.pyc
  16. BIN
      __pycache__/MyTokenizer.cpython-312.pyc
  17. BIN
      __pycache__/TryBatch.cpython-312.pyc
  18. 122
      corpus.py
  19. 1846
      data/cache/train/buckets_meta.json
  20. BIN
      data/cache/train/cached_src_102_129_tgt_82_99_shard_0.pkl.zst
  21. BIN
      data/cache/train/cached_src_102_129_tgt_99_128_shard_0.pkl.zst
  22. BIN
      data/cache/train/cached_src_10_13_tgt_12_15_shard_0.pkl.zst
  23. BIN
      data/cache/train/cached_src_10_13_tgt_15_18_shard_0.pkl.zst
  24. BIN
      data/cache/train/cached_src_10_13_tgt_2_6_shard_0.pkl.zst
  25. BIN
      data/cache/train/cached_src_10_13_tgt_6_9_shard_0.pkl.zst
  26. BIN
      data/cache/train/cached_src_10_13_tgt_6_9_shard_1.pkl.zst
  27. BIN
      data/cache/train/cached_src_10_13_tgt_9_12_shard_0.pkl.zst
  28. BIN
      data/cache/train/cached_src_10_13_tgt_9_12_shard_1.pkl.zst
  29. BIN
      data/cache/train/cached_src_10_13_tgt_9_12_shard_2.pkl.zst
  30. BIN
      data/cache/train/cached_src_13_16_tgt_12_15_shard_0.pkl.zst
  31. BIN
      data/cache/train/cached_src_13_16_tgt_12_15_shard_1.pkl.zst
  32. BIN
      data/cache/train/cached_src_13_16_tgt_12_15_shard_2.pkl.zst
  33. BIN
      data/cache/train/cached_src_13_16_tgt_15_18_shard_0.pkl.zst
  34. BIN
      data/cache/train/cached_src_13_16_tgt_18_21_shard_0.pkl.zst
  35. BIN
      data/cache/train/cached_src_13_16_tgt_6_9_shard_0.pkl.zst
  36. BIN
      data/cache/train/cached_src_13_16_tgt_9_12_shard_0.pkl.zst
  37. BIN
      data/cache/train/cached_src_13_16_tgt_9_12_shard_1.pkl.zst
  38. BIN
      data/cache/train/cached_src_16_19_tgt_12_15_shard_0.pkl.zst
  39. BIN
      data/cache/train/cached_src_16_19_tgt_12_15_shard_1.pkl.zst
  40. BIN
      data/cache/train/cached_src_16_19_tgt_15_18_shard_0.pkl.zst
  41. BIN
      data/cache/train/cached_src_16_19_tgt_15_18_shard_1.pkl.zst
  42. BIN
      data/cache/train/cached_src_16_19_tgt_15_18_shard_2.pkl.zst
  43. BIN
      data/cache/train/cached_src_16_19_tgt_18_21_shard_0.pkl.zst
  44. BIN
      data/cache/train/cached_src_16_19_tgt_21_24_shard_0.pkl.zst
  45. BIN
      data/cache/train/cached_src_16_19_tgt_9_12_shard_0.pkl.zst
  46. BIN
      data/cache/train/cached_src_19_22_tgt_12_15_shard_0.pkl.zst
  47. BIN
      data/cache/train/cached_src_19_22_tgt_15_18_shard_0.pkl.zst
  48. BIN
      data/cache/train/cached_src_19_22_tgt_15_18_shard_1.pkl.zst
  49. BIN
      data/cache/train/cached_src_19_22_tgt_18_21_shard_0.pkl.zst
  50. BIN
      data/cache/train/cached_src_19_22_tgt_18_21_shard_1.pkl.zst
  51. BIN
      data/cache/train/cached_src_19_22_tgt_21_24_shard_0.pkl.zst
  52. BIN
      data/cache/train/cached_src_19_22_tgt_24_27_shard_0.pkl.zst
  53. BIN
      data/cache/train/cached_src_19_22_tgt_27_30_shard_0.pkl.zst
  54. BIN
      data/cache/train/cached_src_19_22_tgt_9_12_shard_0.pkl.zst
  55. BIN
      data/cache/train/cached_src_22_25_tgt_12_15_shard_0.pkl.zst
  56. BIN
      data/cache/train/cached_src_22_25_tgt_15_18_shard_0.pkl.zst
  57. BIN
      data/cache/train/cached_src_22_25_tgt_18_21_shard_0.pkl.zst
  58. BIN
      data/cache/train/cached_src_22_25_tgt_18_21_shard_1.pkl.zst
  59. BIN
      data/cache/train/cached_src_22_25_tgt_21_24_shard_0.pkl.zst
  60. BIN
      data/cache/train/cached_src_22_25_tgt_21_24_shard_1.pkl.zst
  61. BIN
      data/cache/train/cached_src_22_25_tgt_24_27_shard_0.pkl.zst
  62. BIN
      data/cache/train/cached_src_22_25_tgt_27_30_shard_0.pkl.zst
  63. BIN
      data/cache/train/cached_src_22_25_tgt_30_33_shard_0.pkl.zst
  64. BIN
      data/cache/train/cached_src_25_28_tgt_15_18_shard_0.pkl.zst
  65. BIN
      data/cache/train/cached_src_25_28_tgt_18_21_shard_0.pkl.zst
  66. BIN
      data/cache/train/cached_src_25_28_tgt_21_24_shard_0.pkl.zst
  67. BIN
      data/cache/train/cached_src_25_28_tgt_21_24_shard_1.pkl.zst
  68. BIN
      data/cache/train/cached_src_25_28_tgt_24_27_shard_0.pkl.zst
  69. BIN
      data/cache/train/cached_src_25_28_tgt_24_27_shard_1.pkl.zst
  70. BIN
      data/cache/train/cached_src_25_28_tgt_27_30_shard_0.pkl.zst
  71. BIN
      data/cache/train/cached_src_25_28_tgt_30_33_shard_0.pkl.zst
  72. BIN
      data/cache/train/cached_src_25_28_tgt_33_36_shard_0.pkl.zst
  73. BIN
      data/cache/train/cached_src_28_31_tgt_15_18_shard_0.pkl.zst
  74. BIN
      data/cache/train/cached_src_28_31_tgt_18_21_shard_0.pkl.zst
  75. BIN
      data/cache/train/cached_src_28_31_tgt_21_24_shard_0.pkl.zst
  76. BIN
      data/cache/train/cached_src_28_31_tgt_24_27_shard_0.pkl.zst
  77. BIN
      data/cache/train/cached_src_28_31_tgt_24_27_shard_1.pkl.zst
  78. BIN
      data/cache/train/cached_src_28_31_tgt_27_30_shard_0.pkl.zst
  79. BIN
      data/cache/train/cached_src_28_31_tgt_30_33_shard_0.pkl.zst
  80. BIN
      data/cache/train/cached_src_28_31_tgt_33_36_shard_0.pkl.zst
  81. BIN
      data/cache/train/cached_src_28_31_tgt_36_40_shard_0.pkl.zst
  82. BIN
      data/cache/train/cached_src_31_34_tgt_18_21_shard_0.pkl.zst
  83. BIN
      data/cache/train/cached_src_31_34_tgt_21_24_shard_0.pkl.zst
  84. BIN
      data/cache/train/cached_src_31_34_tgt_24_27_shard_0.pkl.zst
  85. BIN
      data/cache/train/cached_src_31_34_tgt_27_30_shard_0.pkl.zst
  86. BIN
      data/cache/train/cached_src_31_34_tgt_30_33_shard_0.pkl.zst
  87. BIN
      data/cache/train/cached_src_31_34_tgt_33_36_shard_0.pkl.zst
  88. BIN
      data/cache/train/cached_src_31_34_tgt_36_40_shard_0.pkl.zst
  89. BIN
      data/cache/train/cached_src_34_38_tgt_21_24_shard_0.pkl.zst
  90. BIN
      data/cache/train/cached_src_34_38_tgt_24_27_shard_0.pkl.zst
  91. BIN
      data/cache/train/cached_src_34_38_tgt_27_30_shard_0.pkl.zst
  92. BIN
      data/cache/train/cached_src_34_38_tgt_30_33_shard_0.pkl.zst
  93. BIN
      data/cache/train/cached_src_34_38_tgt_33_36_shard_0.pkl.zst
  94. BIN
      data/cache/train/cached_src_34_38_tgt_36_40_shard_0.pkl.zst
  95. BIN
      data/cache/train/cached_src_34_38_tgt_40_44_shard_0.pkl.zst
  96. BIN
      data/cache/train/cached_src_38_42_tgt_24_27_shard_0.pkl.zst
  97. BIN
      data/cache/train/cached_src_38_42_tgt_27_30_shard_0.pkl.zst
  98. BIN
      data/cache/train/cached_src_38_42_tgt_30_33_shard_0.pkl.zst
  99. BIN
      data/cache/train/cached_src_38_42_tgt_33_36_shard_0.pkl.zst
  100. BIN
      data/cache/train/cached_src_38_42_tgt_36_40_shard_0.pkl.zst

3
.gitattributes

@ -0,0 +1,3 @@
*.pt filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text

16
.vscode/launch.json

@ -0,0 +1,16 @@
{
// 使 IntelliSense
//
// 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"cwd": "${fileDirname}"
}
]
}

6
.vscode/settings.json

@ -0,0 +1,6 @@
{
"editor.formatOnSave": false,
"editor.defaultFormatter": "charliermarsh.ruff",
"ruff.lint.ignore": ["E501", "E502", "W293", "W391", "Q", "ANN", "D", "ERA", "PLR"],
"ruff.lint.select": ["E", "W", "F"]
}

423
BucketManager.py

@ -0,0 +1,423 @@
import json
from typing import List, Optional, Generator, Dict, Tuple
from collections import defaultdict, Counter
from pathlib import Path
import shutil
import math
from MyDataset import MyDataset
class BucketManager:
def __init__(
self,
cache_path: str,
min_samples: int = 1024,
max_samples: int = 524288,
num_cuts: int = 16,
processed_data: Optional[List[dict]] = None,
force_rebuild: bool = False,
):
self._validate_init_params(cache_path, processed_data, force_rebuild)
self.cache_path = Path(cache_path)
self.num_cuts = num_cuts
self.min_samples = min_samples
self.max_samples = max_samples
self.buckets = defaultdict(lambda: defaultdict(list))
self.valid_buckets = []
self.total_samples = 0
self.src_buckets = []
self.tgt_buckets = []
self.total_original_samples = 0
self.discarded_samples = 0
self.total_padding = 0
self.total_actual_tokens = 0
if processed_data is not None:
if not force_rebuild and self._metadata_exists():
self._load_from_metadata()
else:
if force_rebuild and self.cache_path.exists():
shutil.rmtree(self.cache_path)
self.cache_path.mkdir(parents=True, exist_ok=True)
self._process_data_with_dp(processed_data)
self._save_metadata()
else:
if not self._metadata_exists():
raise FileNotFoundError(f"No cache found at {cache_path}")
self._load_from_metadata()
def _validate_init_params(
self,
cache_path: str,
processed_data: Optional[List[dict]],
force_rebuild: bool
):
if not cache_path:
raise ValueError("cache_path cannot be empty")
if processed_data is None and not force_rebuild and not Path(cache_path).exists():
raise FileNotFoundError(f"Cache path missing: {cache_path}")
@staticmethod
def _optimal_1d_partition(lengths: List[int], num_cuts: int) -> List[Tuple[int, int]]:
if not lengths:
return []
length_counts = Counter(lengths)
unique_lengths = sorted(length_counts.keys())
if len(unique_lengths) <= num_cuts:
buckets = []
for i, length in enumerate(unique_lengths):
start = length
end = unique_lengths[i + 1] if i + 1 < len(unique_lengths) else length + 1
buckets.append((start, end))
return buckets
n = len(unique_lengths)
dp = [[float('inf')] * (num_cuts + 1) for _ in range(n + 1)]
parent = [[-1] * (num_cuts + 1) for _ in range(n + 1)]
dp[0][0] = 0
for i in range(1, n + 1):
for j in range(1, min(i + 1, num_cuts + 1)):
for k in range(j - 1, i):
if dp[k][j - 1] == float('inf'):
continue
bucket_start = unique_lengths[k]
bucket_end = unique_lengths[i - 1] + 1
bucket_max = unique_lengths[i - 1]
padding_in_bucket = 0
for idx in range(k, i):
length = unique_lengths[idx]
count = length_counts[length]
padding_in_bucket += (bucket_max - length) * count
total_padding = dp[k][j - 1] + padding_in_bucket
if total_padding < dp[i][j]:
dp[i][j] = total_padding
parent[i][j] = k
if dp[n][num_cuts] == float('inf'):
return [(min(lengths), max(lengths) + 1)]
buckets = []
i, j = n, num_cuts
while j > 0 and i > 0:
k = parent[i][j]
if k < 0 or k >= i:
break
bucket_start = unique_lengths[k]
bucket_end = unique_lengths[i - 1] + 1
buckets.append((bucket_start, bucket_end))
i, j = k, j - 1
buckets.reverse()
return buckets
def _create_2d_buckets(
self,
data: List[Tuple[int, int]]
) -> List[Tuple[Tuple[int, int], Tuple[int, int]]]:
if not data:
return []
src_lengths = [item[0] for item in data]
tgt_lengths = [item[1] for item in data]
print("Calculating optimal buckets for src lengths...")
self.src_buckets = self._optimal_1d_partition(src_lengths, self.num_cuts)
print(f"src buckets: {self.src_buckets}")
print("Calculating optimal buckets for tgt lengths...")
self.tgt_buckets = self._optimal_1d_partition(tgt_lengths, self.num_cuts)
print(f"tgt buckets: {self.tgt_buckets}")
bucket_samples = {}
for src_bucket in self.src_buckets:
for tgt_bucket in self.tgt_buckets:
bucket_key = (src_bucket, tgt_bucket)
bucket_samples[bucket_key] = []
for src_len, tgt_len in data:
src_bucket = None
tgt_bucket = None
for bucket in self.src_buckets:
if bucket[0] <= src_len < bucket[1]:
src_bucket = bucket
break
for bucket in self.tgt_buckets:
if bucket[0] <= tgt_len < bucket[1]:
tgt_bucket = bucket
break
if src_bucket and tgt_bucket:
bucket_key = (src_bucket, tgt_bucket)
bucket_samples[bucket_key].append((src_len, tgt_len))
valid_buckets = []
for bucket_key, samples in bucket_samples.items():
if len(samples) >= self.min_samples:
valid_buckets.append(bucket_key)
return valid_buckets
def _process_data_with_dp(self, data: List[dict]):
length_pairs = [(len(item["src"]), len(item["tgt"])) for item in data]
self.total_original_samples = len(data)
valid_buckets = self._create_2d_buckets(length_pairs)
bucket_data = {bucket: [] for bucket in valid_buckets}
total_actual_tokens = 0
total_padding = 0
used_samples = 0
for item in data:
src_len = len(item["src"])
tgt_len = len(item["tgt"])
found_bucket = False
for (src_start, src_end), (tgt_start, tgt_end) in valid_buckets:
if src_start <= src_len < src_end and tgt_start <= tgt_len < tgt_end:
src_max = src_end - 1
tgt_max = tgt_end - 1
src_pad = src_max - src_len
tgt_pad = tgt_max - tgt_len
total_padding += src_pad + tgt_pad
total_actual_tokens += src_len + tgt_len
used_samples += 1
bucket_data[(src_start, src_end), (tgt_start, tgt_end)].append(item)
found_bucket = True
break
if not found_bucket:
pass
self.discarded_samples = len(data) - used_samples
self.total_padding = total_padding
self.total_actual_tokens = total_actual_tokens
print("\nBuilding datasets from buckets...")
total_samples = 0
self.valid_buckets = []
sorted_buckets = sorted(valid_buckets, key=lambda x: (x[1][0], x[0][0]))
for bucket_key in sorted_buckets:
(src_start, src_end), (tgt_start, tgt_end) = bucket_key
bucket_items = bucket_data[bucket_key]
data_len = len(bucket_items)
base_num_shards = max(1, (data_len + self.max_samples - 1) // self.max_samples)
last_shard_size = data_len % self.max_samples
if last_shard_size == 0 and data_len > 0:
last_shard_size = self.max_samples
if base_num_shards > 1 and last_shard_size <= self.max_samples * 0.5:
num_shards = base_num_shards - 1
else:
num_shards = base_num_shards
for shard_idx in range(num_shards):
if shard_idx == num_shards - 1:
start_idx = shard_idx * self.max_samples
end_idx = data_len
else:
start_idx = shard_idx * self.max_samples
end_idx = start_idx + self.max_samples
shard_data = bucket_items[start_idx:end_idx]
shard_size = len(shard_data)
MyDataset(
cache_path=str(self.cache_path),
processed_data=shard_data,
src_range=(src_start, src_end),
tgt_range=(tgt_start, tgt_end),
shard_idx=shard_idx,
)
bucket_info = {
"src_range": (src_start, src_end),
"tgt_range": (tgt_start, tgt_end),
"shard_idx": shard_idx,
"num_shards": num_shards,
"suggested_batch_size": 0,
"num_samples": shard_size
}
self.valid_buckets.append(bucket_info)
total_samples += shard_size
self.total_samples = total_samples
print(f"Valid buckets: {len(self.valid_buckets)}")
print(f"Total samples: {total_samples}")
def _save_metadata(self):
meta = {
"num_cuts": self.num_cuts,
"valid_buckets": [{
"src_range": list(b["src_range"]),
"tgt_range": list(b["tgt_range"]),
"shard_idx": b["shard_idx"],
"num_shards": b["num_shards"],
"suggested_batch_size": b["suggested_batch_size"],
"num_samples": b["num_samples"]
} for b in self.valid_buckets],
"min_samples": self.min_samples,
"max_samples": self.max_samples,
"total_samples": self.total_samples,
"total_original_samples": self.total_original_samples,
"discarded_samples": self.discarded_samples,
"total_padding": self.total_padding,
"total_actual_tokens": self.total_actual_tokens
}
meta_path = self.cache_path / "buckets_meta.json"
meta_path.write_text(json.dumps(meta, indent=2))
print(f"Metadata saved to {meta_path}")
def _load_from_metadata(self):
meta_path = self.cache_path / "buckets_meta.json"
meta = json.loads(meta_path.read_text())
self.num_cuts = meta.get("num_cuts", 8)
self.min_samples = meta.get("min_samples", self.min_samples)
self.max_samples = meta.get("max_samples", self.max_samples)
self.total_samples = meta.get("total_samples", 0)
self.total_original_samples = meta.get("total_original_samples", 0)
self.discarded_samples = meta.get("discarded_samples", 0)
self.total_padding = meta.get("total_padding", 0)
self.total_actual_tokens = meta.get("total_actual_tokens", 0)
self.valid_buckets = []
for b in meta["valid_buckets"]:
self.valid_buckets.append({
"src_range": tuple(b["src_range"]),
"tgt_range": tuple(b["tgt_range"]),
"shard_idx": b["shard_idx"],
"num_shards": b["num_shards"],
"suggested_batch_size": b.get("suggested_batch_size", 0),
"num_samples": b["num_samples"]
})
print(f"Loaded {len(self.valid_buckets)} buckets from {meta_path}")
def _metadata_exists(self) -> bool:
return (self.cache_path / "buckets_meta.json").exists()
def __iter__(self) -> Generator[Dict, None, None]:
yield from self.valid_buckets
def __len__(self) -> int:
return len(self.valid_buckets)
def __getitem__(self, index: int) -> dict:
return self.valid_buckets[index]
def reset_batch_size(self) -> None:
for b in self.valid_buckets:
b["suggested_batch_size"] = 0
def find_optimal_batch_size(self, find_batch_size_func) -> None:
bucket_type_map = {}
for i, bucket in enumerate(self.valid_buckets):
src_range = bucket["src_range"]
tgt_range = bucket["tgt_range"]
bucket_type = (src_range, tgt_range)
if bucket_type in bucket_type_map:
bucket["suggested_batch_size"] = bucket_type_map[bucket_type]
print(f"Bucket {i + 1}/{len(self.valid_buckets)} reused batch size: {bucket['suggested_batch_size']}\n")
continue
print(
f"Searching optimal batch size for bucket {i + 1}/{len(self.valid_buckets)} " + \
f"src: {src_range} tgt: {tgt_range} ..."
)
batch_size = find_batch_size_func(src_max=src_range[1], tgt_max=tgt_range[1])
bucket["suggested_batch_size"] = batch_size
bucket_type_map[bucket_type] = bucket["suggested_batch_size"]
print(f"Found batch size: {bucket['suggested_batch_size']}\n")
self._save_metadata()
def found_optimal(self) -> bool:
return all(b["suggested_batch_size"] > 0 for b in self.valid_buckets)
def get_total_iterations(self, safety_factor: float, drop_last: bool = False) -> int:
total_iterations = 0
for bucket in self.valid_buckets:
batch_size = math.floor(bucket["suggested_batch_size"] * safety_factor)
num_samples = bucket["num_samples"]
if batch_size <= 0:
raise ValueError("Batch size must be positive. Call find_optimal_batch_size first.")
if drop_last:
iterations = num_samples // batch_size
else:
iterations = (num_samples + batch_size - 1) // batch_size
total_iterations += iterations
return total_iterations
def get_info(self) -> List[dict]:
return self.valid_buckets
def print_stats(self):
print("\n" + "=" * 60)
print("Bucket Statistics")
print("=" * 60)
print(f"Total buckets: {len(self.valid_buckets)}")
print(f"Total samples: {self.total_samples}")
print(f"Min samples per bucket: {self.min_samples}")
print(f"Max samples per bucket: {self.max_samples}")
print("\nData Distribution:")
print("-" * 60)
print(f"Original samples: {self.total_original_samples}")
print(f"Discarded samples: {self.discarded_samples} ({self.discarded_samples / self.total_original_samples * 100:.2f}%)")
total_tokens_with_padding = self.total_actual_tokens + self.total_padding
if total_tokens_with_padding > 0:
padding_rate = self.total_padding / total_tokens_with_padding
else:
padding_rate = 0.0
print(f"Total padding tokens: {self.total_padding}")
print(f"Padding rate: {padding_rate * 100:.2f}%")
print("\nBucket Details:")
print("-" * 80)
print(f"{'Bucket ID':<8} {'Src Range':<15} {'Tgt Range':<15} {'Samples':<10} {'Shards':<8} {'Batch Size':<10}")
print("-" * 80)
for i, bucket in enumerate(self.valid_buckets):
src_range = f"{bucket['src_range'][0]}-{bucket['src_range'][1] - 1}"
tgt_range = f"{bucket['tgt_range'][0]}-{bucket['tgt_range'][1] - 1}"
samples = bucket['num_samples']
shards = f"{bucket['shard_idx'] + 1}/{bucket['num_shards']}"
batch_size = bucket['suggested_batch_size']
print(f"{i:<8} {src_range:<15} {tgt_range:<15} {samples:<10} {shards:<8} {batch_size:<10}")

84
MyDataset.py

@ -0,0 +1,84 @@
import os
import pickle
import zstandard as zstd
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
class MyDataset(Dataset):
def __init__(
self,
cache_path: str,
processed_data: list[dict] = None,
src_range: tuple = (0, 0),
tgt_range: tuple = (0, 0),
shard_idx: int = 0,
zstd_level: int = 9
):
self.src_range = src_range
self.tgt_range = tgt_range
self.zstd_level = zstd_level
self.processed_data = processed_data
base_name = f"cached_src_{src_range[0]}_{src_range[1]}_tgt_{tgt_range[0]}_{tgt_range[1]}"
self.cache_path = os.path.join(
cache_path,
f"{base_name}_shard_{shard_idx}.pkl.zst"
)
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as reader:
self.processed_data = pickle.load(reader)
return
elif self.processed_data is None:
raise ValueError(f"Cache file not found: {self.cache_path} and no data provided")
# Cache保存
print("Saving cache, please wait...")
cctx = zstd.ZstdCompressor(level=self.zstd_level)
with open(self.cache_path, 'wb') as f:
with cctx.stream_writer(f) as compressor:
pickle.dump(
self.processed_data,
compressor,
protocol=pickle.HIGHEST_PROTOCOL
)
print(f"Cache saved at: {self.cache_path}")
def __len__(self):
return len(self.processed_data)
def __getitem__(self, idx):
item = self.processed_data[idx]
return {
"src": torch.from_numpy(item["src"]).long(),
"tgt": torch.from_numpy(item["tgt"]).long(),
"label": torch.from_numpy(item["label"]).long(),
}
def collate_fn(batch, pad_token_id: int, src_max_length: int, tgt_max_length: int):
src = [item["src"] for item in batch]
tgt = [item["tgt"] for item in batch]
label = [item["label"] for item in batch]
# pad src
src_padded = [F.pad(t, (0, src_max_length - t.size(0)), value=pad_token_id) for t in src]
src_padded = torch.stack(src_padded)
# pad tgt
tgt_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in tgt]
tgt_padded = torch.stack(tgt_padded)
# pad label
label_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in label]
label_padded = torch.stack(label_padded)
return {
"src": src_padded,
"tgt": tgt_padded,
"label": label_padded,
}

336
MyLSTM.py

@ -0,0 +1,336 @@
import torch
import torch.nn as nn
from MyLayer import SoftTanh, SoftSigmoid, AdaptiveSoftTanh
import random
from typing import Tuple, List, Optional
class MyLSTM(nn.Module):
def __init__(self, input_size: int, hidden_size: int, dropout: float) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.norm = AdaptiveSoftTanh(input_size)
self.dropout = nn.Dropout(dropout)
self.W_i = nn.Parameter(torch.Tensor(4 * hidden_size, input_size))
self.W_h = nn.Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
self.init_weight()
self.h_state = None
self.c_state = None
self.initialized = False
def init_weight(self) -> None:
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
else:
nn.init.zeros_(param)
def init_hidden(
self, batch_size: int, device: torch.device,
h_state: torch.Tensor = Optional[None], c_state: torch.Tensor = Optional[None]
) -> None:
self.h_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if h_state is not None else h_state
self.c_state = torch.zeros(batch_size, self.hidden_size, device=device) \
if c_state is not None else c_state
self.initialized = True
def _forward_seq(self, x: torch.Tensor) \
-> torch.Tensor:
_, T, _ = x.shape
# 存储所有时间步的输出
output = []
normed = self.norm(x)
# 遍历序列的每个时间步
for t in range(T):
x_t = normed[:, t, :] # 当前时间步输入 (B, C)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = x_t @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output.append(self.h_state)
# 将输出转换为张量
output = torch.stack(output, dim=1)
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
else:
output = output
return self.dropout(output)
def _forward_step(self, x: torch.Tensor) \
-> torch.Tensor:
normed = self.norm(x)
# 合并计算所有门的线性变换 (batch_size, 4*hidden_size)
gates = normed @ self.W_i.t() + self.h_state @ self.W_h.t() + self.bias
# 分割为四个门 (每个都是batch_size, hidden_size)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=-1)
i_t = SoftSigmoid(i_gate) # 输入门
f_t = SoftSigmoid(f_gate) # 遗忘门
g_t = SoftTanh(g_gate) # 候选门
o_t = SoftSigmoid(o_gate) # 输出门
# i_t = torch.sigmoid(i_gate) # 输入门
# f_t = torch.sigmoid(f_gate) # 遗忘门
# g_t = torch.tanh(g_gate) # 候选门
# o_t = torch.sigmoid(o_gate) # 输出门
# 更新细胞状态
self.c_state = f_t * self.c_state + i_t * g_t
# 更新隐藏状态
self.h_state = o_t * SoftTanh(self.c_state)
# self.h_state = o_t * torch.tanh(self.c_state)
# 当前时间步输出
output = self.h_state
# 添加残差连接
if self.input_size == self.hidden_size:
output = output + x
else:
output = output
return self.dropout(output)
def forward(self, x) -> torch.Tensor:
if not self.initialized:
B = x.size(0)
self.init_hidden(B, x.device)
self.initialized = True
if x.dim() == 2:
return self._forward_step(x)
elif x.dim() == 3:
return self._forward_seq(x)
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self) -> None:
self.h_state = None
self.c_state = None
self.initialized = False
class LSTMEncoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMEncoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
def forward(self, x) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
x = self.embedding(x)
output = x
# 保存每一层最后的隐藏状态和细胞状态
h, c = [], []
for layer in self.layers:
output = layer(output)
h.append(layer.h_state)
c.append(layer.c_state)
return h, c
def reset(self) -> None:
for layer in self.layers:
layer.reset()
class LSTMDecoder(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super(LSTMDecoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx
)
self.layers = nn.ModuleList([
MyLSTM(input_size=embedding_dim, hidden_size=embedding_dim, dropout=dropout) \
for _ in range(num_layers)
])
self.fc = nn.Linear(embedding_dim, vocab_size)
self.initialized = False
def init_hidden(
self, batch_size: int, device: torch.device,
h: List[torch.Tensor], c: List[torch.Tensor]
) -> None:
for i, layer in enumerate(self.layers):
layer.init_hidden(batch_size, device, h[i], c[i])
self.initialized = True
def _forward_step(self, x: torch.Tensor) -> torch.Tensor:
output = x
for i, layer in enumerate(self.layers):
output = layer(output)
return output
def _forward_seq(self, x: torch.Tensor, tfr: float = 0.0) -> torch.Tensor:
_, T, _ = x.shape
outputs = []
for t in range(T):
if t == 0 or random.random() < tfr:
outputs.append(self._forward_step(x[:, t, :]))
else:
logits = torch.argmax(self.fc(outputs[-1].clone().detach()), dim=-1)
previous = self.embedding(logits)
outputs.append(self._forward_step(previous))
outputs = torch.stack(outputs, dim=1)
return outputs
def forward(self, x: torch.Tensor, tfr: float = 0.0):
x = self.embedding(x)
if x.dim() == 3:
return self.fc(self._forward_seq(x, tfr=tfr))
elif x.dim() == 2:
return self.fc(self._forward_step(x))
else:
raise ValueError("input dim must be 2(step) or 3(sequence)")
def reset(self):
for layer in self.layers:
layer.reset()
self.initialized = True
class Seq2SeqLSTM(nn.Module):
def __init__(
self, vocab_size: int, embedding_dim: int,
padding_idx: int, num_layers: int, dropout: float
) -> None:
super().__init__()
self.encoder = LSTMEncoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
self.decoder = LSTMDecoder(
vocab_size=vocab_size, embedding_dim=embedding_dim,
padding_idx=padding_idx, num_layers=num_layers, dropout=dropout
)
def forward(self, src: torch.Tensor, tgt: torch.Tensor, tfr: float = 0) -> torch.Tensor:
self.reset()
h, c = self.encoder(src)
batch_size = src.size(0)
device = src.device
self.decoder.init_hidden(batch_size, device, h, c)
return self.decoder(tgt, tfr=tfr)
def reset(self):
self.encoder.reset()
self.decoder.reset()
def greedy_decode(
self,
src: torch.Tensor,
bos_token_id: int,
eos_token_id: int,
pad_token_id: int,
max_length: int
) -> torch.Tensor:
"""
推理阶段的序列生成方法支持批处理
Args:
src: 已添加BOS/EOS并填充的源序列 (batch_size, src_len)
bos_token_id: 起始符token ID
eos_token_id: 结束符token ID
pad_token_id: 填充符token ID
max_length: 最大生成长度
Returns:
生成的序列 (batch_size, tgt_len)
"""
self.reset()
batch_size = src.size(0)
device = src.device
# 编码源序列
h, c = self.encoder(src)
# 初始化解码器状态
self.decoder.init_hidden(batch_size, device, h, c)
# 准备输出序列张量(全部初始化为pad_token_id)
sequences = torch.full(
(batch_size, max_length),
pad_token_id,
dtype=torch.long,
device=device
)
# 初始输入为BOS
current_input = torch.full(
(batch_size,),
bos_token_id,
dtype=torch.long,
device=device
)
# 标记序列是否已结束(初始全为False)
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
# 自回归生成序列
for t in range(max_length):
# 跳过已结束序列的计算
if finished.all():
break
# 获取当前时间步输出
logits = self.decoder(current_input) # (batch_size, vocab_size)
next_tokens = logits.argmax(dim=-1) # 贪心选择 (batch_size,)
# 更新未结束序列的输出
sequences[~finished, t] = next_tokens[~finished]
# 检测EOS标记
eos_mask = (next_tokens == eos_token_id)
finished = finished | eos_mask # 更新结束标志
# 准备下一步输入:未结束序列用新token,已结束序列用PAD
current_input = torch.where(
~finished,
next_tokens,
torch.full_like(next_tokens, pad_token_id)
)
return sequences

221
MyLayer.py

@ -0,0 +1,221 @@
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
class SoftTanhFunction(torch.autograd.Function):
'''
SoftTanh(x) = sign(x) * tanh(ln(1 + ln(1 + |x|)))
f(x) (-1, 1)
dy/dx (0, 1]
'''
@staticmethod
def forward(ctx, x):
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = torch.sign(x) * tanh_v
ctx.save_for_backward(abs_x)
return y
@staticmethod
def backward(ctx, grad_output):
abs_x, = ctx.saved_tensors
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
return grad_x
def SoftTanh(x):
return SoftTanhFunction.apply(x)
class SoftSigmoidFunction(torch.autograd.Function):
'''
SoftSigmoid(x) = (sign(x) * tanh(ln(1 + ln(1 + |2x|))) + 1) / 2
f(x) (0, 1)
dy/dx (0, 1]
'''
@staticmethod
def forward(ctx, x):
abs_2x = torch.abs(2 * x)
u = torch.log1p(abs_2x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = (torch.sign(x) * tanh_v + 1) * 0.5
ctx.save_for_backward(abs_2x)
return y
@staticmethod
def backward(ctx, grad_output):
abs_2x, = ctx.saved_tensors
u = torch.log1p(abs_2x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_2x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
return grad_x
def SoftSigmoid(x):
return SoftSigmoidFunction.apply(x)
class AdaptiveSoftTanhFunction(torch.autograd.Function):
'''
AdaptiveSoftTanh(x) = alpha * sign(x) * tanh(ln(1 + ln(1 + |x|))) + beta
f(x) (beta - |alpha|, beta + |alpha|)
dy/dx (0, alpha] if alpha > 0
dy/dx [alpha, 0) if alpha < 0
'''
@staticmethod
def forward(ctx, x, alpha, beta):
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = torch.sign(x) * tanh_v * alpha + beta
ctx.save_for_backward(x, alpha)
return y
@staticmethod
def backward(ctx, grad_output):
x, alpha, = ctx.saved_tensors
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
grad_alpha = grad_output * torch.sign(x) * tanh_v
grad_beta = grad_output.clone()
sum_dims = [d for d in range(grad_output.dim()) if d != grad_output.dim() - 1]
if sum_dims:
grad_alpha = grad_alpha.sum(dim=sum_dims)
grad_beta = grad_beta.sum(dim=sum_dims)
return grad_x, grad_alpha, grad_beta
class AdaptiveSoftTanh(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.Tensor(channels))
self.beta = nn.Parameter(torch.zeros(channels))
nn.init.normal_(self.alpha, mean=0, std=math.sqrt(2 / channels))
def forward(self, x):
return AdaptiveSoftTanhFunction.apply(x, self.alpha, self.beta)
if __name__ == "__main__":
print("=== 测试 x=0 的情况 ===")
# 创建包含0的输入张量
x = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
# 前向传播
y = SoftTanh(x)
print(f"前向输出 (x=0): {y.item()}")
# 反向传播
y.backward()
print(f"梯度值 (x=0): {x.grad.item()}")
# 验证梯度计算
grad_ok = torch.allclose(x.grad, torch.tensor([1.0]))
print(f"梯度验证: {'通过' if grad_ok else '失败'}")
x = torch.linspace(-10, 10, 10000, requires_grad=True)
y_soft = SoftTanh(x)
y_tanh = torch.tanh(x)
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
x_np = x.detach().numpy()
y_soft_np = y_soft.detach().numpy()
y_tanh_np = y_tanh.detach().numpy()
grad_soft_np = grad_soft.detach().numpy()
grad_tanh_np = grad_tanh.detach().numpy()
plt.figure(figsize=(12, 10))
# function graph
plt.subplot(2, 1, 1)
plt.plot(x_np, y_soft_np, 'b-', linewidth=2, label='SoftTanh')
plt.plot(x_np, y_tanh_np, 'r--', linewidth=2, label='Tanh')
plt.title('Function Comparison')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.legend()
# grad graph
plt.subplot(2, 1, 2)
plt.plot(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
plt.plot(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
plt.title('Gradient Comparison')
plt.xlabel('x')
plt.ylabel('dy/dx')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('./softtanh_comparison.png', dpi=300)
plt.close()
# 使用对数空间生成x值,只考虑正半轴
x = torch.logspace(-2, 3, 100000, base=10, requires_grad=True)
y_soft = SoftTanh(x)
y_tanh = torch.tanh(x)
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
# 转换为numpy数组用于绘图
x_np = x.detach().numpy()
y_soft_np = y_soft.detach().numpy()
y_tanh_np = y_tanh.detach().numpy()
grad_soft_np = grad_soft.detach().numpy()
grad_tanh_np = grad_tanh.detach().numpy()
# 创建图形和子图
plt.figure(figsize=(14, 12))
# 函数图像(对数坐标)
plt.subplot(2, 1, 1)
plt.loglog(x_np, 1 - y_soft_np, 'b-', linewidth=2, label='SoftTanh (1-y)')
plt.loglog(x_np, 1 - y_tanh_np, 'r--', linewidth=2, label='Tanh (1-y)')
plt.title('Function Asymptotic Behavior (Log-Log Scale)')
plt.xlabel('x (log scale)')
plt.ylabel('1 - y (log scale)')
plt.grid(True, which="both", ls="--")
plt.legend()
# 梯度图像(对数坐标)
plt.subplot(2, 1, 2)
plt.loglog(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
plt.loglog(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
plt.title('Gradient Decay (Log-Log Scale)')
plt.xlabel('x (log scale)')
plt.ylabel('dy/dx (log scale)')
plt.grid(True, which="both", ls="--")
plt.legend()
plt.tight_layout()
plt.savefig('./softtanh_log_comparison.png', dpi=300)
plt.close()

123
MyTokenizer.py

@ -0,0 +1,123 @@
from tokenizers import (
decoders,
models,
normalizers,
pre_tokenizers,
trainers,
Tokenizer,
processors
)
from transformers import PreTrainedTokenizerFast
import os
import jieba
SPECIAL_TOKENS = {
"pad_token": "[PAD]",
"bos_token": "[BOS]",
"eos_token": "[EOS]",
"unk_token": "[UNK]",
}
def create_tokenizer(
corpus_path: str,
save_dir: str,
language: str,
vocab_size: int,
):
"""create a univeral bpe tokenizer"""
if not os.path.exists(corpus_path):
raise FileNotFoundError(f"corpus file {corpus_path} not exists")
tokenizer = Tokenizer(models.BPE(
unk_token=SPECIAL_TOKENS["unk_token"]
))
tokenizer.normalizer = normalizers.Sequence([
normalizers.NFKC(),
normalizers.StripAccents(),
normalizers.NFD(),
normalizers.Lowercase(),
])
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Whitespace(),
pre_tokenizers.Punctuation(),
pre_tokenizers.Digits(individual_digits=True),
pre_tokenizers.ByteLevel(
add_prefix_space=True,
use_regex=True
)
])
trainer = trainers.BpeTrainer(
special_tokens=list(SPECIAL_TOKENS.values()),
vocab_size=vocab_size,
min_frequency=4,
show_progress=True
)
tokenizer.train(files=[corpus_path], trainer=trainer)
tokenizer.decoder = decoders.ByteLevel()
# 添加后处理器以自动添加特殊标记
bos_token = SPECIAL_TOKENS["bos_token"]
eos_token = SPECIAL_TOKENS["eos_token"]
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{bos_token} $A {eos_token}",
pair=f"{bos_token} $A {eos_token} {bos_token} $B {eos_token}",
special_tokens=[
(bos_token, tokenizer.token_to_id(bos_token)),
(eos_token, tokenizer.token_to_id(eos_token)),
],
)
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
**SPECIAL_TOKENS,
padding_side="right",
truncation_side="right",
do_lower_case=True
)
os.makedirs(save_dir, exist_ok=True)
fast_tokenizer.save_pretrained(save_dir)
return fast_tokenizer
def load_tokenizer(path: str):
if not os.path.exists(path):
raise FileNotFoundError("can not find tokenizer, plz train first")
return PreTrainedTokenizerFast.from_pretrained(path)
if __name__ == "__main__":
en_tokenizer = create_tokenizer(
corpus_path="data/txt/corpus.en",
save_dir="model/tokenizers/en",
language="en",
vocab_size=16384,
)
zh_tokenizer = create_tokenizer(
corpus_path="data/txt/corpus.zh",
save_dir="model/tokenizers/zh",
language="zh",
vocab_size=16384,
)
en_tokenizer = load_tokenizer("model/tokenizers/en")
zh_tokenizer = load_tokenizer("model/tokenizers/zh")
# 测试英文处理
en_text = "How many books do you think you've read so far?"
en_encoding = en_tokenizer(en_text)
print("en encoding:", en_encoding.tokens())
print("en decoding:", en_tokenizer.decode(en_encoding.input_ids))
# 测试中文处理
zh_text = "到目前为止你认为你读过多少书?"
jieba.initialize()
zh_text = " ".join(jieba.lcut(zh_text))
zh_encoding = zh_tokenizer(zh_text)
print("zh encoding:", zh_encoding.tokens())
print("zh decoding:", zh_tokenizer.decode(zh_encoding.input_ids))

0
MyTransformer.py

3
README.md

@ -1,2 +1 @@
# MyArena
# MyArena

140
TryBatch.py

@ -0,0 +1,140 @@
# try_batch.py
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
from MyLSTM import Seq2SeqLSTM
import multiprocessing
import gc
import math
import random
def try_batch(config: dict, src_max: int, tgt_max: int, batch_size: int, max_batch_iter: int = 2):
# init model, optimizer, scaler
model = Seq2SeqLSTM(
vocab_size=config["vocab_size"],
embedding_dim=config["embedding_dim"],
padding_idx=config["pad_token_id"],
num_layers=config["num_layers"],
dropout=config["dropout"],
).to(config["device"])
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config["lr_eta_max"],
betas=config["betas"],
weight_decay=config["weight_decay"]
)
scaler = GradScaler(config["device"], enabled=config["use_amp"])
# construct fake batch for testing
srcs = torch.full((1, src_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"])
tgts = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"])
labels = torch.full((1, tgt_max), config["unk_token_id"], dtype=torch.long).expand(batch_size, -1).to(config["device"])
# training loop
for i in range(max_batch_iter):
with autocast(config["device"], enabled=config["use_amp"]):
if i % 2 == 0:
tfr = random.random()
else:
tfr = 1
output = model(srcs, tgts, tfr)
loss = nn.CrossEntropyLoss()(
output.reshape(-1, output.size(-1)),
labels.contiguous().view(-1)
)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=config["grad_clip"]
)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
def _worker(config, src_max, tgt_max, batch_size, max_batch_iter, result_queue):
try:
torch.cuda.reset_peak_memory_stats()
try_batch(config, src_max, tgt_max, batch_size, max_batch_iter)
peak_mem = torch.cuda.max_memory_allocated()
result_queue.put(('success', peak_mem))
except Exception as e:
if 'out of memory' in str(e).lower() or 'alloc' in str(e).lower():
peak_mem = torch.cuda.max_memory_allocated()
result_queue.put(('oom', peak_mem))
else:
result_queue.put(('error', type(e).__name__, str(e)))
finally:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def find_batch_size(
src_max: int,
tgt_max: int,
config: dict,
initial_batch_size: int = 16384,
max_batch_iter: int = 2,
timeout: int = 30
) -> int:
low, high = 1, initial_batch_size
best_size = 0
max_attempts = math.ceil(math.log2(initial_batch_size))
attempt = 0
while low <= high and attempt < max_attempts:
attempt += 1
mid = (low + high) // 2
ctx = multiprocessing.get_context('spawn')
result_queue = ctx.Queue()
process = ctx.Process(
target=_worker,
args=(config, src_max, tgt_max, mid, max_batch_iter, result_queue)
)
try:
torch.cuda.reset_peak_memory_stats()
process.start()
process.join(timeout=timeout)
if process.is_alive():
process.terminate()
process.join()
raise TimeoutError(f"Batch size {mid} timed out after {timeout}s")
if result_queue.empty():
raise RuntimeError("Subprocess exited without returning results")
status, data = result_queue.get()
if status == 'success':
current_peak = data / (1024 ** 3)
print(f"Attempt {attempt:3d}: {mid:5d} | OK (Peak Memory: {current_peak:5.2f}GB)")
best_size = mid
low = mid + 1
elif status == 'oom':
current_peak = data / (1024 ** 3)
print(f"Attempt {attempt:3d}: {mid:5d} | Failed (OOM, Peak: {current_peak:5.2f}GB)")
high = mid - 1
elif status == 'error':
exc_type, exc_msg = data
print(f"Attempt {attempt:3d}: {mid:5d} | CRITICAL ERROR: {exc_type} - {exc_msg}")
raise RuntimeError(f"Critical error at batch size {mid}: {exc_type} - {exc_msg}")
except Exception as e:
print(f"Attempt {attempt:3d}: {mid:5d} | Error: {str(e)}")
high = mid - 1
finally:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
return best_size

BIN
__pycache__/BucketManager.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyDataset.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLSTM.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyLayer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/MyTokenizer.cpython-312.pyc

Binary file not shown.

BIN
__pycache__/TryBatch.cpython-312.pyc

Binary file not shown.

122
corpus.py

@ -0,0 +1,122 @@
import os
import pandas as pd
import glob
import jieba
from multiprocessing import Pool
from tqdm import tqdm
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
NUM_WORKERS = 20
INPUT_DIRS = {
'train': 'data/train',
'valid': 'data/valid'
}
OUTPUT_DIR = 'data/cache/txt'
EN_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.en')
ZH_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.zh')
def init_jieba():
jieba_logger = logging.getLogger('jieba')
jieba_logger.setLevel(logging.WARNING)
jieba.disable_parallel()
def process_line(record):
try:
en_text = record['translation']['en']
zh_text = record['translation']['zh']
# 中文分词
zh_words = jieba.lcut(zh_text)
zh_sentence = ' '.join(zh_words)
return (en_text, zh_sentence)
except KeyError as e:
logging.warning(f"Missing field in record: {str(e)}")
return None
except Exception as e:
logging.warning(f"Line processing error: {str(e)}")
return None
def process_shard(shard_path):
try:
df = pd.read_parquet(shard_path)
records = df.to_dict(orient='records')
total = len(records)
logging.info(f"Processing {shard_path} ({total} lines)")
with Pool(NUM_WORKERS, initializer=init_jieba) as pool:
results = []
for result in tqdm(
pool.imap(process_line, records),
total=total,
desc=f"Processing {os.path.basename(shard_path)}",
unit="lines",
colour='green',
bar_format='{l_bar}{bar:32}{r_bar}'
):
if result is not None:
results.append(result)
en_sentences, zh_sentences = zip(*results) if results else ([], [])
logging.info(f"Processed {len(results)} lines from {shard_path}")
return list(en_sentences), list(zh_sentences)
except Exception as e:
logging.error(f"Shard processing failed: {shard_path} - {str(e)}")
return [], []
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
all_shards = []
for _, dir_path in INPUT_DIRS.items():
shards = glob.glob(os.path.join(dir_path, '*.parquet'))
if not shards:
logging.warning(f"No Parquet files found in {dir_path}")
all_shards.extend(shards)
if not all_shards:
logging.error("No Parquet files found in any input directories")
return
all_shards.sort(key=lambda x: os.path.abspath(x))
logging.info(f"Found {len(all_shards)} shards to process")
jieba.initialize()
all_en = []
all_zh = []
for shard_path in all_shards:
en_sentences, zh_sentences = process_shard(shard_path)
all_en.extend(en_sentences)
all_zh.extend(zh_sentences)
if len(all_en) != len(all_zh):
logging.warning(f"Data length mismatch: {len(all_en)} English vs {len(all_zh)} Chinese sentences")
logging.info(f"Writing {len(all_en)} sentences to final files")
with open(EN_OUTPUT_PATH, 'w', encoding='utf-8') as f_en, \
open(ZH_OUTPUT_PATH, 'w', encoding='utf-8') as f_zh:
for en, zh in tqdm(zip(all_en, all_zh), total=len(all_en), desc="Writing files"):
f_en.write(en + '\n')
f_zh.write(zh + '\n')
logging.info("Corpus generation completed successfully")
logging.info(f"English corpus saved at: {EN_OUTPUT_PATH}")
logging.info(f"Chinese corpus saved at: {ZH_OUTPUT_PATH}")
if __name__ == "__main__":
main()

1846
data/cache/train/buckets_meta.json

File diff suppressed because it is too large

BIN
data/cache/train/cached_src_102_129_tgt_82_99_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_102_129_tgt_99_128_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_12_15_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_2_6_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_6_9_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_6_9_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_9_12_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_9_12_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_10_13_tgt_9_12_shard_2.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_12_15_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_12_15_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_12_15_shard_2.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_6_9_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_9_12_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_13_16_tgt_9_12_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_12_15_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_12_15_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_15_18_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_15_18_shard_2.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_16_19_tgt_9_12_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_12_15_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_15_18_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_18_21_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_19_22_tgt_9_12_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_12_15_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_18_21_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_21_24_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_22_25_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_21_24_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_24_27_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_25_28_tgt_33_36_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_15_18_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_24_27_shard_1.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_33_36_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_28_31_tgt_36_40_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_18_21_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_33_36_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_31_34_tgt_36_40_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_21_24_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_33_36_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_36_40_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_34_38_tgt_40_44_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_38_42_tgt_24_27_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_38_42_tgt_27_30_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_38_42_tgt_30_33_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_38_42_tgt_33_36_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

BIN
data/cache/train/cached_src_38_42_tgt_36_40_shard_0.pkl.zst (Stored with Git LFS)

Binary file not shown.

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save