|
|
|
# inference.py
|
|
|
|
import torch
|
|
|
|
import pandas as pd
|
|
|
|
from sacrebleu import corpus_bleu
|
|
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from tqdm import tqdm
|
|
|
|
from pathlib import Path
|
|
|
|
from MyTokenizer import load_tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset(path):
|
|
|
|
df = pd.read_parquet(path)
|
|
|
|
try:
|
|
|
|
return list(zip(
|
|
|
|
df['translation'].apply(lambda x: x['en']),
|
|
|
|
df['translation'].apply(lambda x: x['zh'])
|
|
|
|
))
|
|
|
|
except KeyError as e:
|
|
|
|
raise ValueError(f"Missing required translation key: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def make_batches(items, batch_size):
|
|
|
|
return [items[i: i + batch_size] for i in range(0, len(items), batch_size)]
|
|
|
|
|
|
|
|
|
|
|
|
class BLEUEvaluator:
|
|
|
|
def __init__(self, model, en_tokenizer, zh_tokenizer, device, batch_size, src_max_length, tgt_max_length):
|
|
|
|
self.model = model
|
|
|
|
self.en_tokenizer = en_tokenizer
|
|
|
|
self.zh_tokenizer = zh_tokenizer
|
|
|
|
self.device = device
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.src_max_length = src_max_length
|
|
|
|
self.tgt_max_length = tgt_max_length
|
|
|
|
self.pad_token_id = en_tokenizer.pad_token_id
|
|
|
|
self.bos_token_id = zh_tokenizer.bos_token_id
|
|
|
|
self.eos_token_id = zh_tokenizer.eos_token_id
|
|
|
|
|
|
|
|
def preprocess_batch(self, batch):
|
|
|
|
valid_samples = []
|
|
|
|
discarded = 0
|
|
|
|
for en_text, zh_ref in batch:
|
|
|
|
try:
|
|
|
|
if len(self.en_tokenizer.encode(en_text)) <= self.src_max_length:
|
|
|
|
valid_samples.append((en_text, zh_ref))
|
|
|
|
else:
|
|
|
|
discarded += 1
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Tokenization error: {e}")
|
|
|
|
discarded += 1
|
|
|
|
return valid_samples, discarded
|
|
|
|
|
|
|
|
def encode_batch(self, en_batch):
|
|
|
|
batch_encoded = [self.en_tokenizer.encode(text) for text in en_batch]
|
|
|
|
src_ids = [torch.tensor(x, device=self.device) for x in batch_encoded]
|
|
|
|
return pad_sequence(
|
|
|
|
src_ids,
|
|
|
|
batch_first=True,
|
|
|
|
padding_value=self.pad_token_id
|
|
|
|
)
|
|
|
|
|
|
|
|
def decode_predictions(self, outputs):
|
|
|
|
return [
|
|
|
|
self.zh_tokenizer.decode(output, skip_special_tokens=True).replace(" ", "")
|
|
|
|
for output in outputs
|
|
|
|
]
|
|
|
|
|
|
|
|
def write_comparison(self, output_path, en_batch, zh_batch, hypos_batch):
|
|
|
|
with open(output_path, 'a', encoding='utf-8') as f:
|
|
|
|
for en, zh, hyp in zip(en_batch, zh_batch, hypos_batch):
|
|
|
|
f.write(f"EN: {en}\nZH_REF: {zh}\nZH_HYP: {hyp}\n\n")
|
|
|
|
|
|
|
|
def evaluate(self, dataset_path, output_path="translations_comparison.txt"):
|
|
|
|
Path(output_path).unlink(missing_ok=True)
|
|
|
|
|
|
|
|
data = load_dataset(dataset_path)
|
|
|
|
hypotheses = []
|
|
|
|
references = []
|
|
|
|
total_discarded = 0
|
|
|
|
|
|
|
|
batches = make_batches(data, self.batch_size)
|
|
|
|
progress_bar = tqdm(
|
|
|
|
batches, desc="Evaluating", unit="batch", colour="blue",
|
|
|
|
bar_format='{l_bar}{bar:32}{r_bar}',
|
|
|
|
dynamic_ncols=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
for batch in progress_bar:
|
|
|
|
valid_samples, discarded = self.preprocess_batch(batch)
|
|
|
|
total_discarded += discarded
|
|
|
|
if not valid_samples:
|
|
|
|
continue
|
|
|
|
|
|
|
|
en_batch, zh_batch = zip(*valid_samples)
|
|
|
|
src_ids = self.encode_batch(en_batch)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = self.model.greedy_decode(
|
|
|
|
src_ids,
|
|
|
|
bos_token_id=self.bos_token_id,
|
|
|
|
eos_token_id=self.eos_token_id,
|
|
|
|
pad_token_id=self.pad_token_id,
|
|
|
|
max_length=self.tgt_max_length
|
|
|
|
)
|
|
|
|
# outputs = self.model.beam_decode(
|
|
|
|
# src_ids,
|
|
|
|
# bos_token_id=self.bos_token_id,
|
|
|
|
# eos_token_id=self.eos_token_id,
|
|
|
|
# max_length=self.tgt_max_length,
|
|
|
|
# beam_size=8,
|
|
|
|
# length_penalty=0.95,
|
|
|
|
# repetition_penalty=1.05
|
|
|
|
# )
|
|
|
|
|
|
|
|
current_hypos = self.decode_predictions(outputs)
|
|
|
|
hypotheses.extend(current_hypos)
|
|
|
|
references.extend([[ref] for ref in zh_batch])
|
|
|
|
|
|
|
|
self.write_comparison(output_path, en_batch, zh_batch, current_hypos)
|
|
|
|
|
|
|
|
progress_bar.set_postfix({
|
|
|
|
"Processed": f"{len(hypotheses)}/{len(data)}",
|
|
|
|
"Discarded": total_discarded
|
|
|
|
})
|
|
|
|
|
|
|
|
bleu_score = corpus_bleu(hypotheses, references, lowercase=False, smooth_method="floor", tokenize="zh")
|
|
|
|
print(
|
|
|
|
f"\nEvaluation Complete\n"
|
|
|
|
f"Total samples: {len(data)}\n"
|
|
|
|
f"Valid samples: {len(hypotheses)}\n"
|
|
|
|
f"Discarded samples: {total_discarded}\n"
|
|
|
|
f"BLEU score: {bleu_score.score:.2f}\n"
|
|
|
|
f"Comparison file saved at: {output_path}"
|
|
|
|
)
|
|
|
|
return bleu_score
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
en_tokenizer = load_tokenizer("model/tokenizers/en")
|
|
|
|
zh_tokenizer = load_tokenizer("model/tokenizers/zh")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
model = torch.load("model/checkpoints/latest_model.pt", map_location=device, weights_only=False)
|
|
|
|
# model = torch.load("model/checkpoints/model_003.pt", map_location=device, weights_only=False)
|
|
|
|
# model = torch.load("temp/latest_model.pt", map_location=device, weights_only=False)
|
|
|
|
model.eval().to(device)
|
|
|
|
|
|
|
|
evaluator = BLEUEvaluator(
|
|
|
|
model=model,
|
|
|
|
en_tokenizer=en_tokenizer,
|
|
|
|
zh_tokenizer=zh_tokenizer,
|
|
|
|
device=device,
|
|
|
|
batch_size=512,
|
|
|
|
src_max_length=128,
|
|
|
|
tgt_max_length=256
|
|
|
|
)
|
|
|
|
|
|
|
|
_ = evaluator.evaluate(
|
|
|
|
"data/valid/validation-00000-of-00001.parquet",
|
|
|
|
output_path="results/translation_comparison.txt"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|