You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
122 lines
3.7 KiB
122 lines
3.7 KiB
import os
|
|
import pandas as pd
|
|
import glob
|
|
import jieba
|
|
from multiprocessing import Pool
|
|
from tqdm import tqdm
|
|
import logging
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
NUM_WORKERS = 20
|
|
INPUT_DIRS = {
|
|
'train': 'data/train',
|
|
'valid': 'data/valid'
|
|
}
|
|
OUTPUT_DIR = 'data/cache/txt'
|
|
EN_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.en')
|
|
ZH_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'corpus.zh')
|
|
|
|
|
|
def init_jieba():
|
|
jieba_logger = logging.getLogger('jieba')
|
|
jieba_logger.setLevel(logging.WARNING)
|
|
jieba.disable_parallel()
|
|
|
|
|
|
def process_line(record):
|
|
try:
|
|
en_text = record['translation']['en']
|
|
zh_text = record['translation']['zh']
|
|
|
|
# 中文分词
|
|
zh_words = jieba.lcut(zh_text)
|
|
zh_sentence = ' '.join(zh_words)
|
|
|
|
return (en_text, zh_sentence)
|
|
except KeyError as e:
|
|
logging.warning(f"Missing field in record: {str(e)}")
|
|
return None
|
|
except Exception as e:
|
|
logging.warning(f"Line processing error: {str(e)}")
|
|
return None
|
|
|
|
|
|
def process_shard(shard_path):
|
|
try:
|
|
df = pd.read_parquet(shard_path)
|
|
records = df.to_dict(orient='records')
|
|
total = len(records)
|
|
logging.info(f"Processing {shard_path} ({total} lines)")
|
|
|
|
with Pool(NUM_WORKERS, initializer=init_jieba) as pool:
|
|
results = []
|
|
for result in tqdm(
|
|
pool.imap(process_line, records),
|
|
total=total,
|
|
desc=f"Processing {os.path.basename(shard_path)}",
|
|
unit="lines",
|
|
colour='green',
|
|
bar_format='{l_bar}{bar:32}{r_bar}'
|
|
):
|
|
if result is not None:
|
|
results.append(result)
|
|
|
|
en_sentences, zh_sentences = zip(*results) if results else ([], [])
|
|
logging.info(f"Processed {len(results)} lines from {shard_path}")
|
|
return list(en_sentences), list(zh_sentences)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Shard processing failed: {shard_path} - {str(e)}")
|
|
return [], []
|
|
|
|
|
|
def main():
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
all_shards = []
|
|
for _, dir_path in INPUT_DIRS.items():
|
|
shards = glob.glob(os.path.join(dir_path, '*.parquet'))
|
|
if not shards:
|
|
logging.warning(f"No Parquet files found in {dir_path}")
|
|
all_shards.extend(shards)
|
|
|
|
if not all_shards:
|
|
logging.error("No Parquet files found in any input directories")
|
|
return
|
|
|
|
all_shards.sort(key=lambda x: os.path.abspath(x))
|
|
|
|
logging.info(f"Found {len(all_shards)} shards to process")
|
|
jieba.initialize()
|
|
|
|
all_en = []
|
|
all_zh = []
|
|
|
|
for shard_path in all_shards:
|
|
en_sentences, zh_sentences = process_shard(shard_path)
|
|
all_en.extend(en_sentences)
|
|
all_zh.extend(zh_sentences)
|
|
|
|
if len(all_en) != len(all_zh):
|
|
logging.warning(f"Data length mismatch: {len(all_en)} English vs {len(all_zh)} Chinese sentences")
|
|
|
|
logging.info(f"Writing {len(all_en)} sentences to final files")
|
|
|
|
with open(EN_OUTPUT_PATH, 'w', encoding='utf-8') as f_en, \
|
|
open(ZH_OUTPUT_PATH, 'w', encoding='utf-8') as f_zh:
|
|
|
|
for en, zh in tqdm(zip(all_en, all_zh), total=len(all_en), desc="Writing files"):
|
|
f_en.write(en + '\n')
|
|
f_zh.write(zh + '\n')
|
|
|
|
logging.info("Corpus generation completed successfully")
|
|
logging.info(f"English corpus saved at: {EN_OUTPUT_PATH}")
|
|
logging.info(f"Chinese corpus saved at: {ZH_OUTPUT_PATH}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|