You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

423 lines
16 KiB

import json
from typing import List, Optional, Generator, Dict, Tuple
from collections import defaultdict, Counter
from pathlib import Path
import shutil
import math
from MyDataset import MyDataset
class BucketManager:
def __init__(
self,
cache_path: str,
min_samples: int = 1024,
max_samples: int = 524288,
num_cuts: int = 16,
processed_data: Optional[List[dict]] = None,
force_rebuild: bool = False,
):
self._validate_init_params(cache_path, processed_data, force_rebuild)
self.cache_path = Path(cache_path)
self.num_cuts = num_cuts
self.min_samples = min_samples
self.max_samples = max_samples
self.buckets = defaultdict(lambda: defaultdict(list))
self.valid_buckets = []
self.total_samples = 0
self.src_buckets = []
self.tgt_buckets = []
self.total_original_samples = 0
self.discarded_samples = 0
self.total_padding = 0
self.total_actual_tokens = 0
if processed_data is not None:
if not force_rebuild and self._metadata_exists():
self._load_from_metadata()
else:
if force_rebuild and self.cache_path.exists():
shutil.rmtree(self.cache_path)
self.cache_path.mkdir(parents=True, exist_ok=True)
self._process_data_with_dp(processed_data)
self._save_metadata()
else:
if not self._metadata_exists():
raise FileNotFoundError(f"No cache found at {cache_path}")
self._load_from_metadata()
def _validate_init_params(
self,
cache_path: str,
processed_data: Optional[List[dict]],
force_rebuild: bool
):
if not cache_path:
raise ValueError("cache_path cannot be empty")
if processed_data is None and not force_rebuild and not Path(cache_path).exists():
raise FileNotFoundError(f"Cache path missing: {cache_path}")
@staticmethod
def _optimal_1d_partition(lengths: List[int], num_cuts: int) -> List[Tuple[int, int]]:
if not lengths:
return []
length_counts = Counter(lengths)
unique_lengths = sorted(length_counts.keys())
if len(unique_lengths) <= num_cuts:
buckets = []
for i, length in enumerate(unique_lengths):
start = length
end = unique_lengths[i + 1] if i + 1 < len(unique_lengths) else length + 1
buckets.append((start, end))
return buckets
n = len(unique_lengths)
dp = [[float('inf')] * (num_cuts + 1) for _ in range(n + 1)]
parent = [[-1] * (num_cuts + 1) for _ in range(n + 1)]
dp[0][0] = 0
for i in range(1, n + 1):
for j in range(1, min(i + 1, num_cuts + 1)):
for k in range(j - 1, i):
if dp[k][j - 1] == float('inf'):
continue
bucket_start = unique_lengths[k]
bucket_end = unique_lengths[i - 1] + 1
bucket_max = unique_lengths[i - 1]
padding_in_bucket = 0
for idx in range(k, i):
length = unique_lengths[idx]
count = length_counts[length]
padding_in_bucket += (bucket_max - length) * count
total_padding = dp[k][j - 1] + padding_in_bucket
if total_padding < dp[i][j]:
dp[i][j] = total_padding
parent[i][j] = k
if dp[n][num_cuts] == float('inf'):
return [(min(lengths), max(lengths) + 1)]
buckets = []
i, j = n, num_cuts
while j > 0 and i > 0:
k = parent[i][j]
if k < 0 or k >= i:
break
bucket_start = unique_lengths[k]
bucket_end = unique_lengths[i - 1] + 1
buckets.append((bucket_start, bucket_end))
i, j = k, j - 1
buckets.reverse()
return buckets
def _create_2d_buckets(
self,
data: List[Tuple[int, int]]
) -> List[Tuple[Tuple[int, int], Tuple[int, int]]]:
if not data:
return []
src_lengths = [item[0] for item in data]
tgt_lengths = [item[1] for item in data]
print("Calculating optimal buckets for src lengths...")
self.src_buckets = self._optimal_1d_partition(src_lengths, self.num_cuts)
print(f"src buckets: {self.src_buckets}")
print("Calculating optimal buckets for tgt lengths...")
self.tgt_buckets = self._optimal_1d_partition(tgt_lengths, self.num_cuts)
print(f"tgt buckets: {self.tgt_buckets}")
bucket_samples = {}
for src_bucket in self.src_buckets:
for tgt_bucket in self.tgt_buckets:
bucket_key = (src_bucket, tgt_bucket)
bucket_samples[bucket_key] = []
for src_len, tgt_len in data:
src_bucket = None
tgt_bucket = None
for bucket in self.src_buckets:
if bucket[0] <= src_len < bucket[1]:
src_bucket = bucket
break
for bucket in self.tgt_buckets:
if bucket[0] <= tgt_len < bucket[1]:
tgt_bucket = bucket
break
if src_bucket and tgt_bucket:
bucket_key = (src_bucket, tgt_bucket)
bucket_samples[bucket_key].append((src_len, tgt_len))
valid_buckets = []
for bucket_key, samples in bucket_samples.items():
if len(samples) >= self.min_samples:
valid_buckets.append(bucket_key)
return valid_buckets
def _process_data_with_dp(self, data: List[dict]):
length_pairs = [(len(item["src"]), len(item["tgt"])) for item in data]
self.total_original_samples = len(data)
valid_buckets = self._create_2d_buckets(length_pairs)
bucket_data = {bucket: [] for bucket in valid_buckets}
total_actual_tokens = 0
total_padding = 0
used_samples = 0
for item in data:
src_len = len(item["src"])
tgt_len = len(item["tgt"])
found_bucket = False
for (src_start, src_end), (tgt_start, tgt_end) in valid_buckets:
if src_start <= src_len < src_end and tgt_start <= tgt_len < tgt_end:
src_max = src_end - 1
tgt_max = tgt_end - 1
src_pad = src_max - src_len
tgt_pad = tgt_max - tgt_len
total_padding += src_pad + tgt_pad
total_actual_tokens += src_len + tgt_len
used_samples += 1
bucket_data[(src_start, src_end), (tgt_start, tgt_end)].append(item)
found_bucket = True
break
if not found_bucket:
pass
self.discarded_samples = len(data) - used_samples
self.total_padding = total_padding
self.total_actual_tokens = total_actual_tokens
print("\nBuilding datasets from buckets...")
total_samples = 0
self.valid_buckets = []
sorted_buckets = sorted(valid_buckets, key=lambda x: (x[1][0], x[0][0]))
for bucket_key in sorted_buckets:
(src_start, src_end), (tgt_start, tgt_end) = bucket_key
bucket_items = bucket_data[bucket_key]
data_len = len(bucket_items)
base_num_shards = max(1, (data_len + self.max_samples - 1) // self.max_samples)
last_shard_size = data_len % self.max_samples
if last_shard_size == 0 and data_len > 0:
last_shard_size = self.max_samples
if base_num_shards > 1 and last_shard_size <= self.max_samples * 0.5:
num_shards = base_num_shards - 1
else:
num_shards = base_num_shards
for shard_idx in range(num_shards):
if shard_idx == num_shards - 1:
start_idx = shard_idx * self.max_samples
end_idx = data_len
else:
start_idx = shard_idx * self.max_samples
end_idx = start_idx + self.max_samples
shard_data = bucket_items[start_idx:end_idx]
shard_size = len(shard_data)
MyDataset(
cache_path=str(self.cache_path),
processed_data=shard_data,
src_range=(src_start, src_end),
tgt_range=(tgt_start, tgt_end),
shard_idx=shard_idx,
)
bucket_info = {
"src_range": (src_start, src_end),
"tgt_range": (tgt_start, tgt_end),
"shard_idx": shard_idx,
"num_shards": num_shards,
"suggested_batch_size": 0,
"num_samples": shard_size
}
self.valid_buckets.append(bucket_info)
total_samples += shard_size
self.total_samples = total_samples
print(f"Valid buckets: {len(self.valid_buckets)}")
print(f"Total samples: {total_samples}")
def _save_metadata(self):
meta = {
"num_cuts": self.num_cuts,
"valid_buckets": [{
"src_range": list(b["src_range"]),
"tgt_range": list(b["tgt_range"]),
"shard_idx": b["shard_idx"],
"num_shards": b["num_shards"],
"suggested_batch_size": b["suggested_batch_size"],
"num_samples": b["num_samples"]
} for b in self.valid_buckets],
"min_samples": self.min_samples,
"max_samples": self.max_samples,
"total_samples": self.total_samples,
"total_original_samples": self.total_original_samples,
"discarded_samples": self.discarded_samples,
"total_padding": self.total_padding,
"total_actual_tokens": self.total_actual_tokens
}
meta_path = self.cache_path / "buckets_meta.json"
meta_path.write_text(json.dumps(meta, indent=2))
print(f"Metadata saved to {meta_path}")
def _load_from_metadata(self):
meta_path = self.cache_path / "buckets_meta.json"
meta = json.loads(meta_path.read_text())
self.num_cuts = meta.get("num_cuts", 8)
self.min_samples = meta.get("min_samples", self.min_samples)
self.max_samples = meta.get("max_samples", self.max_samples)
self.total_samples = meta.get("total_samples", 0)
self.total_original_samples = meta.get("total_original_samples", 0)
self.discarded_samples = meta.get("discarded_samples", 0)
self.total_padding = meta.get("total_padding", 0)
self.total_actual_tokens = meta.get("total_actual_tokens", 0)
self.valid_buckets = []
for b in meta["valid_buckets"]:
self.valid_buckets.append({
"src_range": tuple(b["src_range"]),
"tgt_range": tuple(b["tgt_range"]),
"shard_idx": b["shard_idx"],
"num_shards": b["num_shards"],
"suggested_batch_size": b.get("suggested_batch_size", 0),
"num_samples": b["num_samples"]
})
print(f"Loaded {len(self.valid_buckets)} buckets from {meta_path}")
def _metadata_exists(self) -> bool:
return (self.cache_path / "buckets_meta.json").exists()
def __iter__(self) -> Generator[Dict, None, None]:
yield from self.valid_buckets
def __len__(self) -> int:
return len(self.valid_buckets)
def __getitem__(self, index: int) -> dict:
return self.valid_buckets[index]
def reset_batch_size(self) -> None:
for b in self.valid_buckets:
b["suggested_batch_size"] = 0
def find_optimal_batch_size(self, find_batch_size_func) -> None:
bucket_type_map = {}
for i, bucket in enumerate(self.valid_buckets):
src_range = bucket["src_range"]
tgt_range = bucket["tgt_range"]
bucket_type = (src_range, tgt_range)
if bucket_type in bucket_type_map:
bucket["suggested_batch_size"] = bucket_type_map[bucket_type]
print(f"Bucket {i + 1}/{len(self.valid_buckets)} reused batch size: {bucket['suggested_batch_size']}\n")
continue
print(
f"Searching optimal batch size for bucket {i + 1}/{len(self.valid_buckets)} " + \
f"src: {src_range} tgt: {tgt_range} ..."
)
batch_size = find_batch_size_func(src_max=src_range[1], tgt_max=tgt_range[1])
bucket["suggested_batch_size"] = batch_size
bucket_type_map[bucket_type] = bucket["suggested_batch_size"]
print(f"Found batch size: {bucket['suggested_batch_size']}\n")
self._save_metadata()
def found_optimal(self) -> bool:
return all(b["suggested_batch_size"] > 0 for b in self.valid_buckets)
def get_total_iterations(self, safety_factor: float, drop_last: bool = False) -> int:
total_iterations = 0
for bucket in self.valid_buckets:
batch_size = math.floor(bucket["suggested_batch_size"] * safety_factor)
num_samples = bucket["num_samples"]
if batch_size <= 0:
raise ValueError("Batch size must be positive. Call find_optimal_batch_size first.")
if drop_last:
iterations = num_samples // batch_size
else:
iterations = (num_samples + batch_size - 1) // batch_size
total_iterations += iterations
return total_iterations
def get_info(self) -> List[dict]:
return self.valid_buckets
def print_stats(self):
print("\n" + "=" * 60)
print("Bucket Statistics")
print("=" * 60)
print(f"Total buckets: {len(self.valid_buckets)}")
print(f"Total samples: {self.total_samples}")
print(f"Min samples per bucket: {self.min_samples}")
print(f"Max samples per bucket: {self.max_samples}")
print("\nData Distribution:")
print("-" * 60)
print(f"Original samples: {self.total_original_samples}")
print(f"Discarded samples: {self.discarded_samples} ({self.discarded_samples / self.total_original_samples * 100:.2f}%)")
total_tokens_with_padding = self.total_actual_tokens + self.total_padding
if total_tokens_with_padding > 0:
padding_rate = self.total_padding / total_tokens_with_padding
else:
padding_rate = 0.0
print(f"Total padding tokens: {self.total_padding}")
print(f"Padding rate: {padding_rate * 100:.2f}%")
print("\nBucket Details:")
print("-" * 80)
print(f"{'Bucket ID':<8} {'Src Range':<15} {'Tgt Range':<15} {'Samples':<10} {'Shards':<8} {'Batch Size':<10}")
print("-" * 80)
for i, bucket in enumerate(self.valid_buckets):
src_range = f"{bucket['src_range'][0]}-{bucket['src_range'][1] - 1}"
tgt_range = f"{bucket['tgt_range'][0]}-{bucket['tgt_range'][1] - 1}"
samples = bucket['num_samples']
shards = f"{bucket['shard_idx'] + 1}/{bucket['num_shards']}"
batch_size = bucket['suggested_batch_size']
print(f"{i:<8} {src_range:<15} {tgt_range:<15} {samples:<10} {shards:<8} {batch_size:<10}")