You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

329 lines
12 KiB

# Some test dataset and evaluation method are ref from https://github.com/OSU-NLP-Group/HippoRAG/tree/main/data , many thanks
################################################################################
# Note: This evaluation script will cost a lot of LLM token usage, please make sure you have enough token budget.
################################################################################
import argparse
import ast
import json
import logging
import os
import time
import warnings
from collections import defaultdict
from typing import List, Tuple
import pandas as pd
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.online_query import naive_retrieve, retrieve
httpx_logger = logging.getLogger("httpx") # disable openai's logger output
httpx_logger.setLevel(logging.WARNING)
warnings.simplefilter(action="ignore", category=FutureWarning) # disable warning output
current_dir = os.path.dirname(os.path.abspath(__file__))
k_list = [2, 5]
def _deepsearch_retrieve_titles(
question: str,
retry_num: int = 4,
base_wait_time: int = 4,
max_iter: int = 3,
) -> Tuple[List[str], int, bool]:
"""
Retrieve document titles using DeepSearcher with retry mechanism.
Args:
question (str): The query question.
retry_num (int, optional): Number of retry attempts. Defaults to 4.
base_wait_time (int, optional): Base wait time between retries in seconds. Defaults to 4.
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
Returns:
Tuple[List[str], int, bool]: A tuple containing:
- List of retrieved document titles
- Number of tokens consumed
- Boolean indicating whether the retrieval failed
"""
retrieved_results = []
consume_tokens = 0
for i in range(retry_num):
try:
retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter)
break
except Exception:
wait_time = base_wait_time * (2**i)
print(f"Parse LLM's output failed, retry again after {wait_time} seconds...")
time.sleep(wait_time)
if retrieved_results:
retrieved_titles = [
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
]
fail = False
else:
print("Pipeline error, no retrieved results.")
retrieved_titles = []
fail = True
return retrieved_titles, consume_tokens, fail
def _naive_retrieve_titles(question: str) -> List[str]:
"""
Retrieve document titles using naive retrieval method.
Args:
question (str): The query question.
Returns:
List[str]: List of retrieved document titles.
"""
retrieved_results = naive_retrieve(question)
retrieved_titles = [
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
]
return retrieved_titles
def _calcu_recall(sample, retrieved_titles, dataset) -> dict:
"""
Calculate recall metrics for retrieved titles.
Args:
sample: The sample data containing ground truth information.
retrieved_titles: List of retrieved document titles.
dataset (str): The name of the dataset being evaluated.
Returns:
dict: Dictionary containing recall values at different k values.
Raises:
NotImplementedError: If the dataset is not supported.
"""
if dataset in ["2wikimultihopqa"]:
gold_passages = [item for item in sample["supporting_facts"]]
gold_items = set([item[0] for item in gold_passages])
retrieved_items = retrieved_titles
else:
raise NotImplementedError
recall = dict()
for k in k_list:
recall[k] = round(
sum(1 for t in gold_items if t in retrieved_items[:k]) / len(gold_items), 4
)
return recall
def _print_recall_line(recall: dict, pre_str="", post_str="\n"):
"""
Print recall metrics in a formatted line.
Args:
recall (dict): Dictionary containing recall values at different k values.
pre_str (str, optional): String to print before recall values. Defaults to "".
post_str (str, optional): String to print after recall values. Defaults to "\n".
"""
print(pre_str, end="")
for k in k_list:
print(f"R@{k}: {recall[k]:.3f} ", end="")
print(post_str, end="")
def evaluate(
dataset: str,
output_root: str,
pre_num: int = 10,
max_iter: int = 3,
skip_load=False,
flag: str = "result",
):
"""
Evaluate the retrieval performance on a dataset.
Args:
dataset (str): Name of the dataset to evaluate.
output_root (str): Root directory for output files.
pre_num (int, optional): Number of samples to evaluate. Defaults to 10.
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
skip_load (bool, optional): Whether to skip loading the dataset. Defaults to False.
flag (str, optional): Flag for the evaluation run. Defaults to "result".
"""
corpus_file = os.path.join(current_dir, f"../examples/data/{dataset}_corpus.json")
if not skip_load:
# set chunk size to a large number to avoid chunking, because the dataset was chunked already.
load_from_local_files(
corpus_file, force_new_collection=True, chunk_size=999999, chunk_overlap=0
)
eval_output_subdir = os.path.join(output_root, flag)
os.makedirs(eval_output_subdir, exist_ok=True)
csv_file_path = os.path.join(eval_output_subdir, "details.csv")
statistics_file_path = os.path.join(eval_output_subdir, "statistics.json")
data_with_gt_file_path = os.path.join(current_dir, f"../examples/data/{dataset}.json")
data_with_gt = json.load(open(data_with_gt_file_path, "r"))
if not pre_num:
pre_num = len(data_with_gt)
pipeline_error_num = 0
end_ind = min(pre_num, len(data_with_gt))
start_ind = 0
existing_df = pd.DataFrame()
existing_statistics = defaultdict(dict)
existing_token_usage = 0
existing_error_num = 0
existing_sample_num = 0
if os.path.exists(csv_file_path):
existing_df = pd.read_csv(csv_file_path)
start_ind = len(existing_df)
print(f"Loading results from {csv_file_path}, start_index = {start_ind}")
if os.path.exists(statistics_file_path):
existing_statistics = json.load(open(statistics_file_path, "r"))
print(
f"Loading statistics from {statistics_file_path}, will recalculate the statistics based on both new and existing results."
)
existing_token_usage = existing_statistics["deepsearcher"]["token_usage"]
existing_error_num = existing_statistics["deepsearcher"].get("error_num", 0)
existing_sample_num = existing_statistics["deepsearcher"].get("sample_num", 0)
for sample_idx, sample in enumerate(data_with_gt[start_ind:end_ind]):
global_idx = sample_idx + start_ind
question = sample["question"]
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(
question, max_iter=max_iter
)
retrieved_titles_naive = _naive_retrieve_titles(question)
if fail:
pipeline_error_num += 1
print(
f"Pipeline error, no retrieved results. Current pipeline_error_num = {pipeline_error_num}"
)
print(f"idx: {global_idx}: ")
recall = _calcu_recall(sample, retrieved_titles, dataset)
recall_naive = _calcu_recall(sample, retrieved_titles_naive, dataset)
current_result = [
{
"idx": global_idx,
"question": question,
"recall": recall,
"recall_naive": recall_naive,
"gold_titles": [item[0] for item in sample["supporting_facts"]],
"retrieved_titles": retrieved_titles,
"retrieved_titles_naive": retrieved_titles_naive,
}
]
current_df = pd.DataFrame(current_result)
existing_df = pd.concat([existing_df, current_df], ignore_index=True)
existing_df.to_csv(csv_file_path, index=False)
average_recall = dict()
average_recall_naive = dict()
for k in k_list:
average_recall[k] = sum(
[
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
for d in existing_df["recall"]
]
) / len(existing_df)
average_recall_naive[k] = sum(
[
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
for d in existing_df["recall_naive"]
]
) / len(existing_df)
_print_recall_line(average_recall, pre_str="Average recall of DeepSearcher: ")
_print_recall_line(average_recall_naive, pre_str="Average recall of naive RAG : ")
existing_token_usage += consume_tokens
existing_error_num += 1 if fail else 0
existing_sample_num += 1
existing_statistics["deepsearcher"]["average_recall"] = average_recall
existing_statistics["deepsearcher"]["token_usage"] = existing_token_usage
existing_statistics["deepsearcher"]["error_num"] = existing_error_num
existing_statistics["deepsearcher"]["sample_num"] = existing_sample_num
existing_statistics["deepsearcher"]["token_usage_per_sample"] = (
existing_token_usage / existing_sample_num
)
existing_statistics["naive_rag"]["average_recall"] = average_recall_naive
json.dump(existing_statistics, open(statistics_file_path, "w"), indent=4)
print("")
print("Finish results to save.")
def main_eval():
"""
Main function for running the evaluation from command line.
This function parses command line arguments and calls the evaluate function
with the appropriate parameters.
"""
parser = argparse.ArgumentParser(prog="evaluate", description="Deep Searcher evaluation.")
parser.add_argument(
"--dataset",
type=str,
default="2wikimultihopqa",
help="Dataset name, default is `2wikimultihopqa`. More datasets will be supported in the future.",
)
parser.add_argument(
"--config_yaml",
type=str,
default="./eval_config.yaml",
help="Configuration yaml file path, default is `./eval_config.yaml`",
)
parser.add_argument(
"--pre_num",
type=int,
default=30,
help="Number of samples to evaluate, default is 30",
)
parser.add_argument(
"--max_iter",
type=int,
default=3,
help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./eval_output",
help="Output root directory, default is `./eval_output`",
)
parser.add_argument(
"--skip_load",
action="store_true",
help="Whether to skip loading the dataset. Default it don't skip loading. If you want to skip loading, please set this flag.",
)
parser.add_argument(
"--flag",
type=str,
default="result",
help="Flag for evaluation, default is `result`",
)
args = parser.parse_args()
config = Configuration(config_path=args.config_yaml)
init_config(config=config)
evaluate(
dataset=args.dataset,
output_root=args.output_dir,
pre_num=args.pre_num,
max_iter=args.max_iter,
skip_load=args.skip_load,
flag=args.flag,
)
if __name__ == "__main__":
main_eval()