You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

165 lines
5.9 KiB

# inference.py
import torch
import pandas as pd
from sacrebleu import corpus_bleu
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from pathlib import Path
from MyTokenizer import load_tokenizer
def load_dataset(path):
df = pd.read_parquet(path)
try:
return list(zip(
df['translation'].apply(lambda x: x['en']),
df['translation'].apply(lambda x: x['zh'])
))
except KeyError as e:
raise ValueError(f"Missing required translation key: {e}")
def make_batches(items, batch_size):
return [items[i: i + batch_size] for i in range(0, len(items), batch_size)]
class BLEUEvaluator:
def __init__(self, model, en_tokenizer, zh_tokenizer, device, batch_size, src_max_length, tgt_max_length):
self.model = model
self.en_tokenizer = en_tokenizer
self.zh_tokenizer = zh_tokenizer
self.device = device
self.batch_size = batch_size
self.src_max_length = src_max_length
self.tgt_max_length = tgt_max_length
self.pad_token_id = en_tokenizer.pad_token_id
self.bos_token_id = zh_tokenizer.bos_token_id
self.eos_token_id = zh_tokenizer.eos_token_id
def preprocess_batch(self, batch):
valid_samples = []
discarded = 0
for en_text, zh_ref in batch:
try:
if len(self.en_tokenizer.encode(en_text)) <= self.src_max_length:
valid_samples.append((en_text, zh_ref))
else:
discarded += 1
except Exception as e:
print(f"Tokenization error: {e}")
discarded += 1
return valid_samples, discarded
def encode_batch(self, en_batch):
batch_encoded = [self.en_tokenizer.encode(text) for text in en_batch]
src_ids = [torch.tensor(x, device=self.device) for x in batch_encoded]
return pad_sequence(
src_ids,
batch_first=True,
padding_value=self.pad_token_id
)
def decode_predictions(self, outputs):
return [
self.zh_tokenizer.decode(output, skip_special_tokens=True).replace(" ", "")
for output in outputs
]
def write_comparison(self, output_path, en_batch, zh_batch, hypos_batch):
with open(output_path, 'a', encoding='utf-8') as f:
for en, zh, hyp in zip(en_batch, zh_batch, hypos_batch):
f.write(f"EN: {en}\nZH_REF: {zh}\nZH_HYP: {hyp}\n\n")
def evaluate(self, dataset_path, output_path="translations_comparison.txt"):
Path(output_path).unlink(missing_ok=True)
data = load_dataset(dataset_path)
hypotheses = []
references = []
total_discarded = 0
batches = make_batches(data, self.batch_size)
progress_bar = tqdm(
batches, desc="Evaluating", unit="batch", colour="blue",
bar_format='{l_bar}{bar:32}{r_bar}',
dynamic_ncols=True,
)
for batch in progress_bar:
valid_samples, discarded = self.preprocess_batch(batch)
total_discarded += discarded
if not valid_samples:
continue
en_batch, zh_batch = zip(*valid_samples)
src_ids = self.encode_batch(en_batch)
with torch.no_grad():
outputs = self.model.greedy_decode(
src_ids,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
max_length=self.tgt_max_length
)
# outputs = self.model.beam_decode(
# src_ids,
# bos_token_id=self.bos_token_id,
# eos_token_id=self.eos_token_id,
# max_length=self.tgt_max_length,
# beam_size=8,
# length_penalty=0.95,
# repetition_penalty=1.05
# )
current_hypos = self.decode_predictions(outputs)
hypotheses.extend(current_hypos)
references.extend([[ref] for ref in zh_batch])
self.write_comparison(output_path, en_batch, zh_batch, current_hypos)
progress_bar.set_postfix({
"Processed": f"{len(hypotheses)}/{len(data)}",
"Discarded": total_discarded
})
bleu_score = corpus_bleu(hypotheses, references, lowercase=False, smooth_method="floor", tokenize="zh")
print(
f"\nEvaluation Complete\n"
f"Total samples: {len(data)}\n"
f"Valid samples: {len(hypotheses)}\n"
f"Discarded samples: {total_discarded}\n"
f"BLEU score: {bleu_score.score:.2f}\n"
f"Comparison file saved at: {output_path}"
)
return bleu_score
def main():
en_tokenizer = load_tokenizer("model/tokenizers/en")
zh_tokenizer = load_tokenizer("model/tokenizers/zh")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model/checkpoints/latest_model.pt", map_location=device, weights_only=False)
# model = torch.load("model/checkpoints/model_009.pt", map_location=device, weights_only=False)
# model = torch.load("temp/latest_model.pt", map_location=device, weights_only=False)
model.eval().to(device)
evaluator = BLEUEvaluator(
model=model,
en_tokenizer=en_tokenizer,
zh_tokenizer=zh_tokenizer,
device=device,
batch_size=512,
src_max_length=128,
tgt_max_length=256
)
_ = evaluator.evaluate(
"data/valid/validation-00000-of-00001.parquet",
output_path="results/translation_comparison.txt"
)
if __name__ == "__main__":
main()