You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
329 lines
12 KiB
329 lines
12 KiB
# Some test dataset and evaluation method are ref from https://github.com/OSU-NLP-Group/HippoRAG/tree/main/data , many thanks
|
|
|
|
################################################################################
|
|
# Note: This evaluation script will cost a lot of LLM token usage, please make sure you have enough token budget.
|
|
################################################################################
|
|
import argparse
|
|
import ast
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import warnings
|
|
from collections import defaultdict
|
|
from typing import List, Tuple
|
|
|
|
import pandas as pd
|
|
|
|
from deepsearcher.configuration import Configuration, init_config
|
|
from deepsearcher.offline_loading import load_from_local_files
|
|
from deepsearcher.online_query import naive_retrieve, retrieve
|
|
|
|
httpx_logger = logging.getLogger("httpx") # disable openai's logger output
|
|
httpx_logger.setLevel(logging.WARNING)
|
|
|
|
|
|
warnings.simplefilter(action="ignore", category=FutureWarning) # disable warning output
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
k_list = [2, 5]
|
|
|
|
|
|
def _deepsearch_retrieve_titles(
|
|
question: str,
|
|
retry_num: int = 4,
|
|
base_wait_time: int = 4,
|
|
max_iter: int = 3,
|
|
) -> Tuple[List[str], int, bool]:
|
|
"""
|
|
Retrieve document titles using DeepSearcher with retry mechanism.
|
|
|
|
Args:
|
|
question (str): The query question.
|
|
retry_num (int, optional): Number of retry attempts. Defaults to 4.
|
|
base_wait_time (int, optional): Base wait time between retries in seconds. Defaults to 4.
|
|
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
|
|
|
|
Returns:
|
|
Tuple[List[str], int, bool]: A tuple containing:
|
|
- List of retrieved document titles
|
|
- Number of tokens consumed
|
|
- Boolean indicating whether the retrieval failed
|
|
"""
|
|
retrieved_results = []
|
|
consume_tokens = 0
|
|
for i in range(retry_num):
|
|
try:
|
|
retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter)
|
|
break
|
|
except Exception:
|
|
wait_time = base_wait_time * (2**i)
|
|
print(f"Parse LLM's output failed, retry again after {wait_time} seconds...")
|
|
time.sleep(wait_time)
|
|
if retrieved_results:
|
|
retrieved_titles = [
|
|
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
|
|
]
|
|
fail = False
|
|
else:
|
|
print("Pipeline error, no retrieved results.")
|
|
retrieved_titles = []
|
|
fail = True
|
|
return retrieved_titles, consume_tokens, fail
|
|
|
|
|
|
def _naive_retrieve_titles(question: str) -> List[str]:
|
|
"""
|
|
Retrieve document titles using naive retrieval method.
|
|
|
|
Args:
|
|
question (str): The query question.
|
|
|
|
Returns:
|
|
List[str]: List of retrieved document titles.
|
|
"""
|
|
retrieved_results = naive_retrieve(question)
|
|
retrieved_titles = [
|
|
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
|
|
]
|
|
return retrieved_titles
|
|
|
|
|
|
def _calcu_recall(sample, retrieved_titles, dataset) -> dict:
|
|
"""
|
|
Calculate recall metrics for retrieved titles.
|
|
|
|
Args:
|
|
sample: The sample data containing ground truth information.
|
|
retrieved_titles: List of retrieved document titles.
|
|
dataset (str): The name of the dataset being evaluated.
|
|
|
|
Returns:
|
|
dict: Dictionary containing recall values at different k values.
|
|
|
|
Raises:
|
|
NotImplementedError: If the dataset is not supported.
|
|
"""
|
|
if dataset in ["2wikimultihopqa"]:
|
|
gold_passages = [item for item in sample["supporting_facts"]]
|
|
gold_items = set([item[0] for item in gold_passages])
|
|
retrieved_items = retrieved_titles
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
recall = dict()
|
|
for k in k_list:
|
|
recall[k] = round(
|
|
sum(1 for t in gold_items if t in retrieved_items[:k]) / len(gold_items), 4
|
|
)
|
|
return recall
|
|
|
|
|
|
def _print_recall_line(recall: dict, pre_str="", post_str="\n"):
|
|
"""
|
|
Print recall metrics in a formatted line.
|
|
|
|
Args:
|
|
recall (dict): Dictionary containing recall values at different k values.
|
|
pre_str (str, optional): String to print before recall values. Defaults to "".
|
|
post_str (str, optional): String to print after recall values. Defaults to "\n".
|
|
"""
|
|
print(pre_str, end="")
|
|
for k in k_list:
|
|
print(f"R@{k}: {recall[k]:.3f} ", end="")
|
|
print(post_str, end="")
|
|
|
|
|
|
def evaluate(
|
|
dataset: str,
|
|
output_root: str,
|
|
pre_num: int = 10,
|
|
max_iter: int = 3,
|
|
skip_load=False,
|
|
flag: str = "result",
|
|
):
|
|
"""
|
|
Evaluate the retrieval performance on a dataset.
|
|
|
|
Args:
|
|
dataset (str): Name of the dataset to evaluate.
|
|
output_root (str): Root directory for output files.
|
|
pre_num (int, optional): Number of samples to evaluate. Defaults to 10.
|
|
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
|
|
skip_load (bool, optional): Whether to skip loading the dataset. Defaults to False.
|
|
flag (str, optional): Flag for the evaluation run. Defaults to "result".
|
|
"""
|
|
corpus_file = os.path.join(current_dir, f"../examples/data/{dataset}_corpus.json")
|
|
if not skip_load:
|
|
# set chunk size to a large number to avoid chunking, because the dataset was chunked already.
|
|
load_from_local_files(
|
|
corpus_file, force_new_collection=True, chunk_size=999999, chunk_overlap=0
|
|
)
|
|
|
|
eval_output_subdir = os.path.join(output_root, flag)
|
|
os.makedirs(eval_output_subdir, exist_ok=True)
|
|
csv_file_path = os.path.join(eval_output_subdir, "details.csv")
|
|
statistics_file_path = os.path.join(eval_output_subdir, "statistics.json")
|
|
|
|
data_with_gt_file_path = os.path.join(current_dir, f"../examples/data/{dataset}.json")
|
|
data_with_gt = json.load(open(data_with_gt_file_path, "r"))
|
|
|
|
if not pre_num:
|
|
pre_num = len(data_with_gt)
|
|
|
|
pipeline_error_num = 0
|
|
end_ind = min(pre_num, len(data_with_gt))
|
|
|
|
start_ind = 0
|
|
existing_df = pd.DataFrame()
|
|
existing_statistics = defaultdict(dict)
|
|
existing_token_usage = 0
|
|
existing_error_num = 0
|
|
existing_sample_num = 0
|
|
if os.path.exists(csv_file_path):
|
|
existing_df = pd.read_csv(csv_file_path)
|
|
start_ind = len(existing_df)
|
|
print(f"Loading results from {csv_file_path}, start_index = {start_ind}")
|
|
|
|
if os.path.exists(statistics_file_path):
|
|
existing_statistics = json.load(open(statistics_file_path, "r"))
|
|
print(
|
|
f"Loading statistics from {statistics_file_path}, will recalculate the statistics based on both new and existing results."
|
|
)
|
|
existing_token_usage = existing_statistics["deepsearcher"]["token_usage"]
|
|
existing_error_num = existing_statistics["deepsearcher"].get("error_num", 0)
|
|
existing_sample_num = existing_statistics["deepsearcher"].get("sample_num", 0)
|
|
for sample_idx, sample in enumerate(data_with_gt[start_ind:end_ind]):
|
|
global_idx = sample_idx + start_ind
|
|
question = sample["question"]
|
|
|
|
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(
|
|
question, max_iter=max_iter
|
|
)
|
|
retrieved_titles_naive = _naive_retrieve_titles(question)
|
|
|
|
if fail:
|
|
pipeline_error_num += 1
|
|
print(
|
|
f"Pipeline error, no retrieved results. Current pipeline_error_num = {pipeline_error_num}"
|
|
)
|
|
|
|
print(f"idx: {global_idx}: ")
|
|
recall = _calcu_recall(sample, retrieved_titles, dataset)
|
|
recall_naive = _calcu_recall(sample, retrieved_titles_naive, dataset)
|
|
current_result = [
|
|
{
|
|
"idx": global_idx,
|
|
"question": question,
|
|
"recall": recall,
|
|
"recall_naive": recall_naive,
|
|
"gold_titles": [item[0] for item in sample["supporting_facts"]],
|
|
"retrieved_titles": retrieved_titles,
|
|
"retrieved_titles_naive": retrieved_titles_naive,
|
|
}
|
|
]
|
|
current_df = pd.DataFrame(current_result)
|
|
existing_df = pd.concat([existing_df, current_df], ignore_index=True)
|
|
existing_df.to_csv(csv_file_path, index=False)
|
|
average_recall = dict()
|
|
average_recall_naive = dict()
|
|
for k in k_list:
|
|
average_recall[k] = sum(
|
|
[
|
|
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
|
|
for d in existing_df["recall"]
|
|
]
|
|
) / len(existing_df)
|
|
average_recall_naive[k] = sum(
|
|
[
|
|
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
|
|
for d in existing_df["recall_naive"]
|
|
]
|
|
) / len(existing_df)
|
|
_print_recall_line(average_recall, pre_str="Average recall of DeepSearcher: ")
|
|
_print_recall_line(average_recall_naive, pre_str="Average recall of naive RAG : ")
|
|
existing_token_usage += consume_tokens
|
|
existing_error_num += 1 if fail else 0
|
|
existing_sample_num += 1
|
|
existing_statistics["deepsearcher"]["average_recall"] = average_recall
|
|
existing_statistics["deepsearcher"]["token_usage"] = existing_token_usage
|
|
existing_statistics["deepsearcher"]["error_num"] = existing_error_num
|
|
existing_statistics["deepsearcher"]["sample_num"] = existing_sample_num
|
|
existing_statistics["deepsearcher"]["token_usage_per_sample"] = (
|
|
existing_token_usage / existing_sample_num
|
|
)
|
|
existing_statistics["naive_rag"]["average_recall"] = average_recall_naive
|
|
json.dump(existing_statistics, open(statistics_file_path, "w"), indent=4)
|
|
print("")
|
|
print("Finish results to save.")
|
|
|
|
|
|
def main_eval():
|
|
"""
|
|
Main function for running the evaluation from command line.
|
|
|
|
This function parses command line arguments and calls the evaluate function
|
|
with the appropriate parameters.
|
|
"""
|
|
parser = argparse.ArgumentParser(prog="evaluate", description="Deep Searcher evaluation.")
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
default="2wikimultihopqa",
|
|
help="Dataset name, default is `2wikimultihopqa`. More datasets will be supported in the future.",
|
|
)
|
|
parser.add_argument(
|
|
"--config_yaml",
|
|
type=str,
|
|
default="./eval_config.yaml",
|
|
help="Configuration yaml file path, default is `./eval_config.yaml`",
|
|
)
|
|
parser.add_argument(
|
|
"--pre_num",
|
|
type=int,
|
|
default=30,
|
|
help="Number of samples to evaluate, default is 30",
|
|
)
|
|
parser.add_argument(
|
|
"--max_iter",
|
|
type=int,
|
|
default=3,
|
|
help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="./eval_output",
|
|
help="Output root directory, default is `./eval_output`",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_load",
|
|
action="store_true",
|
|
help="Whether to skip loading the dataset. Default it don't skip loading. If you want to skip loading, please set this flag.",
|
|
)
|
|
parser.add_argument(
|
|
"--flag",
|
|
type=str,
|
|
default="result",
|
|
help="Flag for evaluation, default is `result`",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
config = Configuration(config_path=args.config_yaml)
|
|
init_config(config=config)
|
|
|
|
evaluate(
|
|
dataset=args.dataset,
|
|
output_root=args.output_dir,
|
|
pre_num=args.pre_num,
|
|
max_iter=args.max_iter,
|
|
skip_load=args.skip_load,
|
|
flag=args.flag,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main_eval()
|
|
|