You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

84 lines
2.8 KiB

import os
import pickle
import zstandard as zstd
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
class MyDataset(Dataset):
def __init__(
self,
cache_path: str,
processed_data: list[dict] = None,
src_range: tuple = (0, 0),
tgt_range: tuple = (0, 0),
shard_idx: int = 0,
zstd_level: int = 9
):
self.src_range = src_range
self.tgt_range = tgt_range
self.zstd_level = zstd_level
self.processed_data = processed_data
base_name = f"cached_src_{src_range[0]}_{src_range[1]}_tgt_{tgt_range[0]}_{tgt_range[1]}"
self.cache_path = os.path.join(
cache_path,
f"{base_name}_shard_{shard_idx}.pkl.zst"
)
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as f:
dctx = zstd.ZstdDecompressor()
with dctx.stream_reader(f) as reader:
self.processed_data = pickle.load(reader)
return
elif self.processed_data is None:
raise ValueError(f"Cache file not found: {self.cache_path} and no data provided")
# Cache保存
print("Saving cache, please wait...")
cctx = zstd.ZstdCompressor(level=self.zstd_level)
with open(self.cache_path, 'wb') as f:
with cctx.stream_writer(f) as compressor:
pickle.dump(
self.processed_data,
compressor,
protocol=pickle.HIGHEST_PROTOCOL
)
print(f"Cache saved at: {self.cache_path}")
def __len__(self):
return len(self.processed_data)
def __getitem__(self, idx):
item = self.processed_data[idx]
return {
"src": torch.from_numpy(item["src"]).long(),
"tgt": torch.from_numpy(item["tgt"]).long(),
"label": torch.from_numpy(item["label"]).long(),
}
def collate_fn(batch, pad_token_id: int, src_max_length: int, tgt_max_length: int):
src = [item["src"] for item in batch]
tgt = [item["tgt"] for item in batch]
label = [item["label"] for item in batch]
# pad src
src_padded = [F.pad(t, (0, src_max_length - t.size(0)), value=pad_token_id) for t in src]
src_padded = torch.stack(src_padded)
# pad tgt
tgt_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in tgt]
tgt_padded = torch.stack(tgt_padded)
# pad label
label_padded = [F.pad(t, (0, tgt_max_length - t.size(0)), value=pad_token_id) for t in label]
label_padded = torch.stack(label_padded)
return {
"src": src_padded,
"tgt": tgt_padded,
"label": label_padded,
}