75 changed files with 0 additions and 10370 deletions
@ -1,19 +0,0 @@ |
|||
FROM ghcr.io/astral-sh/uv:python3.10-bookworm-slim |
|||
|
|||
WORKDIR /app |
|||
|
|||
RUN mkdir -p /tmp/uv-cache /app/data /app/logs |
|||
|
|||
COPY pyproject.toml uv.lock LICENSE README.md ./ |
|||
COPY deepsearcher/ ./deepsearcher/ |
|||
|
|||
RUN uv sync |
|||
|
|||
COPY . . |
|||
|
|||
EXPOSE 8000 |
|||
|
|||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ |
|||
CMD curl -f http://localhost:8000/docs || exit 1 |
|||
|
|||
CMD ["uv", "run", "python", "main.py", "--enable-cors", "true"] |
@ -1,201 +0,0 @@ |
|||
Apache License |
|||
Version 2.0, January 2004 |
|||
http://www.apache.org/licenses/ |
|||
|
|||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION |
|||
|
|||
1. Definitions. |
|||
|
|||
"License" shall mean the terms and conditions for use, reproduction, |
|||
and distribution as defined by Sections 1 through 9 of this document. |
|||
|
|||
"Licensor" shall mean the copyright owner or entity authorized by |
|||
the copyright owner that is granting the License. |
|||
|
|||
"Legal Entity" shall mean the union of the acting entity and all |
|||
other entities that control, are controlled by, or are under common |
|||
control with that entity. For the purposes of this definition, |
|||
"control" means (i) the power, direct or indirect, to cause the |
|||
direction or management of such entity, whether by contract or |
|||
otherwise, or (ii) ownership of fifty percent (50%) or more of the |
|||
outstanding shares, or (iii) beneficial ownership of such entity. |
|||
|
|||
"You" (or "Your") shall mean an individual or Legal Entity |
|||
exercising permissions granted by this License. |
|||
|
|||
"Source" form shall mean the preferred form for making modifications, |
|||
including but not limited to software source code, documentation |
|||
source, and configuration files. |
|||
|
|||
"Object" form shall mean any form resulting from mechanical |
|||
transformation or translation of a Source form, including but |
|||
not limited to compiled object code, generated documentation, |
|||
and conversions to other media types. |
|||
|
|||
"Work" shall mean the work of authorship, whether in Source or |
|||
Object form, made available under the License, as indicated by a |
|||
copyright notice that is included in or attached to the work |
|||
(an example is provided in the Appendix below). |
|||
|
|||
"Derivative Works" shall mean any work, whether in Source or Object |
|||
form, that is based on (or derived from) the Work and for which the |
|||
editorial revisions, annotations, elaborations, or other modifications |
|||
represent, as a whole, an original work of authorship. For the purposes |
|||
of this License, Derivative Works shall not include works that remain |
|||
separable from, or merely link (or bind by name) to the interfaces of, |
|||
the Work and Derivative Works thereof. |
|||
|
|||
"Contribution" shall mean any work of authorship, including |
|||
the original version of the Work and any modifications or additions |
|||
to that Work or Derivative Works thereof, that is intentionally |
|||
submitted to Licensor for inclusion in the Work by the copyright owner |
|||
or by an individual or Legal Entity authorized to submit on behalf of |
|||
the copyright owner. For the purposes of this definition, "submitted" |
|||
means any form of electronic, verbal, or written communication sent |
|||
to the Licensor or its representatives, including but not limited to |
|||
communication on electronic mailing lists, source code control systems, |
|||
and issue tracking systems that are managed by, or on behalf of, the |
|||
Licensor for the purpose of discussing and improving the Work, but |
|||
excluding communication that is conspicuously marked or otherwise |
|||
designated in writing by the copyright owner as "Not a Contribution." |
|||
|
|||
"Contributor" shall mean Licensor and any individual or Legal Entity |
|||
on behalf of whom a Contribution has been received by Licensor and |
|||
subsequently incorporated within the Work. |
|||
|
|||
2. Grant of Copyright License. Subject to the terms and conditions of |
|||
this License, each Contributor hereby grants to You a perpetual, |
|||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|||
copyright license to reproduce, prepare Derivative Works of, |
|||
publicly display, publicly perform, sublicense, and distribute the |
|||
Work and such Derivative Works in Source or Object form. |
|||
|
|||
3. Grant of Patent License. Subject to the terms and conditions of |
|||
this License, each Contributor hereby grants to You a perpetual, |
|||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable |
|||
(except as stated in this section) patent license to make, have made, |
|||
use, offer to sell, sell, import, and otherwise transfer the Work, |
|||
where such license applies only to those patent claims licensable |
|||
by such Contributor that are necessarily infringed by their |
|||
Contribution(s) alone or by combination of their Contribution(s) |
|||
with the Work to which such Contribution(s) was submitted. If You |
|||
institute patent litigation against any entity (including a |
|||
cross-claim or counterclaim in a lawsuit) alleging that the Work |
|||
or a Contribution incorporated within the Work constitutes direct |
|||
or contributory patent infringement, then any patent licenses |
|||
granted to You under this License for that Work shall terminate |
|||
as of the date such litigation is filed. |
|||
|
|||
4. Redistribution. You may reproduce and distribute copies of the |
|||
Work or Derivative Works thereof in any medium, with or without |
|||
modifications, and in Source or Object form, provided that You |
|||
meet the following conditions: |
|||
|
|||
(a) You must give any other recipients of the Work or |
|||
Derivative Works a copy of this License; and |
|||
|
|||
(b) You must cause any modified files to carry prominent notices |
|||
stating that You changed the files; and |
|||
|
|||
(c) You must retain, in the Source form of any Derivative Works |
|||
that You distribute, all copyright, patent, trademark, and |
|||
attribution notices from the Source form of the Work, |
|||
excluding those notices that do not pertain to any part of |
|||
the Derivative Works; and |
|||
|
|||
(d) If the Work includes a "NOTICE" text file as part of its |
|||
distribution, then any Derivative Works that You distribute must |
|||
include a readable copy of the attribution notices contained |
|||
within such NOTICE file, excluding those notices that do not |
|||
pertain to any part of the Derivative Works, in at least one |
|||
of the following places: within a NOTICE text file distributed |
|||
as part of the Derivative Works; within the Source form or |
|||
documentation, if provided along with the Derivative Works; or, |
|||
within a display generated by the Derivative Works, if and |
|||
wherever such third-party notices normally appear. The contents |
|||
of the NOTICE file are for informational purposes only and |
|||
do not modify the License. You may add Your own attribution |
|||
notices within Derivative Works that You distribute, alongside |
|||
or as an addendum to the NOTICE text from the Work, provided |
|||
that such additional attribution notices cannot be construed |
|||
as modifying the License. |
|||
|
|||
You may add Your own copyright statement to Your modifications and |
|||
may provide additional or different license terms and conditions |
|||
for use, reproduction, or distribution of Your modifications, or |
|||
for any such Derivative Works as a whole, provided Your use, |
|||
reproduction, and distribution of the Work otherwise complies with |
|||
the conditions stated in this License. |
|||
|
|||
5. Submission of Contributions. Unless You explicitly state otherwise, |
|||
any Contribution intentionally submitted for inclusion in the Work |
|||
by You to the Licensor shall be under the terms and conditions of |
|||
this License, without any additional terms or conditions. |
|||
Notwithstanding the above, nothing herein shall supersede or modify |
|||
the terms of any separate license agreement you may have executed |
|||
with Licensor regarding such Contributions. |
|||
|
|||
6. Trademarks. This License does not grant permission to use the trade |
|||
names, trademarks, service marks, or product names of the Licensor, |
|||
except as required for reasonable and customary use in describing the |
|||
origin of the Work and reproducing the content of the NOTICE file. |
|||
|
|||
7. Disclaimer of Warranty. Unless required by applicable law or |
|||
agreed to in writing, Licensor provides the Work (and each |
|||
Contributor provides its Contributions) on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|||
implied, including, without limitation, any warranties or conditions |
|||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
|||
PARTICULAR PURPOSE. You are solely responsible for determining the |
|||
appropriateness of using or redistributing the Work and assume any |
|||
risks associated with Your exercise of permissions under this License. |
|||
|
|||
8. Limitation of Liability. In no event and under no legal theory, |
|||
whether in tort (including negligence), contract, or otherwise, |
|||
unless required by applicable law (such as deliberate and grossly |
|||
negligent acts) or agreed to in writing, shall any Contributor be |
|||
liable to You for damages, including any direct, indirect, special, |
|||
incidental, or consequential damages of any character arising as a |
|||
result of this License or out of the use or inability to use the |
|||
Work (including but not limited to damages for loss of goodwill, |
|||
work stoppage, computer failure or malfunction, or any and all |
|||
other commercial damages or losses), even if such Contributor |
|||
has been advised of the possibility of such damages. |
|||
|
|||
9. Accepting Warranty or Additional Liability. While redistributing |
|||
the Work or Derivative Works thereof, You may choose to offer, |
|||
and charge a fee for, acceptance of support, warranty, indemnity, |
|||
or other liability obligations and/or rights consistent with this |
|||
License. However, in accepting such obligations, You may act only |
|||
on Your own behalf and on Your sole responsibility, not on behalf |
|||
of any other Contributor, and only if You agree to indemnify, |
|||
defend, and hold each Contributor harmless for any liability |
|||
incurred by, or claims asserted against, such Contributor by reason |
|||
of your accepting any such warranty or additional liability. |
|||
|
|||
END OF TERMS AND CONDITIONS |
|||
|
|||
APPENDIX: How to apply the Apache License to your work. |
|||
|
|||
To apply the Apache License to your work, attach the following |
|||
boilerplate notice, with the fields enclosed by brackets "[]" |
|||
replaced with your own identifying information. (Don't include |
|||
the brackets!) The text should be enclosed in the appropriate |
|||
comment syntax for the file format. We also recommend that a |
|||
file or class name and description of purpose be included on the |
|||
same "printed page" as the copyright notice for easier |
|||
identification within third-party archives. |
|||
|
|||
Copyright 2019 Zilliz |
|||
|
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
http://www.apache.org/licenses/LICENSE-2.0 |
|||
|
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
@ -1,7 +0,0 @@ |
|||
lint: |
|||
uv run ruff format --diff |
|||
uv run ruff check |
|||
|
|||
format: |
|||
uv run ruff format |
|||
uv run ruff check --fix |
Before Width: | Height: | Size: 307 KiB |
Before Width: | Height: | Size: 3.4 MiB |
Before Width: | Height: | Size: 54 KiB |
@ -1,53 +0,0 @@ |
|||
# Evaluation of DeepSearcher |
|||
## Introduction |
|||
DeepSearcher is very good at answering complex queries. In this evaluation introduction, we provide some scripts to evaluate the performance of DeepSearcher vs. naive RAG. |
|||
|
|||
The evaluation is based on the Recall metric: |
|||
|
|||
> Recall@K: The percentage of relevant documents that are retrieved among the top K documents returned by the search engine. |
|||
|
|||
Currently, we support the multi-hop question answering dataset of [2WikiMultiHopQA](https://paperswithcode.com/dataset/2wikimultihopqa). More dataset will be added in the future. |
|||
|
|||
## Evaluation Script |
|||
The main evaluation script is `evaluate.py`. |
|||
|
|||
Your can provide a config file, say `eval_config.yaml`, to specify the LLM, embedding model, and other provider and parameters. |
|||
```shell |
|||
python evaluate.py \ |
|||
--dataset 2wikimultihopqa \ |
|||
--config_yaml ./eval_config.yaml \ |
|||
--pre_num 5 \ |
|||
--output_dir ./eval_output |
|||
``` |
|||
`pre_num` is the number of samples to evaluate, the more samples, the more accurate the results will be, but it will consume more time and your LLM api token usage. |
|||
|
|||
After you have loaded the dataset into vectorDB in the first run, if you want to skip loading dataset again, you can set the flag `--skip_load` in the command line. |
|||
|
|||
For more arguments details, you can run |
|||
```shell |
|||
python evaluate.py --help |
|||
``` |
|||
|
|||
## Evaluation Results |
|||
We conducted tests using the commonly used 2WikiMultiHopQA dataset. (Due to the high consumption of API tokens for testing, we only tested the first 50 samples. This may introduce some fluctuations compared to testing the entire dataset, but it can still roughly reflect the general landscape of performance.) |
|||
|
|||
### Recall Comparison between Naive RAG and DeepSearcher with Different Models |
|||
With Max Iterations on the horizontal axis and Recall on the vertical axis, the following chart compares the recall rates of Deep Searcher and naive RAG. |
|||
 |
|||
#### Performance Improvement with Iterations |
|||
As we can see, as the number of Max Iterations increases, the recall performance of Deep Searcher improves significantly. And all the model results from Deep Searcher are significantly higher than those from naive RAG. |
|||
|
|||
#### Diminishing Returns |
|||
However, it is also evident that as the number of iterations gradually increases, the marginal gains decrease, indicating that there may be a certain limit reached after increasing the feedback iterations, and further feedback might not yield significantly better results. |
|||
|
|||
#### Model Performance Comparison |
|||
Claude-3-7-sonnet (red line) demonstrates superior performance throughout, achieving nearly perfect recall at 7 iterations. Most models show significant improvement as iterations increase, with the steepest gains occurring between 2-4 iterations. Models like o1-mini (yellow) and deepseek-r1 (green) exhibit strong performance at higher iteration counts. Since our sample number for testing is limited, the results of each test may vary somewhat. |
|||
Overall, reasoning models generally perform better than non-reasoning models. |
|||
|
|||
#### Limitations of Non-Reasoning Models |
|||
Additionally, in our tests, weaker and smaller non-reasoning models sometimes failed to complete the entire agent query pipeline, due to their inadequate instruction-following capabilities. |
|||
|
|||
### Token Consumption |
|||
We plotted the graph below with the number of iterations on the horizontal axis and the average token consumption per sample on the vertical axis: |
|||
 |
|||
It is evident that as the number of iterations increases, the token consumption of Deep Searcher rises linearly. Based on this approximate token consumption, you can check the pricing on your model provider's website to estimate the cost of running evaluations with different iteration settings. |
@ -1,119 +0,0 @@ |
|||
provide_settings: |
|||
llm: |
|||
provider: "OpenAI" |
|||
config: |
|||
model: "o1-mini" |
|||
# api_key: "sk-xxxx" # Uncomment to override the `OPENAI_API_KEY` set in the environment variable |
|||
# base_url: "" |
|||
|
|||
# provider: "DeepSeek" |
|||
# config: |
|||
# model: "deepseek-reasoner" |
|||
## api_key: "sk-xxxx" # Uncomment to override the `DEEPSEEK_API_KEY` set in the environment variable |
|||
## base_url: "" |
|||
|
|||
# provider: "SiliconFlow" |
|||
# config: |
|||
# model: "deepseek-ai/DeepSeek-R1" |
|||
## api_key: "xxxx" # Uncomment to override the `SILICONFLOW_API_KEY` set in the environment variable |
|||
## base_url: "" |
|||
|
|||
# provider: "PPIO" |
|||
# config: |
|||
# model: "deepseek/deepseek-r1-turbo" |
|||
## api_key: "xxxx" # Uncomment to override the `PPIO_API_KEY` set in the environment variable |
|||
## base_url: "" |
|||
|
|||
# provider: "TogetherAI" |
|||
# config: |
|||
# model: "deepseek-ai/DeepSeek-R1" |
|||
## api_key: "xxxx" # Uncomment to override the `TOGETHER_API_KEY` set in the environment variable |
|||
|
|||
# provider: "AzureOpenAI" |
|||
# config: |
|||
# model: "" |
|||
# api_version: "" |
|||
## azure_endpoint: "xxxx" # Uncomment to override the `AZURE_OPENAI_ENDPOINT` set in the environment variable |
|||
## api_key: "xxxx" # Uncomment to override the `AZURE_OPENAI_KEY` set in the environment variable |
|||
|
|||
# provider: "Ollama" |
|||
# config: |
|||
# model: "qwq" |
|||
## base_url: "" |
|||
|
|||
# provider: "Novita" |
|||
# config: |
|||
# model: "deepseek/deepseek-v3-0324" |
|||
## api_key: "xxxx" # Uncomment to override the `NOVITA_API_KEY` set in the environment variable |
|||
## base_url: "" |
|||
|
|||
embedding: |
|||
provider: "OpenAIEmbedding" |
|||
config: |
|||
model: "text-embedding-ada-002" |
|||
# api_key: "" # Uncomment to override the `OPENAI_API_KEY` set in the environment variable |
|||
|
|||
|
|||
# provider: "MilvusEmbedding" |
|||
# config: |
|||
# model: "default" |
|||
|
|||
# provider: "VoyageEmbedding" |
|||
# config: |
|||
# model: "voyage-3" |
|||
## api_key: "" # Uncomment to override the `VOYAGE_API_KEY` set in the environment variable |
|||
|
|||
# provider: "BedrockEmbedding" |
|||
# config: |
|||
# model: "amazon.titan-embed-text-v2:0" |
|||
## aws_access_key_id: "" # Uncomment to override the `AWS_ACCESS_KEY_ID` set in the environment variable |
|||
## aws_secret_access_key: "" # Uncomment to override the `AWS_SECRET_ACCESS_KEY` set in the environment variable |
|||
|
|||
# provider: "SiliconflowEmbedding" |
|||
# config: |
|||
# model: "BAAI/bge-m3" |
|||
# . api_key: "" # Uncomment to override the `SILICONFLOW_API_KEY` set in the environment variable |
|||
|
|||
# provider: "NovitaEmbedding" |
|||
# config: |
|||
# model: "baai/bge-m3" |
|||
# . api_key: "" # Uncomment to override the `NOVITA_API_KEY` set in the environment variable |
|||
|
|||
file_loader: |
|||
# provider: "PDFLoader" |
|||
# config: {} |
|||
|
|||
provider: "JsonFileLoader" |
|||
config: |
|||
text_key: "text" |
|||
|
|||
# provider: "TextLoader" |
|||
# config: {} |
|||
|
|||
# provider: "UnstructuredLoader" |
|||
# config: {} |
|||
|
|||
web_crawler: |
|||
provider: "FireCrawlCrawler" |
|||
config: {} |
|||
|
|||
# provider: "Crawl4AICrawler" |
|||
# config: {} |
|||
|
|||
# provider: "JinaCrawler" |
|||
# config: {} |
|||
|
|||
vector_db: |
|||
provider: "Milvus" |
|||
config: |
|||
default_collection: "deepsearcher" |
|||
uri: "./milvus.db" |
|||
token: "root:Milvus" |
|||
db: "default" |
|||
|
|||
query_settings: |
|||
max_iter: 3 |
|||
|
|||
load_settings: |
|||
chunk_size: 1500 |
|||
chunk_overlap: 100 |
@ -1,329 +0,0 @@ |
|||
# Some test dataset and evaluation method are ref from https://github.com/OSU-NLP-Group/HippoRAG/tree/main/data , many thanks |
|||
|
|||
################################################################################ |
|||
# Note: This evaluation script will cost a lot of LLM token usage, please make sure you have enough token budget. |
|||
################################################################################ |
|||
import argparse |
|||
import ast |
|||
import json |
|||
import logging |
|||
import os |
|||
import time |
|||
import warnings |
|||
from collections import defaultdict |
|||
from typing import List, Tuple |
|||
|
|||
import pandas as pd |
|||
|
|||
from deepsearcher.configuration import Configuration, init_config |
|||
from deepsearcher.offline_loading import load_from_local_files |
|||
from deepsearcher.online_query import naive_retrieve, retrieve |
|||
|
|||
httpx_logger = logging.getLogger("httpx") # disable openai's logger output |
|||
httpx_logger.setLevel(logging.WARNING) |
|||
|
|||
|
|||
warnings.simplefilter(action="ignore", category=FutureWarning) # disable warning output |
|||
|
|||
|
|||
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|||
|
|||
k_list = [2, 5] |
|||
|
|||
|
|||
def _deepsearch_retrieve_titles( |
|||
question: str, |
|||
retry_num: int = 4, |
|||
base_wait_time: int = 4, |
|||
max_iter: int = 3, |
|||
) -> Tuple[List[str], int, bool]: |
|||
""" |
|||
Retrieve document titles using DeepSearcher with retry mechanism. |
|||
|
|||
Args: |
|||
question (str): The query question. |
|||
retry_num (int, optional): Number of retry attempts. Defaults to 4. |
|||
base_wait_time (int, optional): Base wait time between retries in seconds. Defaults to 4. |
|||
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3. |
|||
|
|||
Returns: |
|||
Tuple[List[str], int, bool]: A tuple containing: |
|||
- List of retrieved document titles |
|||
- Number of tokens consumed |
|||
- Boolean indicating whether the retrieval failed |
|||
""" |
|||
retrieved_results = [] |
|||
consume_tokens = 0 |
|||
for i in range(retry_num): |
|||
try: |
|||
retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter) |
|||
break |
|||
except Exception: |
|||
wait_time = base_wait_time * (2**i) |
|||
print(f"Parse LLM's output failed, retry again after {wait_time} seconds...") |
|||
time.sleep(wait_time) |
|||
if retrieved_results: |
|||
retrieved_titles = [ |
|||
retrieved_result.metadata["title"] for retrieved_result in retrieved_results |
|||
] |
|||
fail = False |
|||
else: |
|||
print("Pipeline error, no retrieved results.") |
|||
retrieved_titles = [] |
|||
fail = True |
|||
return retrieved_titles, consume_tokens, fail |
|||
|
|||
|
|||
def _naive_retrieve_titles(question: str) -> List[str]: |
|||
""" |
|||
Retrieve document titles using naive retrieval method. |
|||
|
|||
Args: |
|||
question (str): The query question. |
|||
|
|||
Returns: |
|||
List[str]: List of retrieved document titles. |
|||
""" |
|||
retrieved_results = naive_retrieve(question) |
|||
retrieved_titles = [ |
|||
retrieved_result.metadata["title"] for retrieved_result in retrieved_results |
|||
] |
|||
return retrieved_titles |
|||
|
|||
|
|||
def _calcu_recall(sample, retrieved_titles, dataset) -> dict: |
|||
""" |
|||
Calculate recall metrics for retrieved titles. |
|||
|
|||
Args: |
|||
sample: The sample data containing ground truth information. |
|||
retrieved_titles: List of retrieved document titles. |
|||
dataset (str): The name of the dataset being evaluated. |
|||
|
|||
Returns: |
|||
dict: Dictionary containing recall values at different k values. |
|||
|
|||
Raises: |
|||
NotImplementedError: If the dataset is not supported. |
|||
""" |
|||
if dataset in ["2wikimultihopqa"]: |
|||
gold_passages = [item for item in sample["supporting_facts"]] |
|||
gold_items = set([item[0] for item in gold_passages]) |
|||
retrieved_items = retrieved_titles |
|||
else: |
|||
raise NotImplementedError |
|||
|
|||
recall = dict() |
|||
for k in k_list: |
|||
recall[k] = round( |
|||
sum(1 for t in gold_items if t in retrieved_items[:k]) / len(gold_items), 4 |
|||
) |
|||
return recall |
|||
|
|||
|
|||
def _print_recall_line(recall: dict, pre_str="", post_str="\n"): |
|||
""" |
|||
Print recall metrics in a formatted line. |
|||
|
|||
Args: |
|||
recall (dict): Dictionary containing recall values at different k values. |
|||
pre_str (str, optional): String to print before recall values. Defaults to "". |
|||
post_str (str, optional): String to print after recall values. Defaults to "\n". |
|||
""" |
|||
print(pre_str, end="") |
|||
for k in k_list: |
|||
print(f"R@{k}: {recall[k]:.3f} ", end="") |
|||
print(post_str, end="") |
|||
|
|||
|
|||
def evaluate( |
|||
dataset: str, |
|||
output_root: str, |
|||
pre_num: int = 10, |
|||
max_iter: int = 3, |
|||
skip_load=False, |
|||
flag: str = "result", |
|||
): |
|||
""" |
|||
Evaluate the retrieval performance on a dataset. |
|||
|
|||
Args: |
|||
dataset (str): Name of the dataset to evaluate. |
|||
output_root (str): Root directory for output files. |
|||
pre_num (int, optional): Number of samples to evaluate. Defaults to 10. |
|||
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3. |
|||
skip_load (bool, optional): Whether to skip loading the dataset. Defaults to False. |
|||
flag (str, optional): Flag for the evaluation run. Defaults to "result". |
|||
""" |
|||
corpus_file = os.path.join(current_dir, f"../examples/data/{dataset}_corpus.json") |
|||
if not skip_load: |
|||
# set chunk size to a large number to avoid chunking, because the dataset was chunked already. |
|||
load_from_local_files( |
|||
corpus_file, force_new_collection=True, chunk_size=999999, chunk_overlap=0 |
|||
) |
|||
|
|||
eval_output_subdir = os.path.join(output_root, flag) |
|||
os.makedirs(eval_output_subdir, exist_ok=True) |
|||
csv_file_path = os.path.join(eval_output_subdir, "details.csv") |
|||
statistics_file_path = os.path.join(eval_output_subdir, "statistics.json") |
|||
|
|||
data_with_gt_file_path = os.path.join(current_dir, f"../examples/data/{dataset}.json") |
|||
data_with_gt = json.load(open(data_with_gt_file_path, "r")) |
|||
|
|||
if not pre_num: |
|||
pre_num = len(data_with_gt) |
|||
|
|||
pipeline_error_num = 0 |
|||
end_ind = min(pre_num, len(data_with_gt)) |
|||
|
|||
start_ind = 0 |
|||
existing_df = pd.DataFrame() |
|||
existing_statistics = defaultdict(dict) |
|||
existing_token_usage = 0 |
|||
existing_error_num = 0 |
|||
existing_sample_num = 0 |
|||
if os.path.exists(csv_file_path): |
|||
existing_df = pd.read_csv(csv_file_path) |
|||
start_ind = len(existing_df) |
|||
print(f"Loading results from {csv_file_path}, start_index = {start_ind}") |
|||
|
|||
if os.path.exists(statistics_file_path): |
|||
existing_statistics = json.load(open(statistics_file_path, "r")) |
|||
print( |
|||
f"Loading statistics from {statistics_file_path}, will recalculate the statistics based on both new and existing results." |
|||
) |
|||
existing_token_usage = existing_statistics["deepsearcher"]["token_usage"] |
|||
existing_error_num = existing_statistics["deepsearcher"].get("error_num", 0) |
|||
existing_sample_num = existing_statistics["deepsearcher"].get("sample_num", 0) |
|||
for sample_idx, sample in enumerate(data_with_gt[start_ind:end_ind]): |
|||
global_idx = sample_idx + start_ind |
|||
question = sample["question"] |
|||
|
|||
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles( |
|||
question, max_iter=max_iter |
|||
) |
|||
retrieved_titles_naive = _naive_retrieve_titles(question) |
|||
|
|||
if fail: |
|||
pipeline_error_num += 1 |
|||
print( |
|||
f"Pipeline error, no retrieved results. Current pipeline_error_num = {pipeline_error_num}" |
|||
) |
|||
|
|||
print(f"idx: {global_idx}: ") |
|||
recall = _calcu_recall(sample, retrieved_titles, dataset) |
|||
recall_naive = _calcu_recall(sample, retrieved_titles_naive, dataset) |
|||
current_result = [ |
|||
{ |
|||
"idx": global_idx, |
|||
"question": question, |
|||
"recall": recall, |
|||
"recall_naive": recall_naive, |
|||
"gold_titles": [item[0] for item in sample["supporting_facts"]], |
|||
"retrieved_titles": retrieved_titles, |
|||
"retrieved_titles_naive": retrieved_titles_naive, |
|||
} |
|||
] |
|||
current_df = pd.DataFrame(current_result) |
|||
existing_df = pd.concat([existing_df, current_df], ignore_index=True) |
|||
existing_df.to_csv(csv_file_path, index=False) |
|||
average_recall = dict() |
|||
average_recall_naive = dict() |
|||
for k in k_list: |
|||
average_recall[k] = sum( |
|||
[ |
|||
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k) |
|||
for d in existing_df["recall"] |
|||
] |
|||
) / len(existing_df) |
|||
average_recall_naive[k] = sum( |
|||
[ |
|||
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k) |
|||
for d in existing_df["recall_naive"] |
|||
] |
|||
) / len(existing_df) |
|||
_print_recall_line(average_recall, pre_str="Average recall of DeepSearcher: ") |
|||
_print_recall_line(average_recall_naive, pre_str="Average recall of naive RAG : ") |
|||
existing_token_usage += consume_tokens |
|||
existing_error_num += 1 if fail else 0 |
|||
existing_sample_num += 1 |
|||
existing_statistics["deepsearcher"]["average_recall"] = average_recall |
|||
existing_statistics["deepsearcher"]["token_usage"] = existing_token_usage |
|||
existing_statistics["deepsearcher"]["error_num"] = existing_error_num |
|||
existing_statistics["deepsearcher"]["sample_num"] = existing_sample_num |
|||
existing_statistics["deepsearcher"]["token_usage_per_sample"] = ( |
|||
existing_token_usage / existing_sample_num |
|||
) |
|||
existing_statistics["naive_rag"]["average_recall"] = average_recall_naive |
|||
json.dump(existing_statistics, open(statistics_file_path, "w"), indent=4) |
|||
print("") |
|||
print("Finish results to save.") |
|||
|
|||
|
|||
def main_eval(): |
|||
""" |
|||
Main function for running the evaluation from command line. |
|||
|
|||
This function parses command line arguments and calls the evaluate function |
|||
with the appropriate parameters. |
|||
""" |
|||
parser = argparse.ArgumentParser(prog="evaluate", description="Deep Searcher evaluation.") |
|||
parser.add_argument( |
|||
"--dataset", |
|||
type=str, |
|||
default="2wikimultihopqa", |
|||
help="Dataset name, default is `2wikimultihopqa`. More datasets will be supported in the future.", |
|||
) |
|||
parser.add_argument( |
|||
"--config_yaml", |
|||
type=str, |
|||
default="./eval_config.yaml", |
|||
help="Configuration yaml file path, default is `./eval_config.yaml`", |
|||
) |
|||
parser.add_argument( |
|||
"--pre_num", |
|||
type=int, |
|||
default=30, |
|||
help="Number of samples to evaluate, default is 30", |
|||
) |
|||
parser.add_argument( |
|||
"--max_iter", |
|||
type=int, |
|||
default=3, |
|||
help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.", |
|||
) |
|||
parser.add_argument( |
|||
"--output_dir", |
|||
type=str, |
|||
default="./eval_output", |
|||
help="Output root directory, default is `./eval_output`", |
|||
) |
|||
parser.add_argument( |
|||
"--skip_load", |
|||
action="store_true", |
|||
help="Whether to skip loading the dataset. Default it don't skip loading. If you want to skip loading, please set this flag.", |
|||
) |
|||
parser.add_argument( |
|||
"--flag", |
|||
type=str, |
|||
default="result", |
|||
help="Flag for evaluation, default is `result`", |
|||
) |
|||
|
|||
args = parser.parse_args() |
|||
|
|||
config = Configuration(config_path=args.config_yaml) |
|||
init_config(config=config) |
|||
|
|||
evaluate( |
|||
dataset=args.dataset, |
|||
output_root=args.output_dir, |
|||
pre_num=args.pre_num, |
|||
max_iter=args.max_iter, |
|||
skip_load=args.skip_load, |
|||
flag=args.flag, |
|||
) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main_eval() |
Before Width: | Height: | Size: 124 KiB |
Before Width: | Height: | Size: 92 KiB |
Before Width: | Height: | Size: 130 KiB |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher package |
@ -1 +0,0 @@ |
|||
# Tests for the agent module |
@ -1,149 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import MagicMock |
|||
import numpy as np |
|||
|
|||
from deepsearcher.llm.base import BaseLLM, ChatResponse |
|||
from deepsearcher.embedding.base import BaseEmbedding |
|||
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, CollectionInfo |
|||
|
|||
|
|||
class MockLLM(BaseLLM): |
|||
"""Mock LLM implementation for testing agents.""" |
|||
|
|||
def __init__(self, predefined_responses=None): |
|||
""" |
|||
Initialize the MockLLM. |
|||
|
|||
Args: |
|||
predefined_responses: Dictionary mapping prompt substrings to responses |
|||
""" |
|||
self.chat_called = False |
|||
self.last_messages = None |
|||
self.predefined_responses = predefined_responses or {} |
|||
|
|||
def chat(self, messages, **kwargs): |
|||
"""Mock implementation of chat that returns predefined responses or a default response.""" |
|||
self.chat_called = True |
|||
self.last_messages = messages |
|||
|
|||
if self.predefined_responses: |
|||
message_content = messages[0]["content"] if messages else "" |
|||
for key, response in self.predefined_responses.items(): |
|||
if key in message_content: |
|||
return ChatResponse(content=response, total_tokens=10) |
|||
|
|||
# Default response for RERANK_PROMPT - treat all chunks as relevant |
|||
if "Based on the query questions and the retrieved chunks" in message_content: |
|||
# Count the number of chunks in the message |
|||
chunk_count = message_content.count("<chunk_") |
|||
# Return a list with "YES" for each chunk |
|||
return ChatResponse(content=str(["YES"] * chunk_count), total_tokens=10) |
|||
|
|||
return ChatResponse(content="This is a test answer", total_tokens=10) |
|||
|
|||
def literal_eval(self, text): |
|||
"""Mock implementation of literal_eval.""" |
|||
# Default implementation returns a list with test_collection |
|||
# Override this in specific tests if needed |
|||
if text.strip().startswith("[") and text.strip().endswith("]"): |
|||
# Return the list as is if it's already in list format |
|||
try: |
|||
import ast |
|||
return ast.literal_eval(text) |
|||
except: |
|||
pass |
|||
|
|||
return ["test_collection"] |
|||
|
|||
|
|||
class MockEmbedding(BaseEmbedding): |
|||
"""Mock embedding model implementation for testing agents.""" |
|||
|
|||
def __init__(self, dimension=8): |
|||
"""Initialize the MockEmbedding with a specific dimension.""" |
|||
self._dimension = dimension |
|||
|
|||
@property |
|||
def dimension(self): |
|||
"""Return the dimension of the embedding model.""" |
|||
return self._dimension |
|||
|
|||
def embed_query(self, text): |
|||
"""Mock implementation that returns a random vector of the specified dimension.""" |
|||
return np.random.random(self._dimension).tolist() |
|||
|
|||
def embed_documents(self, documents): |
|||
"""Mock implementation that returns random vectors for each document.""" |
|||
return [np.random.random(self._dimension).tolist() for _ in documents] |
|||
|
|||
|
|||
class MockVectorDB(BaseVectorDB): |
|||
"""Mock vector database implementation for testing agents.""" |
|||
|
|||
def __init__(self, collections=None): |
|||
""" |
|||
Initialize the MockVectorDB. |
|||
|
|||
Args: |
|||
collections: List of collection names to initialize with |
|||
""" |
|||
self.default_collection = "test_collection" |
|||
self.search_called = False |
|||
self.insert_called = False |
|||
self._collections = [] |
|||
|
|||
if collections: |
|||
for collection in collections: |
|||
self._collections.append( |
|||
CollectionInfo(collection_name=collection, description=f"Test collection {collection}") |
|||
) |
|||
else: |
|||
self._collections = [ |
|||
CollectionInfo(collection_name="test_collection", description="Test collection for testing") |
|||
] |
|||
|
|||
def search_data(self, collection, vector, top_k=10, **kwargs): |
|||
"""Mock implementation that returns test results.""" |
|||
self.search_called = True |
|||
self.last_search_collection = collection |
|||
self.last_search_vector = vector |
|||
self.last_search_top_k = top_k |
|||
|
|||
return [ |
|||
RetrievalResult( |
|||
embedding=vector, |
|||
text=f"Test result {i} for collection {collection}", |
|||
reference=f"test_reference_{collection}_{i}", |
|||
metadata={"a": i, "wider_text": f"Wider context for test result {i} in collection {collection}"} |
|||
) |
|||
for i in range(min(3, top_k)) |
|||
] |
|||
|
|||
def insert_data(self, collection, chunks): |
|||
"""Mock implementation of insert_data.""" |
|||
self.insert_called = True |
|||
self.last_insert_collection = collection |
|||
self.last_insert_chunks = chunks |
|||
return True |
|||
|
|||
def init_collection(self, dim, collection, **kwargs): |
|||
"""Mock implementation of init_collection.""" |
|||
return True |
|||
|
|||
def list_collections(self, dim=None): |
|||
"""Mock implementation that returns the list of collections.""" |
|||
return self._collections |
|||
|
|||
def clear_db(self, collection): |
|||
"""Mock implementation of clear_db.""" |
|||
return True |
|||
|
|||
|
|||
class BaseAgentTest(unittest.TestCase): |
|||
"""Base test class for agent tests with common setup.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for agent tests.""" |
|||
self.llm = MockLLM() |
|||
self.embedding_model = MockEmbedding(dimension=8) |
|||
self.vector_db = MockVectorDB() |
@ -1,237 +0,0 @@ |
|||
from unittest.mock import MagicMock, patch |
|||
|
|||
from deepsearcher.agent import ChainOfRAG |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
from tests.agent.test_base import BaseAgentTest |
|||
|
|||
|
|||
class TestChainOfRAG(BaseAgentTest): |
|||
"""Test class for ChainOfRAG agent.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for ChainOfRAG tests.""" |
|||
super().setUp() |
|||
|
|||
# Set up predefined responses for the LLM for exact prompt substrings |
|||
self.llm.predefined_responses = { |
|||
"previous queries and answers, generate a new simple follow-up question": "What is the significance of deep learning?", |
|||
"Given the following documents, generate an appropriate answer": "Deep learning is a subset of machine learning that uses neural networks with multiple layers.", |
|||
"given the following intermediate queries and answers, judge whether you have enough information": "Yes", |
|||
"Given a list of agent indexes and corresponding descriptions": "1", |
|||
"Given the following documents, select the ones that are support the Q-A pair": "[0, 1]", |
|||
"Given the following intermediate queries and answers, generate a final answer": "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers." |
|||
} |
|||
|
|||
self.chain_of_rag = ChainOfRAG( |
|||
llm=self.llm, |
|||
embedding_model=self.embedding_model, |
|||
vector_db=self.vector_db, |
|||
max_iter=3, |
|||
early_stopping=True, |
|||
route_collection=True, |
|||
text_window_splitter=True |
|||
) |
|||
|
|||
def test_init(self): |
|||
"""Test the initialization of ChainOfRAG.""" |
|||
self.assertEqual(self.chain_of_rag.llm, self.llm) |
|||
self.assertEqual(self.chain_of_rag.embedding_model, self.embedding_model) |
|||
self.assertEqual(self.chain_of_rag.vector_db, self.vector_db) |
|||
self.assertEqual(self.chain_of_rag.max_iter, 3) |
|||
self.assertEqual(self.chain_of_rag.early_stopping, True) |
|||
self.assertEqual(self.chain_of_rag.route_collection, True) |
|||
self.assertEqual(self.chain_of_rag.text_window_splitter, True) |
|||
|
|||
def test_reflect_get_subquery(self): |
|||
"""Test the _reflect_get_subquery method.""" |
|||
query = "What is deep learning?" |
|||
intermediate_context = ["Previous query: What is AI?", "Previous answer: AI is artificial intelligence."] |
|||
|
|||
# Direct mock for this specific method |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content="What is the significance of deep learning?", |
|||
total_tokens=10 |
|||
)) |
|||
|
|||
subquery, tokens = self.chain_of_rag._reflect_get_subquery(query, intermediate_context) |
|||
|
|||
self.assertEqual(subquery, "What is the significance of deep learning?") |
|||
self.assertEqual(tokens, 10) |
|||
self.assertTrue(self.llm.chat.called) |
|||
|
|||
def test_retrieve_and_answer(self): |
|||
"""Test the _retrieve_and_answer method.""" |
|||
query = "What is deep learning?" |
|||
|
|||
# Mock the collection_router.invoke method |
|||
self.chain_of_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5)) |
|||
|
|||
# Direct mock for this specific method |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content="Deep learning is a subset of machine learning that uses neural networks with multiple layers.", |
|||
total_tokens=10 |
|||
)) |
|||
|
|||
answer, results, tokens = self.chain_of_rag._retrieve_and_answer(query) |
|||
|
|||
# Check if correct methods were called |
|||
self.chain_of_rag.collection_router.invoke.assert_called_once() |
|||
self.assertTrue(self.vector_db.search_called) |
|||
|
|||
# Check the results |
|||
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.") |
|||
self.assertEqual(tokens, 15) # 5 from collection_router + 10 from LLM |
|||
|
|||
def test_get_supported_docs(self): |
|||
"""Test the _get_supported_docs method.""" |
|||
results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text=f"Test result {i}", |
|||
reference="test_reference", |
|||
metadata={"a": i} |
|||
) |
|||
for i in range(3) |
|||
] |
|||
|
|||
query = "What is deep learning?" |
|||
answer = "Deep learning is a subset of machine learning that uses neural networks with multiple layers." |
|||
|
|||
# Mock the literal_eval method to return indices as integers |
|||
self.llm.literal_eval = MagicMock(return_value=[0, 1]) |
|||
|
|||
supported_docs, tokens = self.chain_of_rag._get_supported_docs(results, query, answer) |
|||
|
|||
self.assertEqual(len(supported_docs), 2) # Based on our mock response of [0, 1] |
|||
self.assertEqual(tokens, 10) |
|||
|
|||
def test_check_has_enough_info(self): |
|||
"""Test the _check_has_enough_info method.""" |
|||
query = "What is deep learning?" |
|||
intermediate_contexts = [ |
|||
"Intermediate query1: What is deep learning?", |
|||
"Intermediate answer1: Deep learning is a subset of machine learning that uses neural networks with multiple layers." |
|||
] |
|||
|
|||
# Direct mock for this specific method |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content="Yes", |
|||
total_tokens=10 |
|||
)) |
|||
|
|||
has_enough, tokens = self.chain_of_rag._check_has_enough_info(query, intermediate_contexts) |
|||
|
|||
self.assertTrue(has_enough) # Based on our mock response of "Yes" |
|||
self.assertEqual(tokens, 10) |
|||
|
|||
def test_retrieve(self): |
|||
"""Test the retrieve method.""" |
|||
query = "What is deep learning?" |
|||
|
|||
# Mock all the methods that retrieve calls |
|||
self.chain_of_rag._reflect_get_subquery = MagicMock(return_value=("What is the significance of deep learning?", 5)) |
|||
self.chain_of_rag._retrieve_and_answer = MagicMock( |
|||
return_value=("Deep learning is important in AI", [RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Test result", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
)], 10) |
|||
) |
|||
self.chain_of_rag._get_supported_docs = MagicMock(return_value=([RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Test result", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
)], 5)) |
|||
self.chain_of_rag._check_has_enough_info = MagicMock(return_value=(True, 5)) |
|||
|
|||
results, tokens, metadata = self.chain_of_rag.retrieve(query) |
|||
|
|||
# Check if methods were called |
|||
self.chain_of_rag._reflect_get_subquery.assert_called_once() |
|||
self.chain_of_rag._retrieve_and_answer.assert_called_once() |
|||
self.chain_of_rag._get_supported_docs.assert_called_once() |
|||
|
|||
# With early stopping, it should check if we have enough info |
|||
self.chain_of_rag._check_has_enough_info.assert_called_once() |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 1) |
|||
self.assertEqual(tokens, 25) # 5 + 10 + 5 + 5 |
|||
self.assertIn("intermediate_context", metadata) |
|||
|
|||
def test_query(self): |
|||
"""Test the query method.""" |
|||
query = "What is deep learning?" |
|||
|
|||
# Mock the retrieve method |
|||
retrieved_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text=f"Test result {i}", |
|||
reference="test_reference", |
|||
metadata={"a": i, "wider_text": f"Wider context for test result {i}"} |
|||
) |
|||
for i in range(3) |
|||
] |
|||
|
|||
self.chain_of_rag.retrieve = MagicMock( |
|||
return_value=(retrieved_results, 20, {"intermediate_context": ["Some context"]}) |
|||
) |
|||
|
|||
# Direct mock for this specific method |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content="Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.", |
|||
total_tokens=10 |
|||
)) |
|||
|
|||
answer, results, tokens = self.chain_of_rag.query(query) |
|||
|
|||
# Check if methods were called |
|||
self.chain_of_rag.retrieve.assert_called_once_with(query) |
|||
self.assertTrue(self.llm.chat.called) |
|||
|
|||
# Check results |
|||
self.assertEqual(answer, "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.") |
|||
self.assertEqual(results, retrieved_results) |
|||
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM |
|||
|
|||
def test_format_retrieved_results(self): |
|||
"""Test the _format_retrieved_results method.""" |
|||
retrieved_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Test result 1", |
|||
reference="test_reference", |
|||
metadata={"a": 1, "wider_text": "Wider context for test result 1"} |
|||
), |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Test result 2", |
|||
reference="test_reference", |
|||
metadata={"a": 2, "wider_text": "Wider context for test result 2"} |
|||
) |
|||
] |
|||
|
|||
# Test with text_window_splitter enabled |
|||
self.chain_of_rag.text_window_splitter = True |
|||
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results) |
|||
|
|||
self.assertIn("Wider context for test result 1", formatted) |
|||
self.assertIn("Wider context for test result 2", formatted) |
|||
|
|||
# Test with text_window_splitter disabled |
|||
self.chain_of_rag.text_window_splitter = False |
|||
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results) |
|||
|
|||
self.assertIn("Test result 1", formatted) |
|||
self.assertIn("Test result 2", formatted) |
|||
self.assertNotIn("Wider context for test result 1", formatted) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
import unittest |
|||
unittest.main() |
@ -1,154 +0,0 @@ |
|||
from unittest.mock import MagicMock, patch |
|||
|
|||
from deepsearcher.agent.collection_router import CollectionRouter |
|||
from deepsearcher.llm.base import ChatResponse |
|||
from deepsearcher.vector_db.base import CollectionInfo |
|||
|
|||
from tests.agent.test_base import BaseAgentTest |
|||
|
|||
|
|||
class TestCollectionRouter(BaseAgentTest): |
|||
"""Test class for CollectionRouter.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for CollectionRouter tests.""" |
|||
super().setUp() |
|||
|
|||
# Create mock collections |
|||
self.collection_infos = [ |
|||
CollectionInfo(collection_name="books", description="Collection of book summaries"), |
|||
CollectionInfo(collection_name="science", description="Scientific articles and papers"), |
|||
CollectionInfo(collection_name="news", description="Recent news articles") |
|||
] |
|||
|
|||
# Configure vector_db mock |
|||
self.vector_db.list_collections = MagicMock(return_value=self.collection_infos) |
|||
self.vector_db.default_collection = "books" |
|||
|
|||
# Create the CollectionRouter |
|||
self.collection_router = CollectionRouter( |
|||
llm=self.llm, |
|||
vector_db=self.vector_db, |
|||
dim=8 |
|||
) |
|||
|
|||
def test_init(self): |
|||
"""Test the initialization of CollectionRouter.""" |
|||
self.assertEqual(self.collection_router.llm, self.llm) |
|||
self.assertEqual(self.collection_router.vector_db, self.vector_db) |
|||
self.assertEqual( |
|||
self.collection_router.all_collections, |
|||
["books", "science", "news"] |
|||
) |
|||
|
|||
def test_invoke_with_multiple_collections(self): |
|||
"""Test the invoke method with multiple collections.""" |
|||
query = "What are the latest scientific breakthroughs?" |
|||
|
|||
# Mock LLM to return specific collections based on query |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content='["science", "news"]', |
|||
total_tokens=10 |
|||
)) |
|||
|
|||
# Disable log output for testing |
|||
with patch('deepsearcher.utils.log.color_print'): |
|||
selected_collections, tokens = self.collection_router.invoke(query, dim=8) |
|||
|
|||
# Check results |
|||
self.assertTrue("science" in selected_collections) |
|||
self.assertTrue("news" in selected_collections) |
|||
self.assertTrue("books" in selected_collections) # Default collection is always included |
|||
self.assertEqual(tokens, 10) |
|||
|
|||
# Verify that the LLM was called with the right prompt |
|||
self.llm.chat.assert_called_once() |
|||
self.assertIn(query, self.llm.chat.call_args[1]["messages"][0]["content"]) |
|||
self.assertIn("collection_name", self.llm.chat.call_args[1]["messages"][0]["content"]) |
|||
|
|||
def test_invoke_with_empty_response(self): |
|||
"""Test the invoke method when LLM returns an empty list.""" |
|||
query = "Something completely unrelated" |
|||
|
|||
# Mock LLM to return empty list |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content='[]', |
|||
total_tokens=5 |
|||
)) |
|||
|
|||
# Disable log output for testing |
|||
with patch('deepsearcher.utils.log.color_print'): |
|||
selected_collections, tokens = self.collection_router.invoke(query, dim=8) |
|||
|
|||
# Only default collection should be included |
|||
self.assertEqual(len(selected_collections), 1) |
|||
self.assertEqual(selected_collections[0], "books") |
|||
self.assertEqual(tokens, 5) |
|||
|
|||
def test_invoke_with_no_collections(self): |
|||
"""Test the invoke method when no collections are available.""" |
|||
query = "Test query" |
|||
|
|||
# Mock vector_db to return empty list |
|||
self.vector_db.list_collections = MagicMock(return_value=[]) |
|||
|
|||
# Disable log warnings for testing |
|||
with patch('deepsearcher.utils.log.warning'): |
|||
with patch('deepsearcher.utils.log.color_print'): |
|||
selected_collections, tokens = self.collection_router.invoke(query, dim=8) |
|||
|
|||
# Should return empty list and zero tokens |
|||
self.assertEqual(selected_collections, []) |
|||
self.assertEqual(tokens, 0) |
|||
|
|||
def test_invoke_with_single_collection(self): |
|||
"""Test the invoke method when only one collection is available.""" |
|||
query = "Test query" |
|||
|
|||
# Create a fresh mock for llm.chat to verify it's not called |
|||
mock_chat = MagicMock(return_value=ChatResponse(content='[]', total_tokens=0)) |
|||
self.llm.chat = mock_chat |
|||
|
|||
# Mock vector_db to return single collection |
|||
single_collection = [CollectionInfo(collection_name="single", description="The only collection")] |
|||
self.vector_db.list_collections = MagicMock(return_value=single_collection) |
|||
|
|||
# Disable log output for testing |
|||
with patch('deepsearcher.utils.log.color_print'): |
|||
selected_collections, tokens = self.collection_router.invoke(query, dim=8) |
|||
|
|||
# Should return the only collection without calling LLM |
|||
self.assertEqual(selected_collections, ["single"]) |
|||
self.assertEqual(tokens, 0) |
|||
mock_chat.assert_not_called() |
|||
|
|||
def test_invoke_with_no_description(self): |
|||
"""Test the invoke method when a collection has no description.""" |
|||
query = "Test query" |
|||
|
|||
# Create collections with one having no description |
|||
collections_with_no_desc = [ |
|||
CollectionInfo(collection_name="with_desc", description="Has description"), |
|||
CollectionInfo(collection_name="no_desc", description="") |
|||
] |
|||
self.vector_db.list_collections = MagicMock(return_value=collections_with_no_desc) |
|||
self.vector_db.default_collection = "with_desc" |
|||
|
|||
# Mock LLM to return only the first collection |
|||
self.llm.chat = MagicMock(return_value=ChatResponse( |
|||
content='["with_desc"]', |
|||
total_tokens=5 |
|||
)) |
|||
|
|||
# Disable log output for testing |
|||
with patch('deepsearcher.utils.log.color_print'): |
|||
selected_collections, tokens = self.collection_router.invoke(query, dim=8) |
|||
|
|||
# Both collections should be included (one from LLM, one with no description) |
|||
self.assertEqual(set(selected_collections), {"with_desc", "no_desc"}) |
|||
self.assertEqual(tokens, 5) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
import unittest |
|||
unittest.main() |
@ -1,235 +0,0 @@ |
|||
from unittest.mock import MagicMock, patch |
|||
import asyncio |
|||
|
|||
from deepsearcher.agent import DeepSearch |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
from tests.agent.test_base import BaseAgentTest |
|||
|
|||
|
|||
class TestDeepSearch(BaseAgentTest): |
|||
"""Test class for DeepSearch agent.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for DeepSearch tests.""" |
|||
super().setUp() |
|||
|
|||
# Set up predefined responses for the LLM for exact prompt substrings |
|||
self.llm.predefined_responses = { |
|||
"Original Question:": '["What is deep learning?", "How does deep learning work?", "What are applications of deep learning?"]', |
|||
"Is the chunk helpful": "YES", |
|||
"Respond exclusively in valid List": '["What are limitations of deep learning?"]', |
|||
"You are a AI content analysis expert": "Deep learning is a subset of machine learning that uses neural networks with multiple layers." |
|||
} |
|||
|
|||
self.deep_search = DeepSearch( |
|||
llm=self.llm, |
|||
embedding_model=self.embedding_model, |
|||
vector_db=self.vector_db, |
|||
max_iter=2, |
|||
route_collection=True, |
|||
text_window_splitter=True |
|||
) |
|||
|
|||
def test_init(self): |
|||
"""Test the initialization of DeepSearch.""" |
|||
self.assertEqual(self.deep_search.llm, self.llm) |
|||
self.assertEqual(self.deep_search.embedding_model, self.embedding_model) |
|||
self.assertEqual(self.deep_search.vector_db, self.vector_db) |
|||
self.assertEqual(self.deep_search.max_iter, 2) |
|||
self.assertEqual(self.deep_search.route_collection, True) |
|||
self.assertEqual(self.deep_search.text_window_splitter, True) |
|||
|
|||
def test_generate_sub_queries(self): |
|||
"""Test the _generate_sub_queries method.""" |
|||
query = "Tell me about deep learning" |
|||
|
|||
sub_queries, tokens = self.deep_search._generate_sub_queries(query) |
|||
|
|||
self.assertEqual(len(sub_queries), 3) |
|||
self.assertEqual(sub_queries[0], "What is deep learning?") |
|||
self.assertEqual(sub_queries[1], "How does deep learning work?") |
|||
self.assertEqual(sub_queries[2], "What are applications of deep learning?") |
|||
self.assertEqual(tokens, 10) |
|||
self.assertTrue(self.llm.chat_called) |
|||
|
|||
def test_search_chunks_from_vectordb(self): |
|||
"""Test the _search_chunks_from_vectordb method.""" |
|||
query = "What is deep learning?" |
|||
sub_queries = ["What is deep learning?", "How does deep learning work?"] |
|||
|
|||
# Mock the collection_router.invoke method |
|||
self.deep_search.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5)) |
|||
|
|||
# Run the async method using asyncio.run |
|||
results, tokens = asyncio.run( |
|||
self.deep_search._search_chunks_from_vectordb(query, sub_queries) |
|||
) |
|||
|
|||
# Check if correct methods were called |
|||
self.deep_search.collection_router.invoke.assert_called_once() |
|||
self.assertTrue(self.vector_db.search_called) |
|||
self.assertTrue(self.llm.chat_called) |
|||
|
|||
# With our mock returning "YES" for RERANK_PROMPT, all chunks should be accepted |
|||
self.assertEqual(len(results), 3) # 3 mock results from MockVectorDB |
|||
self.assertEqual(tokens, 15) # 5 from collection_router + 10*1 from LLM calls for reranking (batch) |
|||
|
|||
def test_generate_gap_queries(self): |
|||
"""Test the _generate_gap_queries method.""" |
|||
query = "Tell me about deep learning" |
|||
all_sub_queries = ["What is deep learning?", "How does deep learning work?"] |
|||
all_chunks = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Deep learning is a subset of machine learning", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
), |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Deep learning uses neural networks", |
|||
reference="test_reference", |
|||
metadata={"a": 2} |
|||
) |
|||
] |
|||
|
|||
gap_queries, tokens = self.deep_search._generate_gap_queries(query, all_sub_queries, all_chunks) |
|||
|
|||
self.assertEqual(len(gap_queries), 1) |
|||
self.assertEqual(gap_queries[0], "What are limitations of deep learning?") |
|||
self.assertEqual(tokens, 10) |
|||
|
|||
def test_retrieve(self): |
|||
"""Test the retrieve method.""" |
|||
query = "Tell me about deep learning" |
|||
|
|||
# Mock async method to run synchronously |
|||
async def mock_async_retrieve(*args, **kwargs): |
|||
# Create some test results |
|||
results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Deep learning is a subset of machine learning", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
), |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Deep learning uses neural networks", |
|||
reference="test_reference", |
|||
metadata={"a": 2} |
|||
) |
|||
] |
|||
# Return the results, token count, and additional info |
|||
return results, 30, {"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]} |
|||
|
|||
# Replace the async method with our mock |
|||
self.deep_search.async_retrieve = mock_async_retrieve |
|||
|
|||
results, tokens, metadata = self.deep_search.retrieve(query) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 2) |
|||
self.assertEqual(tokens, 30) |
|||
self.assertIn("all_sub_queries", metadata) |
|||
self.assertEqual(len(metadata["all_sub_queries"]), 2) |
|||
|
|||
def test_async_retrieve(self): |
|||
"""Test the async_retrieve method.""" |
|||
query = "Tell me about deep learning" |
|||
|
|||
# Create mock results |
|||
mock_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Deep learning is a subset of machine learning", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
) |
|||
] |
|||
|
|||
# Create a mock async_retrieve result |
|||
mock_retrieve_result = ( |
|||
mock_results, |
|||
20, |
|||
{"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]} |
|||
) |
|||
|
|||
# Mock the async_retrieve method |
|||
async def mock_async_retrieve(*args, **kwargs): |
|||
return mock_retrieve_result |
|||
|
|||
self.deep_search.async_retrieve = mock_async_retrieve |
|||
|
|||
# Run the async method using asyncio.run |
|||
results, tokens, metadata = asyncio.run(self.deep_search.async_retrieve(query)) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 1) |
|||
self.assertEqual(tokens, 20) |
|||
self.assertIn("all_sub_queries", metadata) |
|||
|
|||
def test_query(self): |
|||
"""Test the query method.""" |
|||
query = "Tell me about deep learning" |
|||
|
|||
# Mock the retrieve method |
|||
retrieved_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text=f"Test result {i}", |
|||
reference="test_reference", |
|||
metadata={"a": i, "wider_text": f"Wider context for test result {i}"} |
|||
) |
|||
for i in range(3) |
|||
] |
|||
|
|||
self.deep_search.retrieve = MagicMock( |
|||
return_value=(retrieved_results, 20, {"all_sub_queries": ["What is deep learning?"]}) |
|||
) |
|||
|
|||
answer, results, tokens = self.deep_search.query(query) |
|||
|
|||
# Check if methods were called |
|||
self.deep_search.retrieve.assert_called_once_with(query) |
|||
self.assertTrue(self.llm.chat_called) |
|||
|
|||
# Check results |
|||
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.") |
|||
self.assertEqual(results, retrieved_results) |
|||
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM |
|||
|
|||
def test_query_no_results(self): |
|||
"""Test the query method when no results are found.""" |
|||
query = "Tell me about deep learning" |
|||
|
|||
# Mock the retrieve method to return no results |
|||
self.deep_search.retrieve = MagicMock( |
|||
return_value=([], 10, {"all_sub_queries": ["What is deep learning?"]}) |
|||
) |
|||
|
|||
answer, results, tokens = self.deep_search.query(query) |
|||
|
|||
# Should return a message saying no results found |
|||
self.assertIn("No relevant information found", answer) |
|||
self.assertEqual(results, []) |
|||
self.assertEqual(tokens, 10) # Only tokens from retrieve |
|||
|
|||
def test_format_chunk_texts(self): |
|||
"""Test the _format_chunk_texts method.""" |
|||
chunk_texts = ["Text 1", "Text 2", "Text 3"] |
|||
|
|||
formatted = self.deep_search._format_chunk_texts(chunk_texts) |
|||
|
|||
self.assertIn("<chunk_0>", formatted) |
|||
self.assertIn("Text 1", formatted) |
|||
self.assertIn("<chunk_1>", formatted) |
|||
self.assertIn("Text 2", formatted) |
|||
self.assertIn("<chunk_2>", formatted) |
|||
self.assertIn("Text 3", formatted) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
import unittest |
|||
unittest.main() |
@ -1,130 +0,0 @@ |
|||
from unittest.mock import MagicMock |
|||
|
|||
from deepsearcher.agent import NaiveRAG |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
from tests.agent.test_base import BaseAgentTest |
|||
|
|||
|
|||
class TestNaiveRAG(BaseAgentTest): |
|||
"""Test class for NaiveRAG agent.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for NaiveRAG tests.""" |
|||
super().setUp() |
|||
self.naive_rag = NaiveRAG( |
|||
llm=self.llm, |
|||
embedding_model=self.embedding_model, |
|||
vector_db=self.vector_db, |
|||
top_k=5, |
|||
route_collection=True, |
|||
text_window_splitter=True |
|||
) |
|||
|
|||
def test_init(self): |
|||
"""Test the initialization of NaiveRAG.""" |
|||
self.assertEqual(self.naive_rag.llm, self.llm) |
|||
self.assertEqual(self.naive_rag.embedding_model, self.embedding_model) |
|||
self.assertEqual(self.naive_rag.vector_db, self.vector_db) |
|||
self.assertEqual(self.naive_rag.top_k, 5) |
|||
self.assertEqual(self.naive_rag.route_collection, True) |
|||
self.assertEqual(self.naive_rag.text_window_splitter, True) |
|||
|
|||
def test_retrieve(self): |
|||
"""Test the retrieve method.""" |
|||
query = "Test query" |
|||
|
|||
# Mock the collection_router.invoke method |
|||
self.naive_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5)) |
|||
|
|||
results, tokens, metadata = self.naive_rag.retrieve(query) |
|||
|
|||
# Check if correct methods were called |
|||
self.naive_rag.collection_router.invoke.assert_called_once() |
|||
self.assertTrue(self.vector_db.search_called) |
|||
|
|||
# Check the results |
|||
self.assertIsInstance(results, list) |
|||
self.assertEqual(len(results), 3) # Should match our mock return of 3 results |
|||
for result in results: |
|||
self.assertIsInstance(result, RetrievalResult) |
|||
|
|||
# Check token count |
|||
self.assertEqual(tokens, 5) # From our mocked collection_router.invoke |
|||
|
|||
def test_retrieve_without_routing(self): |
|||
"""Test retrieve method with routing disabled.""" |
|||
self.naive_rag.route_collection = False |
|||
query = "Test query without routing" |
|||
|
|||
results, tokens, metadata = self.naive_rag.retrieve(query) |
|||
|
|||
# Check that routing was not called |
|||
self.assertTrue(self.vector_db.search_called) |
|||
|
|||
# Check the results |
|||
self.assertIsInstance(results, list) |
|||
for result in results: |
|||
self.assertIsInstance(result, RetrievalResult) |
|||
|
|||
# Check token count |
|||
self.assertEqual(tokens, 0) # No tokens used for routing |
|||
|
|||
def test_query(self): |
|||
"""Test the query method.""" |
|||
query = "Test query for full RAG" |
|||
|
|||
# Mock the retrieve method |
|||
mock_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text=f"Test result {i}", |
|||
reference="test_reference", |
|||
metadata={"a": i, "wider_text": f"Wider context for test result {i}"} |
|||
) |
|||
for i in range(3) |
|||
] |
|||
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {})) |
|||
|
|||
answer, retrieved_results, tokens = self.naive_rag.query(query) |
|||
|
|||
# Check if correct methods were called |
|||
self.naive_rag.retrieve.assert_called_once_with(query) |
|||
self.assertTrue(self.llm.chat_called) |
|||
|
|||
# Check the messages sent to LLM |
|||
self.assertIn("content", self.llm.last_messages[0]) |
|||
self.assertIn(query, self.llm.last_messages[0]["content"]) |
|||
|
|||
# Check the results |
|||
self.assertEqual(answer, "This is a test answer") |
|||
self.assertEqual(retrieved_results, mock_results) |
|||
self.assertEqual(tokens, 15) # 5 from retrieve + 10 from LLM |
|||
|
|||
def test_with_window_splitter_disabled(self): |
|||
"""Test with text window splitter disabled.""" |
|||
self.naive_rag.text_window_splitter = False |
|||
query = "Test query with window splitter off" |
|||
|
|||
# Mock the retrieve method |
|||
mock_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text=f"Test result {i}", |
|||
reference="test_reference", |
|||
metadata={"a": i, "wider_text": f"Wider context for test result {i}"} |
|||
) |
|||
for i in range(3) |
|||
] |
|||
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {})) |
|||
|
|||
answer, retrieved_results, tokens = self.naive_rag.query(query) |
|||
|
|||
# Check that regular text is used instead of wider_text |
|||
self.assertIn("Test result 0", self.llm.last_messages[0]["content"]) |
|||
self.assertNotIn("Wider context", self.llm.last_messages[0]["content"]) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
import unittest |
|||
unittest.main() |
@ -1,162 +0,0 @@ |
|||
from unittest.mock import MagicMock, patch |
|||
|
|||
from deepsearcher.agent import NaiveRAG, ChainOfRAG, DeepSearch |
|||
from deepsearcher.agent.rag_router import RAGRouter |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
from tests.agent.test_base import BaseAgentTest |
|||
|
|||
|
|||
class TestRAGRouter(BaseAgentTest): |
|||
"""Test class for RAGRouter agent.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures for RAGRouter tests.""" |
|||
super().setUp() |
|||
|
|||
# Create mock agent instances |
|||
self.naive_rag = MagicMock(spec=NaiveRAG) |
|||
self.chain_of_rag = MagicMock(spec=ChainOfRAG) |
|||
|
|||
# Create the RAGRouter with the mock agents |
|||
self.rag_router = RAGRouter( |
|||
llm=self.llm, |
|||
rag_agents=[self.naive_rag, self.chain_of_rag], |
|||
agent_descriptions=[ |
|||
"This agent is suitable for simple factual queries", |
|||
"This agent is suitable for complex multi-hop questions" |
|||
] |
|||
) |
|||
|
|||
def test_init(self): |
|||
"""Test the initialization of RAGRouter.""" |
|||
self.assertEqual(self.rag_router.llm, self.llm) |
|||
self.assertEqual(len(self.rag_router.rag_agents), 2) |
|||
self.assertEqual(len(self.rag_router.agent_descriptions), 2) |
|||
self.assertEqual(self.rag_router.agent_descriptions[0], "This agent is suitable for simple factual queries") |
|||
self.assertEqual(self.rag_router.agent_descriptions[1], "This agent is suitable for complex multi-hop questions") |
|||
|
|||
def test_route(self): |
|||
"""Test the _route method.""" |
|||
query = "What is the capital of France?" |
|||
|
|||
# Directly mock the chat method to return a numeric response |
|||
self.llm.chat = MagicMock(return_value=ChatResponse(content="1", total_tokens=10)) |
|||
|
|||
agent, tokens = self.rag_router._route(query) |
|||
|
|||
# Should select the first agent based on our mock response |
|||
self.assertEqual(agent, self.naive_rag) |
|||
self.assertEqual(tokens, 10) |
|||
self.assertTrue(self.llm.chat.called) |
|||
|
|||
def test_route_with_non_numeric_response(self): |
|||
"""Test the _route method with a non-numeric response from LLM.""" |
|||
query = "What is the history of deep learning?" |
|||
|
|||
# Mock the LLM to return a response with a trailing digit |
|||
self.llm.chat = MagicMock(return_value=ChatResponse(content="I recommend agent 2", total_tokens=10)) |
|||
self.rag_router.find_last_digit = MagicMock(return_value="2") |
|||
|
|||
agent, tokens = self.rag_router._route(query) |
|||
|
|||
# Should select the second agent based on our mock response |
|||
self.assertEqual(agent, self.chain_of_rag) |
|||
self.assertTrue(self.rag_router.find_last_digit.called) |
|||
|
|||
def test_retrieve(self): |
|||
"""Test the retrieve method.""" |
|||
query = "What is the capital of France?" |
|||
|
|||
# Mock the _route method to return the first agent |
|||
mock_retrieved_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Paris is the capital of France", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
) |
|||
] |
|||
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5)) |
|||
self.naive_rag.retrieve = MagicMock(return_value=(mock_retrieved_results, 10, {"some": "metadata"})) |
|||
|
|||
results, tokens, metadata = self.rag_router.retrieve(query) |
|||
|
|||
# Check if methods were called |
|||
self.rag_router._route.assert_called_once_with(query) |
|||
self.naive_rag.retrieve.assert_called_once_with(query) |
|||
|
|||
# Check results |
|||
self.assertEqual(results, mock_retrieved_results) |
|||
self.assertEqual(tokens, 15) # 5 from route + 10 from retrieve |
|||
self.assertEqual(metadata, {"some": "metadata"}) |
|||
|
|||
def test_query(self): |
|||
"""Test the query method.""" |
|||
query = "What is the capital of France?" |
|||
|
|||
# Mock the _route method to return the first agent |
|||
mock_retrieved_results = [ |
|||
RetrievalResult( |
|||
embedding=[0.1] * 8, |
|||
text="Paris is the capital of France", |
|||
reference="test_reference", |
|||
metadata={"a": 1} |
|||
) |
|||
] |
|||
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5)) |
|||
self.naive_rag.query = MagicMock(return_value=("Paris is the capital of France", mock_retrieved_results, 10)) |
|||
|
|||
answer, results, tokens = self.rag_router.query(query) |
|||
|
|||
# Check if methods were called |
|||
self.rag_router._route.assert_called_once_with(query) |
|||
self.naive_rag.query.assert_called_once_with(query) |
|||
|
|||
# Check results |
|||
self.assertEqual(answer, "Paris is the capital of France") |
|||
self.assertEqual(results, mock_retrieved_results) |
|||
self.assertEqual(tokens, 15) # 5 from route + 10 from query |
|||
|
|||
def test_find_last_digit(self): |
|||
"""Test the find_last_digit method.""" |
|||
self.assertEqual(self.rag_router.find_last_digit("Agent 2 is better"), "2") |
|||
self.assertEqual(self.rag_router.find_last_digit("I recommend agent number 1"), "1") |
|||
self.assertEqual(self.rag_router.find_last_digit("Choose 3"), "3") |
|||
|
|||
# Test with no digit |
|||
with self.assertRaises(ValueError): |
|||
self.rag_router.find_last_digit("No digits here") |
|||
|
|||
def test_auto_description_fallback(self): |
|||
"""Test that RAGRouter falls back to __description__ when no descriptions provided.""" |
|||
# Create classes with __description__ attribute |
|||
class MockAgent1: |
|||
__description__ = "Auto description 1" |
|||
|
|||
class MockAgent2: |
|||
__description__ = "Auto description 2" |
|||
|
|||
# Create instances of these classes |
|||
mock_agent1 = MagicMock(spec=MockAgent1) |
|||
mock_agent1.__class__ = MockAgent1 |
|||
|
|||
mock_agent2 = MagicMock(spec=MockAgent2) |
|||
mock_agent2.__class__ = MockAgent2 |
|||
|
|||
# Create the RAGRouter without explicit descriptions |
|||
router = RAGRouter( |
|||
llm=self.llm, |
|||
rag_agents=[mock_agent1, mock_agent2] |
|||
) |
|||
|
|||
# Check that descriptions were pulled from the class attributes |
|||
self.assertEqual(len(router.agent_descriptions), 2) |
|||
self.assertEqual(router.agent_descriptions[0], "Auto description 1") |
|||
self.assertEqual(router.agent_descriptions[1], "Auto description 2") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
import unittest |
|||
unittest.main() |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher.embedding package |
@ -1,105 +0,0 @@ |
|||
import unittest |
|||
from typing import List |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from deepsearcher.embedding.base import BaseEmbedding |
|||
from deepsearcher.loader.splitter import Chunk |
|||
|
|||
|
|||
class ConcreteEmbedding(BaseEmbedding): |
|||
"""A concrete implementation of BaseEmbedding for testing.""" |
|||
|
|||
def __init__(self, dimension=768): |
|||
self._dimension = dimension |
|||
|
|||
def embed_query(self, text: str) -> List[float]: |
|||
"""Simple implementation that returns a vector of the given dimension.""" |
|||
return [0.1] * self._dimension |
|||
|
|||
@property |
|||
def dimension(self) -> int: |
|||
return self._dimension |
|||
|
|||
|
|||
class TestBaseEmbedding(unittest.TestCase): |
|||
"""Tests for the BaseEmbedding base class.""" |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test the embed_query method.""" |
|||
embedding = ConcreteEmbedding() |
|||
result = embedding.embed_query("test text") |
|||
self.assertEqual(len(result), 768) |
|||
self.assertEqual(result, [0.1] * 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test the embed_documents method.""" |
|||
embedding = ConcreteEmbedding() |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check we got the right number of embeddings |
|||
self.assertEqual(len(results), 3) |
|||
|
|||
# Check each embedding |
|||
for result in results: |
|||
self.assertEqual(len(result), 768) |
|||
self.assertEqual(result, [0.1] * 768) |
|||
|
|||
@patch('deepsearcher.embedding.base.tqdm') |
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_chunks(self, mock_tqdm): |
|||
"""Test the embed_chunks method.""" |
|||
embedding = ConcreteEmbedding() |
|||
|
|||
# Set up mock tqdm to just return the iterable |
|||
mock_tqdm.return_value = lambda x, desc: x |
|||
|
|||
# Create test chunks |
|||
chunks = [ |
|||
Chunk(text="text 1", reference="ref1"), |
|||
Chunk(text="text 2", reference="ref2"), |
|||
Chunk(text="text 3", reference="ref3") |
|||
] |
|||
|
|||
# Create a spy on embed_documents |
|||
original_embed_documents = embedding.embed_documents |
|||
embed_documents_calls = [] |
|||
|
|||
def mock_embed_documents(texts): |
|||
embed_documents_calls.append(texts) |
|||
return original_embed_documents(texts) |
|||
|
|||
embedding.embed_documents = mock_embed_documents |
|||
|
|||
# Mock tqdm to return the batch_texts directly |
|||
mock_tqdm.side_effect = lambda x, **kwargs: x |
|||
|
|||
# Call the method |
|||
result_chunks = embedding.embed_chunks(chunks, batch_size=2) |
|||
|
|||
# Verify embed_documents was called correctly |
|||
self.assertEqual(len(embed_documents_calls), 2) # Should be called twice with batch_size=2 |
|||
self.assertEqual(embed_documents_calls[0], ["text 1", "text 2"]) |
|||
self.assertEqual(embed_documents_calls[1], ["text 3"]) |
|||
|
|||
# Verify chunks were updated with embeddings |
|||
self.assertEqual(len(result_chunks), 3) |
|||
for chunk in result_chunks: |
|||
self.assertEqual(len(chunk.embedding), 768) |
|||
self.assertEqual(chunk.embedding, [0.1] * 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
embedding = ConcreteEmbedding() |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# Test with different dimension |
|||
embedding = ConcreteEmbedding(dimension=1024) |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,148 +0,0 @@ |
|||
import unittest |
|||
import json |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.embedding import BedrockEmbedding |
|||
from deepsearcher.embedding.bedrock_embedding import ( |
|||
MODEL_ID_TITAN_TEXT_V2, |
|||
MODEL_ID_TITAN_TEXT_G1, |
|||
MODEL_ID_COHERE_ENGLISH_V3, |
|||
) |
|||
|
|||
|
|||
class TestBedrockEmbedding(unittest.TestCase): |
|||
"""Tests for the BedrockEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_boto3 = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_boto3.client = MagicMock(return_value=self.mock_client) |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3}) |
|||
self.module_patcher.start() |
|||
|
|||
# Configure mock response |
|||
self.mock_response = { |
|||
"body": MagicMock(), |
|||
"ResponseMetadata": {"HTTPStatusCode": 200} |
|||
} |
|||
self.mock_response["body"].read.return_value = json.dumps({"embedding": [0.1] * 1024}) |
|||
self.mock_client.invoke_model.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create instance to test |
|||
embedding = BedrockEmbedding() |
|||
|
|||
# Check that boto3 client was created correctly |
|||
self.mock_boto3.client.assert_called_once_with( |
|||
"bedrock-runtime", |
|||
region_name="us-east-1", |
|||
aws_access_key_id=None, |
|||
aws_secret_access_key=None |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_V2) |
|||
|
|||
# Ensure no coroutine warnings |
|||
self.mock_client.invoke_model.return_value = self.mock_response |
|||
|
|||
@patch.dict('os.environ', { |
|||
'AWS_ACCESS_KEY_ID': 'test_key', |
|||
'AWS_SECRET_ACCESS_KEY': 'test_secret' |
|||
}, clear=True) |
|||
def test_init_with_credentials(self): |
|||
"""Test initialization with AWS credentials.""" |
|||
embedding = BedrockEmbedding() |
|||
self.mock_boto3.client.assert_called_with( |
|||
"bedrock-runtime", |
|||
region_name="us-east-1", |
|||
aws_access_key_id="test_key", |
|||
aws_secret_access_key="test_secret" |
|||
) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_different_models(self): |
|||
"""Test initialization with different models.""" |
|||
# Test Titan Text G1 |
|||
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1) |
|||
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_G1) |
|||
|
|||
# Test Cohere English V3 |
|||
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3) |
|||
self.assertEqual(embedding.model, MODEL_ID_COHERE_ENGLISH_V3) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create instance to test |
|||
embedding = BedrockEmbedding() |
|||
|
|||
query = "test query" |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Check that invoke_model was called correctly |
|||
self.mock_client.invoke_model.assert_called_once_with( |
|||
modelId=MODEL_ID_TITAN_TEXT_V2, |
|||
body=json.dumps({"inputText": query}) |
|||
) |
|||
|
|||
# Check result |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create instance to test |
|||
embedding = BedrockEmbedding() |
|||
|
|||
texts = ["text 1", "text 2", "text 3"] |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that invoke_model was called for each text |
|||
self.assertEqual(self.mock_client.invoke_model.call_count, 3) |
|||
for text in texts: |
|||
self.mock_client.invoke_model.assert_any_call( |
|||
modelId=MODEL_ID_TITAN_TEXT_V2, |
|||
body=json.dumps({"inputText": text}) |
|||
) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 3) |
|||
for result in results: |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property for different models.""" |
|||
# Create instance to test with Titan Text V2 |
|||
embedding = BedrockEmbedding() |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# Test Titan Text G1 |
|||
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1) |
|||
self.assertEqual(embedding.dimension, 1536) |
|||
|
|||
# Test Cohere English V3 |
|||
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3) |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,144 +0,0 @@ |
|||
import unittest |
|||
import numpy as np |
|||
from unittest.mock import patch, MagicMock |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.embedding import FastEmbedEmbedding |
|||
|
|||
|
|||
class TestFastEmbedEmbedding(unittest.TestCase): |
|||
"""Tests for the FastEmbedEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_fastembed = MagicMock() |
|||
self.mock_text_embedding = MagicMock() |
|||
self.mock_fastembed.TextEmbedding = MagicMock(return_value=self.mock_text_embedding) |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'fastembed': self.mock_fastembed}) |
|||
self.module_patcher.start() |
|||
|
|||
# Set up mock embeddings |
|||
self.mock_embedding = np.array([0.1] * 384) # BGE-small has 384 dimensions |
|||
self.mock_text_embedding.query_embed.return_value = iter([self.mock_embedding]) |
|||
self.mock_text_embedding.embed.return_value = [self.mock_embedding] * 3 |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create instance to test |
|||
embedding = FastEmbedEmbedding() |
|||
|
|||
# Access a method to trigger lazy loading |
|||
embedding.embed_query("test") |
|||
|
|||
# Check that TextEmbedding was initialized correctly |
|||
self.mock_fastembed.TextEmbedding.assert_called_once_with( |
|||
model_name="BAAI/bge-small-en-v1.5" |
|||
) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
custom_model = "custom/model-name" |
|||
embedding = FastEmbedEmbedding(model=custom_model) |
|||
|
|||
# Access a method to trigger lazy loading |
|||
embedding.embed_query("test") |
|||
|
|||
self.mock_fastembed.TextEmbedding.assert_called_with( |
|||
model_name=custom_model |
|||
) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_kwargs(self): |
|||
"""Test initialization with additional kwargs.""" |
|||
kwargs = {"batch_size": 32, "max_length": 512} |
|||
embedding = FastEmbedEmbedding(**kwargs) |
|||
|
|||
# Access a method to trigger lazy loading |
|||
embedding.embed_query("test") |
|||
|
|||
self.mock_fastembed.TextEmbedding.assert_called_with( |
|||
model_name="BAAI/bge-small-en-v1.5", |
|||
**kwargs |
|||
) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create instance to test |
|||
embedding = FastEmbedEmbedding() |
|||
|
|||
query = "test query" |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Check that query_embed was called correctly |
|||
self.mock_text_embedding.query_embed.assert_called_once_with([query]) |
|||
|
|||
# Check result |
|||
self.assertEqual(len(result), 384) |
|||
np.testing.assert_array_equal(result, [0.1] * 384) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create instance to test |
|||
embedding = FastEmbedEmbedding() |
|||
|
|||
texts = ["text 1", "text 2", "text 3"] |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that embed was called correctly |
|||
self.mock_text_embedding.embed.assert_called_once_with(texts) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 3) |
|||
for result in results: |
|||
self.assertEqual(len(result), 384) |
|||
np.testing.assert_array_equal(result, [0.1] * 384) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create instance to test |
|||
embedding = FastEmbedEmbedding() |
|||
|
|||
# Mock a sample embedding |
|||
sample_embedding = np.array([0.1] * 384) |
|||
self.mock_text_embedding.query_embed.return_value = iter([sample_embedding]) |
|||
|
|||
# Check dimension |
|||
self.assertEqual(embedding.dimension, 384) |
|||
|
|||
# Verify that query_embed was called with sample text |
|||
self.mock_text_embedding.query_embed.assert_called_with(["SAMPLE TEXT"]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_lazy_loading(self): |
|||
"""Test that the model is loaded lazily.""" |
|||
# Create a new instance |
|||
embedding = FastEmbedEmbedding() |
|||
|
|||
# Check that TextEmbedding wasn't called during initialization |
|||
self.mock_fastembed.TextEmbedding.reset_mock() |
|||
self.mock_fastembed.TextEmbedding.assert_not_called() |
|||
|
|||
# Access a method that requires the model |
|||
embedding.embed_query("test") |
|||
|
|||
# Now TextEmbedding should have been called |
|||
self.mock_fastembed.TextEmbedding.assert_called_once() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,266 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
import logging |
|||
import warnings |
|||
import multiprocessing.resource_tracker |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
# Suppress resource tracker warning |
|||
warnings.filterwarnings("ignore", category=ResourceWarning) |
|||
|
|||
# Patch resource tracker to avoid warnings |
|||
def _resource_tracker(): |
|||
pass |
|||
multiprocessing.resource_tracker._resource_tracker = _resource_tracker |
|||
|
|||
from deepsearcher.embedding import GeminiEmbedding |
|||
from deepsearcher.embedding.gemini_embedding import GEMINI_MODEL_DIM_MAP |
|||
|
|||
|
|||
class TestGeminiEmbedding(unittest.TestCase): |
|||
"""Tests for the GeminiEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_genai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_types = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_genai.Client = MagicMock(return_value=self.mock_client) |
|||
self.mock_genai.types = self.mock_types |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'google.genai': self.mock_genai}) |
|||
self.module_patcher.start() |
|||
|
|||
# Set up mock response for embed_content |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.embeddings = [ |
|||
MagicMock(values=[0.1] * 768) # Default embedding for text-embedding-004 |
|||
] |
|||
self.mock_client.models.embed_content.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
# Check that Client was initialized correctly |
|||
self.mock_genai.Client.assert_called_once_with(api_key=None) |
|||
|
|||
# Check default model and dimension |
|||
self.assertEqual(embedding.model, "text-embedding-004") |
|||
self.assertEqual(embedding.dim, 768) |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
@patch.dict('os.environ', {'GEMINI_API_KEY': 'test_api_key_from_env'}, clear=True) |
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
embedding = GeminiEmbedding() |
|||
self.mock_genai.Client.assert_called_with(api_key='test_api_key_from_env') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
api_key = "test_api_key_param" |
|||
embedding = GeminiEmbedding(api_key=api_key) |
|||
self.mock_genai.Client.assert_called_with(api_key=api_key) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
model = "gemini-embedding-exp-03-07" |
|||
embedding = GeminiEmbedding(model=model) |
|||
|
|||
self.assertEqual(embedding.model, model) |
|||
self.assertEqual(embedding.dim, GEMINI_MODEL_DIM_MAP[model]) |
|||
self.assertEqual(embedding.dimension, 3072) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_custom_dimension(self): |
|||
"""Test initialization with custom dimension.""" |
|||
custom_dim = 1024 |
|||
embedding = GeminiEmbedding(dimension=custom_dim) |
|||
|
|||
self.assertEqual(embedding.dim, custom_dim) |
|||
self.assertEqual(embedding.dimension, custom_dim) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query_single_char(self): |
|||
"""Test embedding a single character query.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
query = "a" |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Check that embed_content was called correctly |
|||
self.mock_client.models.embed_content.assert_called_once() |
|||
call_args = self.mock_client.models.embed_content.call_args |
|||
|
|||
# For single character, it should be passed as is |
|||
self.assertEqual(call_args[1]["model"], "text-embedding-004") |
|||
self.assertEqual(call_args[1]["contents"], query) |
|||
|
|||
# Check result |
|||
self.assertEqual(len(result), 768) |
|||
self.assertEqual(result, [0.1] * 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query_multi_char(self): |
|||
"""Test embedding a multi-character query.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
query = "test query" |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Check that embed_content was called correctly |
|||
self.mock_client.models.embed_content.assert_called_once() |
|||
call_args = self.mock_client.models.embed_content.call_args |
|||
|
|||
# For multi-character string, it should be joined with spaces |
|||
self.assertEqual(call_args[1]["model"], "text-embedding-004") |
|||
self.assertEqual(call_args[1]["contents"], "t e s t q u e r y") |
|||
|
|||
# Check result |
|||
self.assertEqual(len(result), 768) |
|||
self.assertEqual(result, [0.1] * 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_embeddings = [ |
|||
MagicMock(values=[0.1] * 768), |
|||
MagicMock(values=[0.2] * 768), |
|||
MagicMock(values=[0.3] * 768) |
|||
] |
|||
self.mock_response.embeddings = mock_embeddings |
|||
|
|||
texts = ["text 1", "text 2", "text 3"] |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that embed_content was called correctly |
|||
self.mock_client.models.embed_content.assert_called_once() |
|||
call_args = self.mock_client.models.embed_content.call_args |
|||
self.assertEqual(call_args[1]["model"], "text-embedding-004") |
|||
self.assertEqual(call_args[1]["contents"], texts) |
|||
|
|||
# Check that EmbedContentConfig was used |
|||
config_arg = call_args[1]["config"] |
|||
self.mock_types.EmbedContentConfig.assert_called_once_with(output_dimensionality=768) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(results), 3) |
|||
expected_results = [[0.1] * 768, [0.2] * 768, [0.3] * 768] |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(len(result), 768) |
|||
self.assertEqual(result, expected_results[i]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_chunks(self): |
|||
"""Test embedding chunks with batch processing.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
# Set up mock response for batched documents |
|||
batch1_embeddings = [MagicMock(values=[0.1] * 768)] * 100 |
|||
batch2_embeddings = [MagicMock(values=[0.2] * 768)] * 50 |
|||
|
|||
# Mock multiple calls to embed_content |
|||
self.mock_client.models.embed_content.side_effect = [ |
|||
MagicMock(embeddings=batch1_embeddings), |
|||
MagicMock(embeddings=batch2_embeddings) |
|||
] |
|||
|
|||
# Create mock chunks |
|||
class MockChunk: |
|||
def __init__(self, text: str): |
|||
self.text = text |
|||
self.embedding = None |
|||
|
|||
chunks = [MockChunk(f"text {i}") for i in range(150)] |
|||
results = embedding.embed_chunks(chunks, batch_size=100) |
|||
|
|||
# Check that embed_content was called twice (150 chunks split into 2 batches) |
|||
self.assertEqual(self.mock_client.models.embed_content.call_count, 2) |
|||
|
|||
# Check that the same chunk objects are returned |
|||
self.assertEqual(len(results), 150) |
|||
self.assertEqual(results, chunks) |
|||
|
|||
# Check that each chunk has an embedding |
|||
for i, chunk in enumerate(results): |
|||
self.assertIsNotNone(chunk.embedding) |
|||
if i < 100: |
|||
self.assertEqual(chunk.embedding, [0.1] * 768) |
|||
else: |
|||
self.assertEqual(chunk.embedding, [0.2] * 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property_different_models(self): |
|||
"""Test the dimension property for different models.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
# Test default model |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# Test experimental model |
|||
embedding_exp = GeminiEmbedding(model="gemini-embedding-exp-03-07") |
|||
self.assertEqual(embedding_exp.dimension, 3072) |
|||
|
|||
# Test custom dimension |
|||
embedding_custom = GeminiEmbedding(dimension=512) |
|||
self.assertEqual(embedding_custom.dimension, 512) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_get_dim_method(self): |
|||
"""Test the private _get_dim method.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
# Test default dimension |
|||
self.assertEqual(embedding._get_dim(), 768) |
|||
|
|||
# Test custom dimension |
|||
embedding_custom = GeminiEmbedding(dimension=1024) |
|||
self.assertEqual(embedding_custom._get_dim(), 1024) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_content_method(self): |
|||
"""Test the private _embed_content method.""" |
|||
# Create instance to test |
|||
embedding = GeminiEmbedding() |
|||
|
|||
texts = ["test text 1", "test text 2"] |
|||
result = embedding._embed_content(texts) |
|||
|
|||
# Check that embed_content was called with correct parameters |
|||
self.mock_client.models.embed_content.assert_called_once() |
|||
call_args = self.mock_client.models.embed_content.call_args |
|||
|
|||
self.assertEqual(call_args[1]["model"], "text-embedding-004") |
|||
self.assertEqual(call_args[1]["contents"], texts) |
|||
self.mock_types.EmbedContentConfig.assert_called_once_with(output_dimensionality=768) |
|||
|
|||
# Check that the response embeddings are returned |
|||
self.assertEqual(result, self.mock_response.embeddings) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,143 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from deepsearcher.embedding import GLMEmbedding |
|||
|
|||
|
|||
class TestGLMEmbedding(unittest.TestCase): |
|||
"""Tests for the GLMEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_zhipuai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_embeddings = MagicMock() |
|||
|
|||
# Set up mock response |
|||
mock_data_item = MagicMock() |
|||
mock_data_item.embedding = [0.1] * 2048 # embedding-3 has 2048 dimensions |
|||
mock_response = MagicMock() |
|||
mock_response.data = [mock_data_item] |
|||
self.mock_embeddings.create.return_value = mock_response |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_zhipuai.ZhipuAI.return_value = self.mock_client |
|||
self.mock_client.embeddings = self.mock_embeddings |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'zhipuai': self.mock_zhipuai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = GLMEmbedding() |
|||
|
|||
# Check that ZhipuAI was initialized correctly |
|||
self.mock_zhipuai.ZhipuAI.assert_called_once_with( |
|||
api_key='fake-api-key', |
|||
base_url='https://open.bigmodel.cn/api/paas/v4/' |
|||
) |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'embedding-3') |
|||
self.assertEqual(embedding.client, self.mock_client) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = GLMEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that ZhipuAI was initialized with the provided API key |
|||
self.mock_zhipuai.ZhipuAI.assert_called_with( |
|||
api_key='test-api-key', |
|||
base_url='https://open.bigmodel.cn/api/paas/v4/' |
|||
) |
|||
|
|||
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_base_url(self): |
|||
"""Test initialization with base URL parameter.""" |
|||
# Initialize with base URL |
|||
embedding = GLMEmbedding(base_url='https://custom-api.example.com') |
|||
|
|||
# Check that ZhipuAI was initialized with the provided base URL |
|||
self.mock_zhipuai.ZhipuAI.assert_called_with( |
|||
api_key='fake-api-key', |
|||
base_url='https://custom-api.example.com' |
|||
) |
|||
|
|||
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = GLMEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that create was called correctly |
|||
self.mock_embeddings.create.assert_called_once_with( |
|||
input=[query], |
|||
model='embedding-3' |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 2048) |
|||
|
|||
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = GLMEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_data_items = [] |
|||
for i in range(3): |
|||
mock_data_item = MagicMock() |
|||
mock_data_item.embedding = [0.1 * (i + 1)] * 2048 |
|||
mock_data_items.append(mock_data_item) |
|||
|
|||
mock_response = MagicMock() |
|||
mock_response.data = mock_data_items |
|||
self.mock_embeddings.create.return_value = mock_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that create was called correctly |
|||
self.mock_embeddings.create.assert_called_once_with( |
|||
input=texts, |
|||
model='embedding-3' |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 2048) |
|||
|
|||
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = GLMEmbedding() |
|||
|
|||
# For embedding-3 |
|||
self.assertEqual(embedding.dimension, 2048) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,130 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import numpy as np |
|||
from deepsearcher.embedding import MilvusEmbedding |
|||
|
|||
|
|||
class TestMilvusEmbedding(unittest.TestCase): |
|||
"""Tests for the MilvusEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_pymilvus = MagicMock() |
|||
self.mock_model = MagicMock() |
|||
self.mock_default_embedding = MagicMock() |
|||
self.mock_jina_embedding = MagicMock() |
|||
self.mock_sentence_transformer = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_pymilvus.model = self.mock_model |
|||
self.mock_model.DefaultEmbeddingFunction = MagicMock(return_value=self.mock_default_embedding) |
|||
self.mock_model.dense = MagicMock() |
|||
self.mock_model.dense.JinaEmbeddingFunction = MagicMock(return_value=self.mock_jina_embedding) |
|||
self.mock_model.dense.SentenceTransformerEmbeddingFunction = MagicMock(return_value=self.mock_sentence_transformer) |
|||
|
|||
# Set up default dimensions and responses |
|||
self.mock_default_embedding.dim = 768 |
|||
self.mock_jina_embedding.dim = 1024 |
|||
self.mock_sentence_transformer.dim = 1024 |
|||
|
|||
# Set up mock responses for encoding |
|||
self.mock_default_embedding.encode_queries.return_value = [np.array([0.1] * 768)] |
|||
self.mock_default_embedding.encode_documents.return_value = [np.array([0.1] * 768)] |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'pymilvus': self.mock_pymilvus}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
embedding = MilvusEmbedding() |
|||
|
|||
# Check that default model was initialized |
|||
self.mock_model.DefaultEmbeddingFunction.assert_called_once() |
|||
self.assertEqual(embedding.model, self.mock_default_embedding) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_jina_model(self): |
|||
"""Test initialization with Jina model.""" |
|||
embedding = MilvusEmbedding(model='jina-embeddings-v3') |
|||
|
|||
# Check that Jina model was initialized |
|||
self.mock_model.dense.JinaEmbeddingFunction.assert_called_once_with('jina-embeddings-v3') |
|||
self.assertEqual(embedding.model, self.mock_jina_embedding) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_bge_model(self): |
|||
"""Test initialization with BGE model.""" |
|||
embedding = MilvusEmbedding(model='BAAI/bge-large-en-v1.5') |
|||
|
|||
# Check that SentenceTransformer model was initialized |
|||
self.mock_model.dense.SentenceTransformerEmbeddingFunction.assert_called_once_with('BAAI/bge-large-en-v1.5') |
|||
self.assertEqual(embedding.model, self.mock_sentence_transformer) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_invalid_model(self): |
|||
"""Test initialization with invalid model raises error.""" |
|||
with self.assertRaises(ValueError): |
|||
MilvusEmbedding(model='invalid-model') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
embedding = MilvusEmbedding() |
|||
query = "This is a test query" |
|||
|
|||
result = embedding.embed_query(query) |
|||
|
|||
# Check that encode_queries was called correctly |
|||
self.mock_default_embedding.encode_queries.assert_called_once_with([query]) |
|||
|
|||
# Convert numpy array to list for comparison |
|||
expected = [0.1] * 768 |
|||
np.testing.assert_array_almost_equal(result, expected) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
embedding = MilvusEmbedding() |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_embeddings = [np.array([0.1 * (i + 1)] * 768) for i in range(3)] |
|||
self.mock_default_embedding.encode_documents.return_value = mock_embeddings |
|||
|
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that encode_documents was called correctly |
|||
self.mock_default_embedding.encode_documents.assert_called_once_with(texts) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
expected = [0.1 * (i + 1)] * 768 |
|||
np.testing.assert_array_almost_equal(result, expected) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# For default model |
|||
embedding = MilvusEmbedding() |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# For Jina model |
|||
embedding = MilvusEmbedding(model='jina-embeddings-v3') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# For BGE model |
|||
embedding = MilvusEmbedding(model='BAAI/bge-large-en-v1.5') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,193 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import requests |
|||
from deepsearcher.embedding import NovitaEmbedding |
|||
|
|||
|
|||
class TestNovitaEmbedding(unittest.TestCase): |
|||
"""Tests for the NovitaEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for requests |
|||
self.requests_patcher = patch('requests.request') |
|||
self.mock_request = self.requests_patcher.start() |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': 0, 'embedding': [0.1] * 1024} # baai/bge-m3 has 1024 dimensions |
|||
] |
|||
} |
|||
self.mock_response.raise_for_status = MagicMock() |
|||
self.mock_request.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.requests_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = NovitaEmbedding() |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
self.assertEqual(embedding.api_key, 'fake-api-key') |
|||
self.assertEqual(embedding.batch_size, 32) |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = NovitaEmbedding(model='baai/bge-m3') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = NovitaEmbedding(model_name='baai/bge-m3') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = NovitaEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that the API key was set correctly |
|||
self.assertEqual(embedding.api_key, 'test-api-key') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_without_api_key(self): |
|||
"""Test initialization without API key raises error.""" |
|||
with self.assertRaises(RuntimeError): |
|||
NovitaEmbedding() |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = NovitaEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.novita.ai/v3/openai/embeddings', |
|||
json={ |
|||
'model': 'baai/bge-m3', |
|||
'input': query, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = NovitaEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024} |
|||
for i in range(3) |
|||
] |
|||
} |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.novita.ai/v3/openai/embeddings', |
|||
json={ |
|||
'model': 'baai/bge-m3', |
|||
'input': texts, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents_with_batching(self): |
|||
"""Test embedding documents with batching.""" |
|||
# Create the embedder |
|||
embedding = NovitaEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text " + str(i) for i in range(50)] # More than batch_size |
|||
|
|||
# Set up mock response for batched documents |
|||
def mock_batch_response(*args, **kwargs): |
|||
batch_input = kwargs['json']['input'] |
|||
mock_resp = MagicMock() |
|||
mock_resp.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1] * 1024} |
|||
for i in range(len(batch_input)) |
|||
] |
|||
} |
|||
mock_resp.raise_for_status = MagicMock() |
|||
return mock_resp |
|||
|
|||
self.mock_request.side_effect = mock_batch_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that request was called multiple times |
|||
self.assertTrue(self.mock_request.call_count > 1) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 50) |
|||
for result in results: |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = NovitaEmbedding() |
|||
|
|||
# For baai/bge-m3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,239 +0,0 @@ |
|||
import unittest |
|||
import sys |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from deepsearcher.embedding import OllamaEmbedding |
|||
|
|||
|
|||
class TestOllamaEmbedding(unittest.TestCase): |
|||
"""Tests for the OllamaEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module for ollama |
|||
mock_ollama_module = MagicMock() |
|||
|
|||
# Create mock Client class |
|||
self.mock_ollama_client = MagicMock() |
|||
mock_ollama_module.Client = self.mock_ollama_client |
|||
|
|||
# Add the mock module to sys.modules |
|||
self.module_patcher = patch.dict('sys.modules', {'ollama': mock_ollama_module}) |
|||
self.module_patcher.start() |
|||
|
|||
# Set up mock client instance |
|||
self.mock_client = MagicMock() |
|||
self.mock_ollama_client.return_value = self.mock_client |
|||
|
|||
# Configure mock embed method |
|||
self.mock_client.embed.return_value = {"embeddings": [[0.1] * 1024]} |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Check that Client was initialized correctly |
|||
self.mock_ollama_client.assert_called_once_with(host="http://localhost:11434/") |
|||
|
|||
# Check instance attributes |
|||
self.assertEqual(embedding.model, "bge-m3") |
|||
self.assertEqual(embedding.dim, 1024) |
|||
self.assertEqual(embedding.batch_size, 32) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Reset mock |
|||
self.mock_ollama_client.reset_mock() |
|||
|
|||
# Create embedding with custom base URL |
|||
embedding = OllamaEmbedding(base_url="http://custom-ollama-server:11434/") |
|||
|
|||
# Check that Client was initialized with custom base URL |
|||
self.mock_ollama_client.assert_called_with(host="http://custom-ollama-server:11434/") |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Reset mock |
|||
self.mock_ollama_client.reset_mock() |
|||
|
|||
# Create embedding with model_name |
|||
embedding = OllamaEmbedding(model_name="mxbai-embed-large") |
|||
|
|||
# Check model attribute |
|||
self.assertEqual(embedding.model, "mxbai-embed-large") |
|||
# Check dimension is set correctly based on model |
|||
self.assertEqual(embedding.dim, 768) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_dimension(self): |
|||
"""Test initialization with custom dimension.""" |
|||
# Reset mock |
|||
self.mock_ollama_client.reset_mock() |
|||
|
|||
# Create embedding with custom dimension |
|||
embedding = OllamaEmbedding(dimension=512) |
|||
|
|||
# Check dimension attribute |
|||
self.assertEqual(embedding.dim, 512) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Set up mock response |
|||
self.mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3] * 341 + [0.4]]} # 1024 dimensions |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query("test query") |
|||
|
|||
# Verify embed was called correctly |
|||
self.mock_client.embed.assert_called_once_with(model="bge-m3", input="test query") |
|||
|
|||
# Check the result |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, [0.1, 0.2, 0.3] * 341 + [0.4]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_small_batch(self): |
|||
"""Test embedding documents with a small batch (less than batch size).""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_embeddings = [ |
|||
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions |
|||
[0.4, 0.5, 0.6] * 341 + [0.7], |
|||
[0.7, 0.8, 0.9] * 341 + [0.1] |
|||
] |
|||
self.mock_client.embed.return_value = {"embeddings": mock_embeddings} |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify embed was called correctly |
|||
self.mock_client.embed.assert_called_once_with(model="bge-m3", input=texts) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, mock_embeddings[i]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_large_batch(self): |
|||
"""Test embedding documents with a large batch (more than batch size).""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Set a smaller batch size for testing |
|||
embedding.batch_size = 2 |
|||
|
|||
# Set up mock responses for batches |
|||
batch1_embeddings = [ |
|||
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions |
|||
[0.4, 0.5, 0.6] * 341 + [0.7] |
|||
] |
|||
batch2_embeddings = [ |
|||
[0.7, 0.8, 0.9] * 341 + [0.1] |
|||
] |
|||
|
|||
# Configure mock to return different responses for each call |
|||
self.mock_client.embed.side_effect = [ |
|||
{"embeddings": batch1_embeddings}, |
|||
{"embeddings": batch2_embeddings} |
|||
] |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify embed was called twice with the right batches |
|||
self.assertEqual(self.mock_client.embed.call_count, 2) |
|||
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 1", "text 2"]) |
|||
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 3"]) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
self.assertEqual(results[0], batch1_embeddings[0]) |
|||
self.assertEqual(results[1], batch1_embeddings[1]) |
|||
self.assertEqual(results[2], batch2_embeddings[0]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_no_batching(self): |
|||
"""Test embedding documents with batching disabled.""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Disable batching |
|||
embedding.batch_size = 0 |
|||
|
|||
# Mock the embed_query method |
|||
original_embed_query = embedding.embed_query |
|||
embed_query_calls = [] |
|||
|
|||
def mock_embed_query(text): |
|||
embed_query_calls.append(text) |
|||
return [0.1] * 1024 # Return a simple mock embedding |
|||
|
|||
embedding.embed_query = mock_embed_query |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that embed_query was called for each text |
|||
self.assertEqual(len(embed_query_calls), 3) |
|||
self.assertEqual(embed_query_calls, texts) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for result in results: |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
# Restore original method |
|||
embedding.embed_query = original_embed_query |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create instance to test |
|||
embedding = OllamaEmbedding(model="bge-m3") |
|||
|
|||
# Check dimension for bge-m3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# Test with different models |
|||
self.mock_ollama_client.reset_mock() |
|||
embedding = OllamaEmbedding(model="mxbai-embed-large") |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
self.mock_ollama_client.reset_mock() |
|||
embedding = OllamaEmbedding(model="nomic-embed-text") |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# Test with custom dimension |
|||
self.mock_ollama_client.reset_mock() |
|||
embedding = OllamaEmbedding(dimension=512) |
|||
self.assertEqual(embedding.dimension, 512) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,272 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock, ANY |
|||
|
|||
from openai._types import NOT_GIVEN |
|||
from deepsearcher.embedding import OpenAIEmbedding |
|||
|
|||
|
|||
class TestOpenAIEmbedding(unittest.TestCase): |
|||
"""Tests for the OpenAIEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for OpenAI classes |
|||
self.openai_patcher = patch('openai.OpenAI') |
|||
self.mock_openai = self.openai_patcher.start() |
|||
|
|||
# Set up mock client |
|||
self.mock_client = MagicMock() |
|||
self.mock_openai.return_value = self.mock_client |
|||
|
|||
# Set up mock embeddings |
|||
self.mock_embeddings = MagicMock() |
|||
self.mock_client.embeddings = self.mock_embeddings |
|||
|
|||
# Set up mock response for embed_query |
|||
mock_data_item = MagicMock() |
|||
mock_data_item.embedding = [0.1] * 1536 |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.data = [mock_data_item] |
|||
self.mock_embeddings.create.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.openai_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = OpenAIEmbedding() |
|||
|
|||
# Check that OpenAI was initialized correctly |
|||
self.mock_openai.assert_called_once_with(api_key='fake-api-key', base_url=None) |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'text-embedding-ada-002') |
|||
self.assertEqual(embedding.dim, 1536) |
|||
self.assertFalse(embedding.is_azure) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = OpenAIEmbedding(model='text-embedding-3-large') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'text-embedding-3-large') |
|||
self.assertEqual(embedding.dim, 3072) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = OpenAIEmbedding(model_name='text-embedding-3-small') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'text-embedding-3-small') |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_dimension(self): |
|||
"""Test initialization with specified dimension.""" |
|||
# Initialize with custom dimension |
|||
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512) |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.dim, 512) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = OpenAIEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that OpenAI was initialized with the provided API key |
|||
self.mock_openai.assert_called_with(api_key='test-api-key', base_url=None) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_base_url(self): |
|||
"""Test initialization with base URL parameter.""" |
|||
# Initialize with base URL |
|||
embedding = OpenAIEmbedding(base_url='https://test-openai-api.com') |
|||
|
|||
# Check that OpenAI was initialized with the provided base URL |
|||
self.mock_openai.assert_called_with(api_key='fake-api-key', base_url='https://test-openai-api.com') |
|||
|
|||
@patch('openai.AzureOpenAI') |
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_azure(self, mock_azure_openai): |
|||
"""Test initialization with Azure OpenAI.""" |
|||
# Set up mock Azure client |
|||
mock_azure_client = MagicMock() |
|||
mock_azure_openai.return_value = mock_azure_client |
|||
|
|||
# Initialize with Azure endpoint |
|||
embedding = OpenAIEmbedding( |
|||
azure_endpoint='https://test-azure.openai.azure.com', |
|||
api_key='test-azure-key', |
|||
api_version='2023-05-15' |
|||
) |
|||
|
|||
# Check that AzureOpenAI was initialized correctly |
|||
mock_azure_openai.assert_called_once_with( |
|||
api_key='test-azure-key', |
|||
api_version='2023-05-15', |
|||
azure_endpoint='https://test-azure.openai.azure.com' |
|||
) |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'text-embedding-ada-002') |
|||
self.assertEqual(embedding.client, mock_azure_client) |
|||
self.assertTrue(embedding.is_azure) |
|||
self.assertEqual(embedding.deployment, 'text-embedding-ada-002') |
|||
|
|||
@patch('openai.AzureOpenAI') |
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_azure_deployment(self, mock_azure_openai): |
|||
"""Test initialization with Azure OpenAI and custom deployment.""" |
|||
# Set up mock Azure client |
|||
mock_azure_client = MagicMock() |
|||
mock_azure_openai.return_value = mock_azure_client |
|||
|
|||
# Initialize with Azure endpoint and deployment |
|||
embedding = OpenAIEmbedding( |
|||
azure_endpoint='https://test-azure.openai.azure.com', |
|||
azure_deployment='test-deployment' |
|||
) |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.deployment, 'test-deployment') |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_get_dim(self): |
|||
"""Test the _get_dim method.""" |
|||
# Create the embedder |
|||
embedding = OpenAIEmbedding() |
|||
|
|||
# For text-embedding-ada-002 |
|||
self.assertIs(embedding._get_dim(), NOT_GIVEN) |
|||
|
|||
# For text-embedding-3-small |
|||
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512) |
|||
self.assertEqual(embedding._get_dim(), 512) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = OpenAIEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that create was called correctly |
|||
self.mock_embeddings.create.assert_called_once_with( |
|||
input=[query], |
|||
model='text-embedding-ada-002', |
|||
dimensions=ANY |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 1536) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query_azure(self): |
|||
"""Test embedding a single query with Azure.""" |
|||
# Set up Azure embedding |
|||
with patch('openai.AzureOpenAI') as mock_azure_openai: |
|||
# Set up mock Azure client |
|||
mock_azure_client = MagicMock() |
|||
mock_azure_openai.return_value = mock_azure_client |
|||
|
|||
# Set up mock embeddings |
|||
mock_azure_embeddings = MagicMock() |
|||
mock_azure_client.embeddings = mock_azure_embeddings |
|||
|
|||
# Set up mock response |
|||
mock_data_item = MagicMock() |
|||
mock_data_item.embedding = [0.2] * 1536 |
|||
mock_response = MagicMock() |
|||
mock_response.data = [mock_data_item] |
|||
mock_azure_embeddings.create.return_value = mock_response |
|||
|
|||
# Initialize with Azure endpoint |
|||
embedding = OpenAIEmbedding( |
|||
azure_endpoint='https://test-azure.openai.azure.com', |
|||
azure_deployment='test-deployment' |
|||
) |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that create was called correctly |
|||
mock_azure_embeddings.create.assert_called_once_with( |
|||
input=[query], |
|||
model='text-embedding-ada-002' # For Azure, this is the deployment name |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.2] * 1536) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = OpenAIEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_data_items = [] |
|||
for i in range(3): |
|||
mock_data_item = MagicMock() |
|||
mock_data_item.embedding = [0.1 * (i + 1)] * 1536 |
|||
mock_data_items.append(mock_data_item) |
|||
|
|||
mock_response = MagicMock() |
|||
mock_response.data = mock_data_items |
|||
self.mock_embeddings.create.return_value = mock_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that create was called correctly |
|||
self.mock_embeddings.create.assert_called_once_with( |
|||
input=texts, |
|||
model='text-embedding-ada-002', |
|||
dimensions=ANY |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 1536) |
|||
|
|||
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = OpenAIEmbedding() |
|||
|
|||
# For text-embedding-ada-002 |
|||
self.assertEqual(embedding.dimension, 1536) |
|||
|
|||
# For text-embedding-3-small |
|||
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512) |
|||
self.assertEqual(embedding.dimension, 512) |
|||
|
|||
# For text-embedding-3-large |
|||
embedding = OpenAIEmbedding(model='text-embedding-3-large') |
|||
self.assertEqual(embedding.dimension, 3072) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,191 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import requests |
|||
from deepsearcher.embedding import PPIOEmbedding |
|||
|
|||
|
|||
class TestPPIOEmbedding(unittest.TestCase): |
|||
"""Tests for the PPIOEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for requests |
|||
self.requests_patcher = patch('requests.request') |
|||
self.mock_request = self.requests_patcher.start() |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': 0, 'embedding': [0.1] * 1024} # baai/bge-m3 has 1024 dimensions |
|||
] |
|||
} |
|||
self.mock_response.raise_for_status = MagicMock() |
|||
self.mock_request.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.requests_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = PPIOEmbedding() |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
self.assertEqual(embedding.api_key, 'fake-api-key') |
|||
self.assertEqual(embedding.batch_size, 32) |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = PPIOEmbedding(model='baai/bge-m3') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = PPIOEmbedding(model_name='baai/bge-m3') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'baai/bge-m3') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = PPIOEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that the API key was set correctly |
|||
self.assertEqual(embedding.api_key, 'test-api-key') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_without_api_key(self): |
|||
"""Test initialization without API key raises error.""" |
|||
with self.assertRaises(RuntimeError): |
|||
PPIOEmbedding() |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = PPIOEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.ppinfra.com/v3/openai/embeddings', |
|||
json={ |
|||
'model': 'baai/bge-m3', |
|||
'input': [query] |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = PPIOEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024} |
|||
for i in range(3) |
|||
] |
|||
} |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.ppinfra.com/v3/openai/embeddings', |
|||
json={ |
|||
'model': 'baai/bge-m3', |
|||
'input': texts |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents_with_batching(self): |
|||
"""Test embedding documents with batching.""" |
|||
# Create the embedder |
|||
embedding = PPIOEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text " + str(i) for i in range(50)] # More than batch_size |
|||
|
|||
# Set up mock response for batched documents |
|||
def mock_batch_response(*args, **kwargs): |
|||
batch_input = kwargs['json']['input'] |
|||
mock_resp = MagicMock() |
|||
mock_resp.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1] * 1024} |
|||
for i in range(len(batch_input)) |
|||
] |
|||
} |
|||
mock_resp.raise_for_status = MagicMock() |
|||
return mock_resp |
|||
|
|||
self.mock_request.side_effect = mock_batch_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that request was called multiple times |
|||
self.assertTrue(self.mock_request.call_count > 1) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 50) |
|||
for result in results: |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = PPIOEmbedding() |
|||
|
|||
# For baai/bge-m3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,213 +0,0 @@ |
|||
import unittest |
|||
import sys |
|||
import logging |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.embedding import SentenceTransformerEmbedding |
|||
|
|||
|
|||
class TestSentenceTransformerEmbedding(unittest.TestCase): |
|||
"""Tests for the SentenceTransformerEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module for sentence_transformers |
|||
mock_st_module = MagicMock() |
|||
|
|||
# Create mock SentenceTransformer class |
|||
self.mock_sentence_transformer = MagicMock() |
|||
mock_st_module.SentenceTransformer = self.mock_sentence_transformer |
|||
|
|||
# Add the mock module to sys.modules |
|||
self.module_patcher = patch.dict('sys.modules', {'sentence_transformers': mock_st_module}) |
|||
self.module_patcher.start() |
|||
|
|||
# Set up mock instance |
|||
self.mock_model = MagicMock() |
|||
self.mock_sentence_transformer.return_value = self.mock_model |
|||
|
|||
# Configure mock encode method |
|||
mock_embedding = [[0.1, 0.2, 0.3] * 341 + [0.4]] # 1024 dimensions |
|||
self.mock_model.encode.return_value = MagicMock() |
|||
self.mock_model.encode.return_value.tolist.return_value = mock_embedding |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
# Create instance to test |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3") |
|||
|
|||
# Check that SentenceTransformer was called with the right model |
|||
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-m3") |
|||
|
|||
# Check that model and client were set correctly |
|||
self.assertEqual(embedding.model, "BAAI/bge-m3") |
|||
self.assertEqual(embedding.client, self.mock_model) |
|||
|
|||
# Check batch size default |
|||
self.assertEqual(embedding.batch_size, 32) |
|||
|
|||
# Test with model_name parameter |
|||
self.mock_sentence_transformer.reset_mock() |
|||
embedding = SentenceTransformerEmbedding(model_name="BAAI/bge-large-zh-v1.5") |
|||
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-large-zh-v1.5") |
|||
self.assertEqual(embedding.model, "BAAI/bge-large-zh-v1.5") |
|||
|
|||
# Test with custom batch size |
|||
self.mock_sentence_transformer.reset_mock() |
|||
embedding = SentenceTransformerEmbedding(batch_size=64) |
|||
self.assertEqual(embedding.batch_size, 64) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create instance to test |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3") |
|||
|
|||
# Mock the encode response for a single query |
|||
single_embedding = [0.1, 0.2, 0.3] * 341 + [0.4] # 1024 dimensions |
|||
self.mock_model.encode.return_value = MagicMock() |
|||
self.mock_model.encode.return_value.tolist.return_value = [single_embedding] |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query("test query") |
|||
|
|||
# Verify encode was called correctly |
|||
self.mock_model.encode.assert_called_once_with("test query") |
|||
|
|||
# Check the result |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, single_embedding) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_small_batch(self): |
|||
"""Test embedding documents with a small batch (less than batch size).""" |
|||
# Create instance to test |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3") |
|||
|
|||
# Mock the encode response for documents |
|||
batch_embeddings = [ |
|||
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions |
|||
[0.4, 0.5, 0.6] * 341 + [0.7], |
|||
[0.7, 0.8, 0.9] * 341 + [0.1] |
|||
] |
|||
self.mock_model.encode.return_value = MagicMock() |
|||
self.mock_model.encode.return_value.tolist.return_value = batch_embeddings |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify encode was called correctly |
|||
self.mock_model.encode.assert_called_once_with(texts) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, batch_embeddings[i]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_large_batch(self): |
|||
"""Test embedding documents with a large batch (more than batch size).""" |
|||
# Create instance to test with small batch size |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=2) |
|||
|
|||
# Mock the encode response for the first batch |
|||
batch1_embeddings = [ |
|||
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions |
|||
[0.4, 0.5, 0.6] * 341 + [0.7] |
|||
] |
|||
# Mock the encode response for the second batch |
|||
batch2_embeddings = [ |
|||
[0.7, 0.8, 0.9] * 341 + [0.1] |
|||
] |
|||
|
|||
# Set up the mock to return different values on each call |
|||
self.mock_model.encode.side_effect = [ |
|||
MagicMock(tolist=lambda: batch1_embeddings), |
|||
MagicMock(tolist=lambda: batch2_embeddings) |
|||
] |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify encode was called twice with the right batches |
|||
self.assertEqual(self.mock_model.encode.call_count, 2) |
|||
self.mock_model.encode.assert_any_call(["text 1", "text 2"]) |
|||
self.mock_model.encode.assert_any_call(["text 3"]) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
self.assertEqual(results[0], batch1_embeddings[0]) |
|||
self.assertEqual(results[1], batch1_embeddings[1]) |
|||
self.assertEqual(results[2], batch2_embeddings[0]) |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_embed_documents_no_batching(self): |
|||
"""Test embedding documents with batching disabled.""" |
|||
# Create instance to test with batching disabled |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=0) |
|||
|
|||
# Mock the embed_query method |
|||
original_embed_query = embedding.embed_query |
|||
embed_query_calls = [] |
|||
|
|||
def mock_embed_query(text): |
|||
embed_query_calls.append(text) |
|||
return [0.1] * 1024 # Return a simple mock embedding |
|||
|
|||
embedding.embed_query = mock_embed_query |
|||
|
|||
# Create test texts |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that embed_query was called for each text |
|||
self.assertEqual(len(embed_query_calls), 3) |
|||
self.assertEqual(embed_query_calls, texts) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for result in results: |
|||
self.assertEqual(len(result), 1024) |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
# Restore original method |
|||
embedding.embed_query = original_embed_query |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create instance to test |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3") |
|||
|
|||
# Check dimension for BAAI/bge-m3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# Test with different models |
|||
self.mock_sentence_transformer.reset_mock() |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-zh-v1.5") |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
self.mock_sentence_transformer.reset_mock() |
|||
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-en-v1.5") |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,201 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import requests |
|||
from deepsearcher.embedding import SiliconflowEmbedding |
|||
|
|||
|
|||
class TestSiliconflowEmbedding(unittest.TestCase): |
|||
"""Tests for the SiliconflowEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for requests |
|||
self.requests_patcher = patch('requests.request') |
|||
self.mock_request = self.requests_patcher.start() |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': 0, 'embedding': [0.1] * 1024} # BAAI/bge-m3 has 1024 dimensions |
|||
] |
|||
} |
|||
self.mock_response.raise_for_status = MagicMock() |
|||
self.mock_request.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.requests_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = SiliconflowEmbedding() |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'BAAI/bge-m3') |
|||
self.assertEqual(embedding.api_key, 'fake-api-key') |
|||
self.assertEqual(embedding.batch_size, 32) |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = SiliconflowEmbedding(model='netease-youdao/bce-embedding-base_v1') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'netease-youdao/bce-embedding-base_v1') |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = SiliconflowEmbedding(model_name='BAAI/bge-large-zh-v1.5') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'BAAI/bge-large-zh-v1.5') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = SiliconflowEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that the API key was set correctly |
|||
self.assertEqual(embedding.api_key, 'test-api-key') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_without_api_key(self): |
|||
"""Test initialization without API key raises error.""" |
|||
with self.assertRaises(RuntimeError): |
|||
SiliconflowEmbedding() |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = SiliconflowEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.siliconflow.cn/v1/embeddings', |
|||
json={ |
|||
'model': 'BAAI/bge-m3', |
|||
'input': query, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = SiliconflowEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024} |
|||
for i in range(3) |
|||
] |
|||
} |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://api.siliconflow.cn/v1/embeddings', |
|||
json={ |
|||
'model': 'BAAI/bge-m3', |
|||
'input': texts, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents_with_batching(self): |
|||
"""Test embedding documents with batching.""" |
|||
# Create the embedder |
|||
embedding = SiliconflowEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text " + str(i) for i in range(50)] # More than batch_size |
|||
|
|||
# Set up mock response for batched documents |
|||
def mock_batch_response(*args, **kwargs): |
|||
batch_input = kwargs['json']['input'] |
|||
mock_resp = MagicMock() |
|||
mock_resp.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1] * 1024} |
|||
for i in range(len(batch_input)) |
|||
] |
|||
} |
|||
mock_resp.raise_for_status = MagicMock() |
|||
return mock_resp |
|||
|
|||
self.mock_request.side_effect = mock_batch_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that request was called multiple times |
|||
self.assertTrue(self.mock_request.call_count > 1) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 50) |
|||
for result in results: |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = SiliconflowEmbedding() |
|||
|
|||
# For BAAI/bge-m3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# For netease-youdao/bce-embedding-base_v1 |
|||
embedding = SiliconflowEmbedding(model='netease-youdao/bce-embedding-base_v1') |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# For BAAI/bge-large-zh-v1.5 |
|||
embedding = SiliconflowEmbedding(model='BAAI/bge-large-zh-v1.5') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,201 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import requests |
|||
from deepsearcher.embedding import VolcengineEmbedding |
|||
|
|||
|
|||
class TestVolcengineEmbedding(unittest.TestCase): |
|||
"""Tests for the VolcengineEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for requests |
|||
self.requests_patcher = patch('requests.request') |
|||
self.mock_request = self.requests_patcher.start() |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': 0, 'embedding': [0.1] * 4096} # doubao-embedding-large-text-240915 has 4096 dimensions |
|||
] |
|||
} |
|||
self.mock_response.raise_for_status = MagicMock() |
|||
self.mock_request.return_value = self.mock_response |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.requests_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = VolcengineEmbedding() |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'doubao-embedding-large-text-240915') |
|||
self.assertEqual(embedding.api_key, 'fake-api-key') |
|||
self.assertEqual(embedding.batch_size, 256) |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = VolcengineEmbedding(model='doubao-embedding-text-240515') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'doubao-embedding-text-240515') |
|||
self.assertEqual(embedding.dimension, 2048) |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = VolcengineEmbedding(model_name='doubao-embedding-text-240715') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'doubao-embedding-text-240715') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = VolcengineEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that the API key was set correctly |
|||
self.assertEqual(embedding.api_key, 'test-api-key') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_without_api_key(self): |
|||
"""Test initialization without API key raises error.""" |
|||
with self.assertRaises(RuntimeError): |
|||
VolcengineEmbedding() |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = VolcengineEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://ark.cn-beijing.volces.com/api/v3/embeddings', |
|||
json={ |
|||
'model': 'doubao-embedding-large-text-240915', |
|||
'input': query, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 4096) |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = VolcengineEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
self.mock_response.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1 * (i + 1)] * 4096} |
|||
for i in range(3) |
|||
] |
|||
} |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that request was called correctly |
|||
self.mock_request.assert_called_once_with( |
|||
'POST', |
|||
'https://ark.cn-beijing.volces.com/api/v3/embeddings', |
|||
json={ |
|||
'model': 'doubao-embedding-large-text-240915', |
|||
'input': texts, |
|||
'encoding_format': 'float' |
|||
}, |
|||
headers={ |
|||
'Authorization': 'Bearer fake-api-key', |
|||
'Content-Type': 'application/json' |
|||
} |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 4096) |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents_with_batching(self): |
|||
"""Test embedding documents with batching.""" |
|||
# Create the embedder |
|||
embedding = VolcengineEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text " + str(i) for i in range(300)] # More than batch_size |
|||
|
|||
# Set up mock response for batched documents |
|||
def mock_batch_response(*args, **kwargs): |
|||
batch_input = kwargs['json']['input'] |
|||
mock_resp = MagicMock() |
|||
mock_resp.json.return_value = { |
|||
'data': [ |
|||
{'index': i, 'embedding': [0.1] * 4096} |
|||
for i in range(len(batch_input)) |
|||
] |
|||
} |
|||
mock_resp.raise_for_status = MagicMock() |
|||
return mock_resp |
|||
|
|||
self.mock_request.side_effect = mock_batch_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Check that request was called multiple times |
|||
self.assertTrue(self.mock_request.call_count > 1) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 300) |
|||
for result in results: |
|||
self.assertEqual(result, [0.1] * 4096) |
|||
|
|||
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = VolcengineEmbedding() |
|||
|
|||
# For doubao-embedding-large-text-240915 |
|||
self.assertEqual(embedding.dimension, 4096) |
|||
|
|||
# For doubao-embedding-text-240715 |
|||
embedding = VolcengineEmbedding(model='doubao-embedding-text-240715') |
|||
self.assertEqual(embedding.dimension, 2560) |
|||
|
|||
# For doubao-embedding-text-240515 |
|||
embedding = VolcengineEmbedding(model='doubao-embedding-text-240515') |
|||
self.assertEqual(embedding.dimension, 2048) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,144 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from deepsearcher.embedding import VoyageEmbedding |
|||
|
|||
|
|||
class TestVoyageEmbedding(unittest.TestCase): |
|||
"""Tests for the VoyageEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create a mock module |
|||
self.mock_voyageai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
|
|||
# Set up mock response for embed |
|||
mock_response = MagicMock() |
|||
mock_response.embeddings = [[0.1] * 1024] # voyage-3 has 1024 dimensions |
|||
self.mock_client.embed.return_value = mock_response |
|||
|
|||
# Set up the mock module |
|||
self.mock_voyageai.Client.return_value = self.mock_client |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'voyageai': self.mock_voyageai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Create the embedder |
|||
embedding = VoyageEmbedding() |
|||
|
|||
# Check that voyageai was initialized correctly |
|||
self.mock_voyageai.Client.assert_called_once() |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'voyage-3') |
|||
self.assertEqual(embedding.voyageai_api_key, 'fake-api-key') |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model(self): |
|||
"""Test initialization with specified model.""" |
|||
# Initialize with a different model |
|||
embedding = VoyageEmbedding(model='voyage-3-lite') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'voyage-3-lite') |
|||
self.assertEqual(embedding.dimension, 512) # voyage-3-lite has 512 dimensions |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_init_with_model_name(self): |
|||
"""Test initialization with model_name parameter.""" |
|||
# Initialize with model_name |
|||
embedding = VoyageEmbedding(model_name='voyage-3-large') |
|||
|
|||
# Check attributes |
|||
self.assertEqual(embedding.model, 'voyage-3-large') |
|||
|
|||
@patch.dict('os.environ', {}, clear=True) |
|||
def test_init_with_api_key(self): |
|||
"""Test initialization with API key parameter.""" |
|||
# Initialize with API key |
|||
embedding = VoyageEmbedding(api_key='test-api-key') |
|||
|
|||
# Check that the API key was set correctly |
|||
self.assertEqual(embedding.voyageai_api_key, 'test-api-key') |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_query(self): |
|||
"""Test embedding a single query.""" |
|||
# Create the embedder |
|||
embedding = VoyageEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that embed was called correctly |
|||
self.mock_client.embed.assert_called_once_with( |
|||
[query], |
|||
model='voyage-3', |
|||
input_type='query' |
|||
) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_embed_documents(self): |
|||
"""Test embedding multiple documents.""" |
|||
# Create the embedder |
|||
embedding = VoyageEmbedding() |
|||
|
|||
# Create test documents |
|||
texts = ["text 1", "text 2", "text 3"] |
|||
|
|||
# Set up mock response for multiple documents |
|||
mock_response = MagicMock() |
|||
mock_response.embeddings = [[0.1 * (i + 1)] * 1024 for i in range(3)] |
|||
self.mock_client.embed.return_value = mock_response |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(texts) |
|||
|
|||
# Verify that embed was called correctly |
|||
self.mock_client.embed.assert_called_once_with( |
|||
texts, |
|||
model='voyage-3', |
|||
input_type='document' |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
for i, result in enumerate(results): |
|||
self.assertEqual(result, [0.1 * (i + 1)] * 1024) |
|||
|
|||
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True) |
|||
def test_dimension_property(self): |
|||
"""Test the dimension property.""" |
|||
# Create the embedder |
|||
embedding = VoyageEmbedding() |
|||
|
|||
# For voyage-3 |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
# For voyage-3-lite |
|||
embedding = VoyageEmbedding(model='voyage-3-lite') |
|||
self.assertEqual(embedding.dimension, 512) |
|||
|
|||
# For voyage-3-large |
|||
embedding = VoyageEmbedding(model='voyage-3-large') |
|||
self.assertEqual(embedding.dimension, 1024) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,284 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import MagicMock, patch, ANY |
|||
import os |
|||
|
|||
class TestWatsonXEmbedding(unittest.TestCase): |
|||
"""Test cases for WatsonXEmbedding class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Mock the ibm_watsonx_ai imports |
|||
self.mock_credentials = MagicMock() |
|||
self.mock_embeddings = MagicMock() |
|||
|
|||
# Create a mock client |
|||
self.mock_client = MagicMock() |
|||
|
|||
# Set up mock response for embed_query |
|||
self.mock_client.embed_query.return_value = { |
|||
'results': [ |
|||
{'embedding': [0.1] * 768} |
|||
] |
|||
} |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_init_with_env_vars(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test initialization with environment variables.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
embedding = WatsonXEmbedding() |
|||
|
|||
# Check that Credentials was called with correct parameters |
|||
mock_credentials_class.assert_called_once_with( |
|||
url='https://test.watsonx.com', |
|||
api_key='test-api-key' |
|||
) |
|||
|
|||
# Check that Embeddings was called with correct parameters |
|||
mock_embeddings_class.assert_called_once_with( |
|||
model_id='ibm/slate-125m-english-rtrvr-v2', |
|||
credentials=mock_credentials_instance, |
|||
project_id='test-project-id' |
|||
) |
|||
|
|||
# Check default model and dimension |
|||
self.assertEqual(embedding.model, 'ibm/slate-125m-english-rtrvr-v2') |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com' |
|||
}) |
|||
def test_init_with_space_id(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test initialization with space_id instead of project_id.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
embedding = WatsonXEmbedding(space_id='test-space-id') |
|||
|
|||
# Check that Embeddings was called with space_id |
|||
mock_embeddings_class.assert_called_once_with( |
|||
model_id='ibm/slate-125m-english-rtrvr-v2', |
|||
credentials=mock_credentials_instance, |
|||
space_id='test-space-id' |
|||
) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
def test_init_missing_api_key(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test initialization with missing API key.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
with patch.dict(os.environ, {}, clear=True): |
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonXEmbedding() |
|||
|
|||
self.assertIn("WATSONX_APIKEY", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key' |
|||
}) |
|||
def test_init_missing_url(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test initialization with missing URL.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonXEmbedding() |
|||
|
|||
self.assertIn("WATSONX_URL", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com' |
|||
}) |
|||
def test_init_missing_project_and_space_id(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test initialization with missing both project_id and space_id.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonXEmbedding() |
|||
|
|||
self.assertIn("WATSONX_PROJECT_ID", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_embed_query(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test embedding a single query.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
# WatsonX embed_query returns the embedding vector directly, not wrapped in a dict |
|||
mock_embeddings_instance.embed_query.return_value = [0.1] * 768 |
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
# Create the embedder |
|||
embedding = WatsonXEmbedding() |
|||
|
|||
# Create a test query |
|||
query = "This is a test query" |
|||
|
|||
# Call the method |
|||
result = embedding.embed_query(query) |
|||
|
|||
# Verify that embed_query was called correctly |
|||
mock_embeddings_instance.embed_query.assert_called_once_with(text=query) |
|||
|
|||
# Check the result |
|||
self.assertEqual(result, [0.1] * 768) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_embed_documents(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test embedding multiple documents.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
# WatsonX embed_documents returns a list of embedding vectors directly |
|||
mock_embeddings_instance.embed_documents.return_value = [ |
|||
[0.1] * 768, |
|||
[0.2] * 768, |
|||
[0.3] * 768 |
|||
] |
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
# Create the embedder |
|||
embedding = WatsonXEmbedding() |
|||
|
|||
# Create test documents |
|||
documents = ["Document 1", "Document 2", "Document 3"] |
|||
|
|||
# Call the method |
|||
results = embedding.embed_documents(documents) |
|||
|
|||
# Verify that embed_documents was called correctly |
|||
mock_embeddings_instance.embed_documents.assert_called_once_with(texts=documents) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(results), 3) |
|||
self.assertEqual(results[0], [0.1] * 768) |
|||
self.assertEqual(results[1], [0.2] * 768) |
|||
self.assertEqual(results[2], [0.3] * 768) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_dimension_property(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test the dimension property for different models.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
# Test default model |
|||
embedding = WatsonXEmbedding() |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
# Test different model |
|||
embedding = WatsonXEmbedding(model='ibm/slate-30m-english-rtrvr') |
|||
self.assertEqual(embedding.dimension, 384) |
|||
|
|||
# Test unknown model (should default to 768) |
|||
embedding = WatsonXEmbedding(model='unknown-model') |
|||
self.assertEqual(embedding.dimension, 768) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_embed_query_error_handling(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test error handling in embed_query.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
mock_embeddings_instance.embed_query.side_effect = Exception("API Error") |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
# Create the embedder |
|||
embedding = WatsonXEmbedding() |
|||
|
|||
# Test that the exception is properly wrapped |
|||
with self.assertRaises(RuntimeError) as context: |
|||
embedding.embed_query("test") |
|||
|
|||
self.assertIn("Error embedding query with WatsonX", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings') |
|||
@patch('deepsearcher.embedding.watsonx_embedding.Credentials') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_embed_documents_error_handling(self, mock_credentials_class, mock_embeddings_class): |
|||
"""Test error handling in embed_documents.""" |
|||
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_embeddings_instance = MagicMock() |
|||
mock_embeddings_instance.embed_documents.side_effect = Exception("API Error") |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_embeddings_class.return_value = mock_embeddings_instance |
|||
|
|||
# Create the embedder |
|||
embedding = WatsonXEmbedding() |
|||
|
|||
# Test that the exception is properly wrapped |
|||
with self.assertRaises(RuntimeError) as context: |
|||
embedding.embed_documents(["test"]) |
|||
|
|||
self.assertIn("Error embedding documents with WatsonX", str(context.exception)) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher.llm package |
@ -1,164 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Aliyun |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestAliyun(unittest.TestCase): |
|||
"""Tests for the Aliyun LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Aliyun() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek-r1") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": api_key}): |
|||
llm = Aliyun() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = Aliyun(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "qwen-max" |
|||
llm = Aliyun(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.aliyun.api" |
|||
llm = Aliyun(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Aliyun instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Aliyun() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-r1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Aliyun instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Aliyun() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-r1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Aliyun instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Aliyun() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("Aliyun API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Aliyun API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,169 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
from typing import List |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Anthropic |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class ContentItem: |
|||
"""Mock content item for Anthropic response.""" |
|||
def __init__(self, text: str): |
|||
self.text = text |
|||
|
|||
|
|||
class TestAnthropic(unittest.TestCase): |
|||
"""Tests for the Anthropic LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_anthropic = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
|
|||
# Set up response content with proper structure |
|||
content_item = ContentItem("Test response") |
|||
self.mock_response.content = [content_item] |
|||
self.mock_response.usage.input_tokens = 50 |
|||
self.mock_response.usage.output_tokens = 50 |
|||
|
|||
# Set up the mock module structure and response |
|||
self.mock_client.messages.create.return_value = self.mock_response |
|||
self.mock_anthropic.Anthropic.return_value = self.mock_client |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'anthropic': self.mock_anthropic}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Anthropic() |
|||
# Check that Anthropic client was initialized correctly |
|||
self.mock_anthropic.Anthropic.assert_called_once_with( |
|||
api_key=None, |
|||
base_url=None |
|||
) |
|||
|
|||
# Check default attributes |
|||
self.assertEqual(llm.model, "claude-sonnet-4-0") |
|||
self.assertEqual(llm.max_tokens, 8192) |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": api_key}): |
|||
llm = Anthropic() |
|||
self.mock_anthropic.Anthropic.assert_called_with( |
|||
api_key=api_key, |
|||
base_url=None |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = Anthropic(api_key=api_key) |
|||
self.mock_anthropic.Anthropic.assert_called_with( |
|||
api_key=api_key, |
|||
base_url=None |
|||
) |
|||
|
|||
def test_init_with_custom_model_and_tokens(self): |
|||
"""Test initialization with custom model and max tokens.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "claude-3-opus-20240229" |
|||
max_tokens = 4096 |
|||
llm = Anthropic(model=model, max_tokens=max_tokens) |
|||
self.assertEqual(llm.model, model) |
|||
self.assertEqual(llm.max_tokens, max_tokens) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.anthropic.api" |
|||
llm = Anthropic(base_url=base_url) |
|||
self.mock_anthropic.Anthropic.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Anthropic instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Anthropic() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that messages.create was called correctly |
|||
self.mock_client.messages.create.assert_called_once() |
|||
call_args = self.mock_client.messages.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "claude-sonnet-4-0") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
self.assertEqual(call_args[1]["max_tokens"], 8192) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) # 50 input + 50 output |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Anthropic instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Anthropic() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that messages.create was called correctly |
|||
self.mock_client.messages.create.assert_called_once() |
|||
call_args = self.mock_client.messages.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "claude-sonnet-4-0") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
self.assertEqual(call_args[1]["max_tokens"], 8192) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) # 50 input + 50 output |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Anthropic instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Anthropic() |
|||
|
|||
# Mock an error response |
|||
self.mock_client.messages.create.side_effect = Exception("Anthropic API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Anthropic API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,170 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import AzureOpenAI |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestAzureOpenAI(unittest.TestCase): |
|||
"""Tests for the Azure OpenAI LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.AzureOpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
# Test parameters |
|||
self.test_model = "gpt-4" |
|||
self.test_endpoint = "https://test.openai.azure.com" |
|||
self.test_api_key = "test_api_key" |
|||
self.test_api_version = "2024-02-15" |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_with_parameters(self): |
|||
"""Test initialization with explicit parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = AzureOpenAI( |
|||
model=self.test_model, |
|||
azure_endpoint=self.test_endpoint, |
|||
api_key=self.test_api_key, |
|||
api_version=self.test_api_version |
|||
) |
|||
# Check that Azure OpenAI client was initialized correctly |
|||
self.mock_openai.AzureOpenAI.assert_called_once_with( |
|||
azure_endpoint=self.test_endpoint, |
|||
api_key=self.test_api_key, |
|||
api_version=self.test_api_version |
|||
) |
|||
|
|||
# Check model attribute |
|||
self.assertEqual(llm.model, self.test_model) |
|||
|
|||
def test_init_with_env_variables(self): |
|||
"""Test initialization with environment variables.""" |
|||
env_endpoint = "https://env.openai.azure.com" |
|||
env_api_key = "env_api_key" |
|||
|
|||
with patch.dict(os.environ, { |
|||
"AZURE_OPENAI_ENDPOINT": env_endpoint, |
|||
"AZURE_OPENAI_KEY": env_api_key |
|||
}): |
|||
llm = AzureOpenAI(model=self.test_model) |
|||
self.mock_openai.AzureOpenAI.assert_called_with( |
|||
azure_endpoint=env_endpoint, |
|||
api_key=env_api_key, |
|||
api_version=None |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Azure OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = AzureOpenAI( |
|||
model=self.test_model, |
|||
azure_endpoint=self.test_endpoint, |
|||
api_key=self.test_api_key, |
|||
api_version=self.test_api_version |
|||
) |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], self.test_model) |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Azure OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = AzureOpenAI( |
|||
model=self.test_model, |
|||
azure_endpoint=self.test_endpoint, |
|||
api_key=self.test_api_key, |
|||
api_version=self.test_api_version |
|||
) |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], self.test_model) |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Azure OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = AzureOpenAI( |
|||
model=self.test_model, |
|||
azure_endpoint=self.test_endpoint, |
|||
api_key=self.test_api_key, |
|||
api_version=self.test_api_version |
|||
) |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("Azure OpenAI API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Azure OpenAI API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,154 +0,0 @@ |
|||
import unittest |
|||
from deepsearcher.llm.base import BaseLLM, ChatResponse |
|||
from unittest.mock import patch |
|||
|
|||
|
|||
class TestBaseLLM(unittest.TestCase): |
|||
"""Tests for the BaseLLM abstract base class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Clear environment variables temporarily |
|||
self.env_patcher = patch.dict('os.environ', {}, clear=True) |
|||
self.env_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.env_patcher.stop() |
|||
|
|||
def test_chat_response_init(self): |
|||
"""Test ChatResponse initialization and representation.""" |
|||
content = "Test content" |
|||
total_tokens = 100 |
|||
response = ChatResponse(content=content, total_tokens=total_tokens) |
|||
|
|||
self.assertEqual(response.content, content) |
|||
self.assertEqual(response.total_tokens, total_tokens) |
|||
self.assertEqual( |
|||
repr(response), |
|||
f"ChatResponse(content={content}, total_tokens={total_tokens})" |
|||
) |
|||
|
|||
def test_literal_eval_python_code_block(self): |
|||
"""Test literal_eval with Python code block.""" |
|||
content = '''```python |
|||
{"key": "value", "number": 42} |
|||
```''' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value", "number": 42}) |
|||
|
|||
def test_literal_eval_json_code_block(self): |
|||
"""Test literal_eval with JSON code block.""" |
|||
content = '''```json |
|||
{"key": "value", "number": 42} |
|||
```''' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value", "number": 42}) |
|||
|
|||
def test_literal_eval_str_code_block(self): |
|||
"""Test literal_eval with str code block.""" |
|||
content = '''```str |
|||
{"key": "value", "number": 42} |
|||
```''' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value", "number": 42}) |
|||
|
|||
def test_literal_eval_plain_code_block(self): |
|||
"""Test literal_eval with plain code block.""" |
|||
content = '''``` |
|||
{"key": "value", "number": 42} |
|||
```''' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value", "number": 42}) |
|||
|
|||
def test_literal_eval_raw_dict(self): |
|||
"""Test literal_eval with raw dictionary string.""" |
|||
content = '{"key": "value", "number": 42}' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value", "number": 42}) |
|||
|
|||
def test_literal_eval_raw_list(self): |
|||
"""Test literal_eval with raw list string.""" |
|||
content = '[1, 2, "three", {"four": 4}]' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, [1, 2, "three", {"four": 4}]) |
|||
|
|||
def test_literal_eval_with_whitespace(self): |
|||
"""Test literal_eval with extra whitespace.""" |
|||
content = ''' |
|||
|
|||
{"key": "value"} |
|||
|
|||
''' |
|||
result = BaseLLM.literal_eval(content) |
|||
self.assertEqual(result, {"key": "value"}) |
|||
|
|||
def test_literal_eval_nested_structures(self): |
|||
"""Test literal_eval with nested data structures.""" |
|||
content = ''' |
|||
{ |
|||
"string": "value", |
|||
"number": 42, |
|||
"list": [1, 2, 3], |
|||
"dict": {"nested": "value"}, |
|||
"mixed": [1, {"key": "value"}, [2, 3]] |
|||
} |
|||
''' |
|||
result = BaseLLM.literal_eval(content) |
|||
expected = { |
|||
"string": "value", |
|||
"number": 42, |
|||
"list": [1, 2, 3], |
|||
"dict": {"nested": "value"}, |
|||
"mixed": [1, {"key": "value"}, [2, 3]] |
|||
} |
|||
self.assertEqual(result, expected) |
|||
|
|||
def test_literal_eval_invalid_format(self): |
|||
"""Test literal_eval with invalid format.""" |
|||
invalid_contents = [ |
|||
"Not a valid Python literal", |
|||
"{invalid: json}", |
|||
"[1, 2, 3", # Unclosed bracket |
|||
'{"key": undefined}', # undefined is not a valid Python literal |
|||
] |
|||
for content in invalid_contents: |
|||
with self.assertRaises(ValueError): |
|||
BaseLLM.literal_eval(content) |
|||
|
|||
def test_remove_think_with_tags(self): |
|||
"""Test remove_think with think tags.""" |
|||
content = '''<think> |
|||
This is the reasoning process. |
|||
Multiple lines of thought. |
|||
</think> |
|||
This is the actual response.''' |
|||
result = BaseLLM.remove_think(content) |
|||
self.assertEqual(result.strip(), "This is the actual response.") |
|||
|
|||
def test_remove_think_without_tags(self): |
|||
"""Test remove_think without think tags.""" |
|||
content = "This is a response without think tags." |
|||
result = BaseLLM.remove_think(content) |
|||
self.assertEqual(result.strip(), content.strip()) |
|||
|
|||
def test_remove_think_multiple_tags(self): |
|||
"""Test remove_think with multiple think tags - should only remove first block.""" |
|||
content = '''<think>First think block</think> |
|||
Actual response |
|||
<think>Second think block</think>''' |
|||
result = BaseLLM.remove_think(content) |
|||
self.assertEqual( |
|||
result.strip(), |
|||
"Actual response\n <think>Second think block</think>" |
|||
) |
|||
|
|||
def test_remove_think_empty_tags(self): |
|||
"""Test remove_think with empty think tags.""" |
|||
content = "<think></think>Response" |
|||
result = BaseLLM.remove_think(content) |
|||
self.assertEqual(result.strip(), "Response") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,196 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Bedrock |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestBedrock(unittest.TestCase): |
|||
"""Tests for the AWS Bedrock LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_boto3 = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_boto3.client = MagicMock(return_value=self.mock_client) |
|||
|
|||
# Set up mock response |
|||
self.mock_response = { |
|||
"output": { |
|||
"message": { |
|||
"content": [{"text": "Test response\nwith newline"}] |
|||
} |
|||
}, |
|||
"usage": { |
|||
"totalTokens": 100 |
|||
} |
|||
} |
|||
self.mock_client.converse.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock() |
|||
# Check that client was initialized correctly |
|||
self.mock_boto3.client.assert_called_once_with( |
|||
"bedrock-runtime", |
|||
region_name="us-west-2", |
|||
aws_access_key_id=None, |
|||
aws_secret_access_key=None, |
|||
aws_session_token=None |
|||
) |
|||
|
|||
# Check default attributes |
|||
self.assertEqual(llm.model, "us.deepseek.r1-v1:0") |
|||
self.assertEqual(llm.max_tokens, 20000) |
|||
|
|||
def test_init_with_aws_credentials_from_env(self): |
|||
"""Test initialization with AWS credentials from environment variables.""" |
|||
credentials = { |
|||
"AWS_ACCESS_KEY_ID": "test_access_key", |
|||
"AWS_SECRET_ACCESS_KEY": "test_secret_key", |
|||
"AWS_SESSION_TOKEN": "test_session_token" |
|||
} |
|||
with patch.dict(os.environ, credentials): |
|||
llm = Bedrock() |
|||
self.mock_boto3.client.assert_called_with( |
|||
"bedrock-runtime", |
|||
region_name="us-west-2", |
|||
aws_access_key_id="test_access_key", |
|||
aws_secret_access_key="test_secret_key", |
|||
aws_session_token="test_session_token" |
|||
) |
|||
|
|||
def test_init_with_aws_credentials_parameters(self): |
|||
"""Test initialization with AWS credentials as parameters.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock( |
|||
aws_access_key_id="param_access_key", |
|||
aws_secret_access_key="param_secret_key", |
|||
aws_session_token="param_session_token" |
|||
) |
|||
self.mock_boto3.client.assert_called_with( |
|||
"bedrock-runtime", |
|||
region_name="us-west-2", |
|||
aws_access_key_id="param_access_key", |
|||
aws_secret_access_key="param_secret_key", |
|||
aws_session_token="param_session_token" |
|||
) |
|||
|
|||
def test_init_with_custom_model_and_tokens(self): |
|||
"""Test initialization with custom model and max tokens.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock(model="custom.model", max_tokens=1000) |
|||
self.assertEqual(llm.model, "custom.model") |
|||
self.assertEqual(llm.max_tokens, 1000) |
|||
|
|||
def test_init_with_custom_region(self): |
|||
"""Test initialization with custom region.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock(region_name="us-east-1") |
|||
self.mock_boto3.client.assert_called_with( |
|||
"bedrock-runtime", |
|||
region_name="us-east-1", |
|||
aws_access_key_id=None, |
|||
aws_secret_access_key=None, |
|||
aws_session_token=None |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Bedrock instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock() |
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that converse was called correctly |
|||
self.mock_client.converse.assert_called_once() |
|||
call_args = self.mock_client.converse.call_args |
|||
self.assertEqual(call_args[1]["modelId"], "us.deepseek.r1-v1:0") |
|||
self.assertEqual(call_args[1]["messages"], [ |
|||
{"role": "user", "content": [{"text": "Hello"}]} |
|||
]) |
|||
self.assertEqual(call_args[1]["inferenceConfig"], {"maxTokens": 20000}) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test responsewith newline") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Bedrock instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock() |
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that converse was called correctly |
|||
self.mock_client.converse.assert_called_once() |
|||
call_args = self.mock_client.converse.call_args |
|||
|
|||
expected_messages = [ |
|||
{"role": "system", "content": [{"text": "You are a helpful assistant"}]}, |
|||
{"role": "user", "content": [{"text": "Hello"}]}, |
|||
{"role": "assistant", "content": [{"text": "Hi there!"}]}, |
|||
{"role": "user", "content": [{"text": "How are you?"}]} |
|||
] |
|||
self.assertEqual(call_args[1]["messages"], expected_messages) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Bedrock instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock() |
|||
# Mock an error response |
|||
self.mock_client.converse.side_effect = Exception("AWS Bedrock Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "AWS Bedrock Error") |
|||
|
|||
def test_chat_with_preformatted_messages(self): |
|||
"""Test chat with messages that are already in the correct format.""" |
|||
# Create Bedrock instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Bedrock() |
|||
messages = [ |
|||
{ |
|||
"role": "user", |
|||
"content": [{"text": "Hello"}] |
|||
} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that the message format was preserved |
|||
call_args = self.mock_client.converse.call_args |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,169 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import DeepSeek |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestDeepSeek(unittest.TestCase): |
|||
"""Tests for the DeepSeek LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = DeepSeek() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://api.deepseek.com" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek-reasoner") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
base_url = "https://custom.deepseek.api" |
|||
with patch.dict(os.environ, { |
|||
"DEEPSEEK_API_KEY": api_key, |
|||
"DEEPSEEK_BASE_URL": base_url |
|||
}): |
|||
llm = DeepSeek() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
api_key = "test_api_key_param" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = DeepSeek(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.deepseek.com" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "deepseek-chat" |
|||
llm = DeepSeek(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.deepseek.api" |
|||
llm = DeepSeek(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create DeepSeek instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = DeepSeek() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-reasoner") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create DeepSeek instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = DeepSeek() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-reasoner") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create DeepSeek instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = DeepSeek() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("DeepSeek API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "DeepSeek API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,136 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Gemini |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestGemini(unittest.TestCase): |
|||
"""Tests for the Gemini LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_genai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_response = MagicMock() |
|||
self.mock_metadata = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_genai.Client = MagicMock(return_value=self.mock_client) |
|||
|
|||
# Set up mock response |
|||
self.mock_response.text = "Test response" |
|||
self.mock_metadata.total_token_count = 100 |
|||
self.mock_response.usage_metadata = self.mock_metadata |
|||
self.mock_client.models.generate_content.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'google.genai': self.mock_genai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Gemini() |
|||
# Check that Client was initialized correctly |
|||
self.mock_genai.Client.assert_called_once_with(api_key=None) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "gemini-2.0-flash") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"GEMINI_API_KEY": api_key}): |
|||
llm = Gemini() |
|||
self.mock_genai.Client.assert_called_with(api_key=api_key) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = Gemini(api_key=api_key) |
|||
self.mock_genai.Client.assert_called_with(api_key=api_key) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "gemini-pro" |
|||
llm = Gemini(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Gemini instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Gemini() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that generate_content was called correctly |
|||
self.mock_client.models.generate_content.assert_called_once() |
|||
call_args = self.mock_client.models.generate_content.call_args |
|||
self.assertEqual(call_args[1]["model"], "gemini-2.0-flash") |
|||
self.assertEqual(call_args[1]["contents"], "Hello") |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Gemini instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Gemini() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that generate_content was called correctly |
|||
self.mock_client.models.generate_content.assert_called_once() |
|||
call_args = self.mock_client.models.generate_content.call_args |
|||
self.assertEqual(call_args[1]["model"], "gemini-2.0-flash") |
|||
expected_content = "You are a helpful assistant\nHello\nHi there!\nHow are you?" |
|||
self.assertEqual(call_args[1]["contents"], expected_content) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Gemini instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Gemini() |
|||
|
|||
# Mock an error response |
|||
self.mock_client.models.generate_content.side_effect = Exception("API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import GLM |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestGLM(unittest.TestCase): |
|||
"""Tests for the GLM LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_zhipuai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_zhipuai.ZhipuAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'zhipuai': self.mock_zhipuai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = GLM() |
|||
# Check that ZhipuAI client was initialized correctly |
|||
self.mock_zhipuai.ZhipuAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://open.bigmodel.cn/api/paas/v4/" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "glm-4-plus") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"GLM_API_KEY": api_key}): |
|||
llm = GLM() |
|||
self.mock_zhipuai.ZhipuAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://open.bigmodel.cn/api/paas/v4/" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = GLM(api_key=api_key) |
|||
self.mock_zhipuai.ZhipuAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://open.bigmodel.cn/api/paas/v4/" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "glm-3-turbo" |
|||
llm = GLM(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.glm.api" |
|||
llm = GLM(base_url=base_url) |
|||
self.mock_zhipuai.ZhipuAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create GLM instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = GLM() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "glm-4-plus") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create GLM instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = GLM() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "glm-4-plus") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create GLM instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = GLM() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("GLM API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "GLM API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Novita |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestNovita(unittest.TestCase): |
|||
"""Tests for the Novita LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Novita() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://api.novita.ai/v3/openai" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "qwen/qwq-32b") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"NOVITA_API_KEY": api_key}): |
|||
llm = Novita() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.novita.ai/v3/openai" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = Novita(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.novita.ai/v3/openai" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "qwen/qwq-72b" |
|||
llm = Novita(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.novita.api" |
|||
llm = Novita(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Novita instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Novita() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "qwen/qwq-32b") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Novita instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Novita() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "qwen/qwq-32b") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Novita instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Novita() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("Novita API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Novita API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,136 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import logging |
|||
import os |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Ollama |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestOllama(unittest.TestCase): |
|||
"""Tests for the Ollama LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_ollama = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_ollama.Client = MagicMock(return_value=self.mock_client) |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_response.message = self.mock_message |
|||
self.mock_response.prompt_eval_count = 50 |
|||
self.mock_response.eval_count = 50 |
|||
|
|||
self.mock_client.chat.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'ollama': self.mock_ollama}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Ollama() |
|||
# Check that Ollama client was initialized correctly |
|||
self.mock_ollama.Client.assert_called_once_with( |
|||
host="http://localhost:11434" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "qwq") |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "llama2" |
|||
llm = Ollama(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "http://custom.ollama:11434" |
|||
llm = Ollama(base_url=base_url) |
|||
self.mock_ollama.Client.assert_called_with( |
|||
host=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Ollama instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Ollama() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that chat was called correctly |
|||
self.mock_client.chat.assert_called_once_with( |
|||
model="qwq", |
|||
messages=messages |
|||
) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Ollama instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Ollama() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that chat was called correctly |
|||
self.mock_client.chat.assert_called_once_with( |
|||
model="qwq", |
|||
messages=messages |
|||
) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Ollama instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Ollama() |
|||
|
|||
# Mock an error response |
|||
self.mock_client.chat.side_effect = Exception("Ollama API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Ollama API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,167 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import OpenAI |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestOpenAI(unittest.TestCase): |
|||
"""Tests for the OpenAI LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = OpenAI() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url=None |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "o1-mini") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
base_url = "https://api.openai.com/v1" |
|||
with patch.dict(os.environ, { |
|||
"OPENAI_API_KEY": api_key, |
|||
"OPENAI_BASE_URL": base_url |
|||
}): |
|||
llm = OpenAI() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
api_key = "test_api_key_param" |
|||
llm = OpenAI(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url=None |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
model = "gpt-4" |
|||
llm = OpenAI(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.openai.api" |
|||
llm = OpenAI(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = OpenAI() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "o1-mini") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = OpenAI() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "o1-mini") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create OpenAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = OpenAI() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("OpenAI API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "OpenAI API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import PPIO |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestPPIO(unittest.TestCase): |
|||
"""Tests for the PPIO LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = PPIO() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://api.ppinfra.com/v3/openai" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek/deepseek-r1-turbo") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"PPIO_API_KEY": api_key}): |
|||
llm = PPIO() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.ppinfra.com/v3/openai" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = PPIO(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.ppinfra.com/v3/openai" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "deepseek/deepseek-r1-max" |
|||
llm = PPIO(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.ppio.api" |
|||
llm = PPIO(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create PPIO instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = PPIO() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek/deepseek-r1-turbo") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create PPIO instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = PPIO() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek/deepseek-r1-turbo") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create PPIO instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = PPIO() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("PPIO API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "PPIO API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import SiliconFlow |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestSiliconFlow(unittest.TestCase): |
|||
"""Tests for the SiliconFlow LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = SiliconFlow() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://api.siliconflow.cn/v1" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek-ai/DeepSeek-R1") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"SILICONFLOW_API_KEY": api_key}): |
|||
llm = SiliconFlow() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.siliconflow.cn/v1" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = SiliconFlow(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.siliconflow.cn/v1" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "deepseek-ai/DeepSeek-R2" |
|||
llm = SiliconFlow(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.siliconflow.api" |
|||
llm = SiliconFlow(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create SiliconFlow instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = SiliconFlow() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create SiliconFlow instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = SiliconFlow() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create SiliconFlow instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = SiliconFlow() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("SiliconFlow API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "SiliconFlow API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,151 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import TogetherAI |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestTogetherAI(unittest.TestCase): |
|||
"""Tests for the TogetherAI LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_together = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_together.Together = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'together': self.mock_together}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = TogetherAI() |
|||
# Check that Together client was initialized correctly |
|||
self.mock_together.Together.assert_called_once_with( |
|||
api_key=None |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek-ai/DeepSeek-R1") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"TOGETHER_API_KEY": api_key}): |
|||
llm = TogetherAI() |
|||
self.mock_together.Together.assert_called_with( |
|||
api_key=api_key |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = TogetherAI(api_key=api_key) |
|||
self.mock_together.Together.assert_called_with( |
|||
api_key=api_key |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|||
llm = TogetherAI(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create TogetherAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = TogetherAI() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create TogetherAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = TogetherAI() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create TogetherAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = TogetherAI() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("Together API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Together API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import Volcengine |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestVolcengine(unittest.TestCase): |
|||
"""Tests for the Volcengine LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Volcengine() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://ark.cn-beijing.volces.com/api/v3" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "deepseek-r1-250120") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"VOLCENGINE_API_KEY": api_key}): |
|||
llm = Volcengine() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://ark.cn-beijing.volces.com/api/v3" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = Volcengine(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://ark.cn-beijing.volces.com/api/v3" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "deepseek-r2-250120" |
|||
llm = Volcengine(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.volcengine.api" |
|||
llm = Volcengine(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create Volcengine instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Volcengine() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-r1-250120") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create Volcengine instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Volcengine() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "deepseek-r1-250120") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create Volcengine instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = Volcengine() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("Volcengine API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "Volcengine API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,421 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import MagicMock, patch |
|||
import os |
|||
|
|||
class TestWatsonX(unittest.TestCase): |
|||
"""Test cases for WatsonX class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
pass |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_init_with_env_vars(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with environment variables.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
# Check that Credentials was called with correct parameters |
|||
mock_credentials_class.assert_called_once_with( |
|||
url='https://test.watsonx.com', |
|||
api_key='test-api-key' |
|||
) |
|||
|
|||
# Check that ModelInference was called with correct parameters |
|||
mock_model_inference_class.assert_called_once_with( |
|||
model_id='ibm/granite-3-3-8b-instruct', |
|||
credentials=mock_credentials_instance, |
|||
project_id='test-project-id' |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, 'ibm/granite-3-3-8b-instruct') |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com' |
|||
}) |
|||
def test_init_with_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with space_id instead of project_id.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX(space_id='test-space-id') |
|||
|
|||
# Check that ModelInference was called with space_id |
|||
mock_model_inference_class.assert_called_once_with( |
|||
model_id='ibm/granite-3-3-8b-instruct', |
|||
credentials=mock_credentials_instance, |
|||
space_id='test-space-id' |
|||
) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_init_with_custom_model(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with custom model.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX(model='ibm/granite-13b-chat-v2') |
|||
|
|||
# Check that ModelInference was called with custom model |
|||
mock_model_inference_class.assert_called_once_with( |
|||
model_id='ibm/granite-13b-chat-v2', |
|||
credentials=mock_credentials_instance, |
|||
project_id='test-project-id' |
|||
) |
|||
|
|||
self.assertEqual(llm.model, 'ibm/granite-13b-chat-v2') |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_init_with_custom_params(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with custom generation parameters.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX( |
|||
max_new_tokens=500, |
|||
temperature=0.7, |
|||
top_p=0.9, |
|||
top_k=40 |
|||
) |
|||
|
|||
# Check that generation parameters were set correctly |
|||
expected_params = { |
|||
'max_new_tokens': 500, |
|||
'temperature': 0.7, |
|||
'top_p': 0.9, |
|||
'top_k': 40 |
|||
} |
|||
self.assertEqual(llm.generation_params, expected_params) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
def test_init_missing_api_key(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with missing API key.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
with patch.dict(os.environ, {}, clear=True): |
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonX() |
|||
|
|||
self.assertIn("WATSONX_APIKEY", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key' |
|||
}) |
|||
def test_init_missing_url(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with missing URL.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonX() |
|||
|
|||
self.assertIn("WATSONX_URL", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com' |
|||
}) |
|||
def test_init_missing_project_and_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test initialization with missing both project_id and space_id.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
with self.assertRaises(ValueError) as context: |
|||
WatsonX() |
|||
|
|||
self.assertIn("WATSONX_PROJECT_ID", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_chat_simple_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test chat with a simple message.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
mock_model_inference_instance.generate_text.return_value = "This is a test response from WatsonX." |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
messages = [ |
|||
{"role": "user", "content": "Hello, how are you?"} |
|||
] |
|||
|
|||
response = llm.chat(messages) |
|||
|
|||
# Check that generate_text was called |
|||
mock_model_inference_instance.generate_text.assert_called_once() |
|||
call_args = mock_model_inference_instance.generate_text.call_args |
|||
|
|||
# Check the prompt format |
|||
expected_prompt = "Human: Hello, how are you?\n\nAssistant:" |
|||
self.assertEqual(call_args[1]['prompt'], expected_prompt) |
|||
|
|||
# Check response |
|||
self.assertEqual(response.content, "This is a test response from WatsonX.") |
|||
self.assertIsInstance(response.total_tokens, int) |
|||
self.assertGreater(response.total_tokens, 0) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_chat_with_system_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test chat with system and user messages.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
mock_model_inference_instance.generate_text.return_value = "4" |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant."}, |
|||
{"role": "user", "content": "What is 2+2?"} |
|||
] |
|||
|
|||
response = llm.chat(messages) |
|||
|
|||
# Check that generate_text was called |
|||
mock_model_inference_instance.generate_text.assert_called_once() |
|||
call_args = mock_model_inference_instance.generate_text.call_args |
|||
|
|||
# Check the prompt format |
|||
expected_prompt = "System: You are a helpful assistant.\n\nHuman: What is 2+2?\n\nAssistant:" |
|||
self.assertEqual(call_args[1]['prompt'], expected_prompt) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_chat_conversation_history(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test chat with conversation history.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
mock_model_inference_instance.generate_text.return_value = "6" |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant."}, |
|||
{"role": "user", "content": "What is 2+2?"}, |
|||
{"role": "assistant", "content": "2+2 equals 4."}, |
|||
{"role": "user", "content": "What about 3+3?"} |
|||
] |
|||
|
|||
response = llm.chat(messages) |
|||
|
|||
# Check that generate_text was called |
|||
mock_model_inference_instance.generate_text.assert_called_once() |
|||
call_args = mock_model_inference_instance.generate_text.call_args |
|||
|
|||
# Check the prompt format includes conversation history |
|||
expected_prompt = ("System: You are a helpful assistant.\n\n" |
|||
"Human: What is 2+2?\n\n" |
|||
"Assistant: 2+2 equals 4.\n\n" |
|||
"Human: What about 3+3?\n\n" |
|||
"Assistant:") |
|||
self.assertEqual(call_args[1]['prompt'], expected_prompt) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_chat_error_handling(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test error handling in chat method.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
mock_model_inference_instance.generate_text.side_effect = Exception("API Error") |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
|
|||
# Test that the exception is properly wrapped |
|||
with self.assertRaises(RuntimeError) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertIn("Error generating response with WatsonX", str(context.exception)) |
|||
|
|||
@patch('deepsearcher.llm.watsonx.ModelInference') |
|||
@patch('deepsearcher.llm.watsonx.Credentials') |
|||
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames') |
|||
@patch.dict('os.environ', { |
|||
'WATSONX_APIKEY': 'test-api-key', |
|||
'WATSONX_URL': 'https://test.watsonx.com', |
|||
'WATSONX_PROJECT_ID': 'test-project-id' |
|||
}) |
|||
def test_messages_to_prompt(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class): |
|||
"""Test the _messages_to_prompt method.""" |
|||
from deepsearcher.llm.watsonx import WatsonX |
|||
|
|||
mock_credentials_instance = MagicMock() |
|||
mock_model_inference_instance = MagicMock() |
|||
|
|||
mock_credentials_class.return_value = mock_credentials_instance |
|||
mock_model_inference_class.return_value = mock_model_inference_instance |
|||
|
|||
# Mock the GenTextParamsMetaNames attributes |
|||
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens' |
|||
mock_gen_text_params_class.TEMPERATURE = 'temperature' |
|||
mock_gen_text_params_class.TOP_P = 'top_p' |
|||
mock_gen_text_params_class.TOP_K = 'top_k' |
|||
|
|||
llm = WatsonX() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "System message"}, |
|||
{"role": "user", "content": "User message"}, |
|||
{"role": "assistant", "content": "Assistant message"}, |
|||
{"role": "user", "content": "Another user message"} |
|||
] |
|||
|
|||
prompt = llm._messages_to_prompt(messages) |
|||
|
|||
expected_prompt = ("System: System message\n\n" |
|||
"Human: User message\n\n" |
|||
"Assistant: Assistant message\n\n" |
|||
"Human: Another user message\n\n" |
|||
"Assistant:") |
|||
|
|||
self.assertEqual(prompt, expected_prompt) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
@ -1,165 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import os |
|||
import logging |
|||
|
|||
# Disable logging for tests |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
from deepsearcher.llm import XAI |
|||
from deepsearcher.llm.base import ChatResponse |
|||
|
|||
|
|||
class TestXAI(unittest.TestCase): |
|||
"""Tests for the X.AI LLM provider.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock module and components |
|||
self.mock_openai = MagicMock() |
|||
self.mock_client = MagicMock() |
|||
self.mock_chat = MagicMock() |
|||
self.mock_completions = MagicMock() |
|||
|
|||
# Set up the mock module structure |
|||
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client) |
|||
self.mock_client.chat = self.mock_chat |
|||
self.mock_chat.completions = self.mock_completions |
|||
|
|||
# Set up mock response |
|||
self.mock_response = MagicMock() |
|||
self.mock_choice = MagicMock() |
|||
self.mock_message = MagicMock() |
|||
self.mock_usage = MagicMock() |
|||
|
|||
self.mock_message.content = "Test response" |
|||
self.mock_choice.message = self.mock_message |
|||
self.mock_usage.total_tokens = 100 |
|||
|
|||
self.mock_response.choices = [self.mock_choice] |
|||
self.mock_response.usage = self.mock_usage |
|||
self.mock_completions.create.return_value = self.mock_response |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai}) |
|||
self.module_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default parameters.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = XAI() |
|||
# Check that OpenAI client was initialized correctly |
|||
self.mock_openai.OpenAI.assert_called_once_with( |
|||
api_key=None, |
|||
base_url="https://api.x.ai/v1" |
|||
) |
|||
|
|||
# Check default model |
|||
self.assertEqual(llm.model, "grok-2-latest") |
|||
|
|||
def test_init_with_api_key_from_env(self): |
|||
"""Test initialization with API key from environment variable.""" |
|||
api_key = "test_api_key_from_env" |
|||
with patch.dict(os.environ, {"XAI_API_KEY": api_key}): |
|||
llm = XAI() |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.x.ai/v1" |
|||
) |
|||
|
|||
def test_init_with_api_key_parameter(self): |
|||
"""Test initialization with API key as parameter.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
api_key = "test_api_key_param" |
|||
llm = XAI(api_key=api_key) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=api_key, |
|||
base_url="https://api.x.ai/v1" |
|||
) |
|||
|
|||
def test_init_with_custom_model(self): |
|||
"""Test initialization with custom model.""" |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
model = "grok-1" |
|||
llm = XAI(model=model) |
|||
self.assertEqual(llm.model, model) |
|||
|
|||
def test_init_with_custom_base_url(self): |
|||
"""Test initialization with custom base URL.""" |
|||
# Clear environment variables temporarily |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
base_url = "https://custom.x.ai" |
|||
llm = XAI(base_url=base_url) |
|||
self.mock_openai.OpenAI.assert_called_with( |
|||
api_key=None, |
|||
base_url=base_url |
|||
) |
|||
|
|||
def test_chat_single_message(self): |
|||
"""Test chat with a single message.""" |
|||
# Create XAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = XAI() |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "grok-2-latest") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_multiple_messages(self): |
|||
"""Test chat with multiple messages.""" |
|||
# Create XAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = XAI() |
|||
|
|||
messages = [ |
|||
{"role": "system", "content": "You are a helpful assistant"}, |
|||
{"role": "user", "content": "Hello"}, |
|||
{"role": "assistant", "content": "Hi there!"}, |
|||
{"role": "user", "content": "How are you?"} |
|||
] |
|||
response = llm.chat(messages) |
|||
|
|||
# Check that completions.create was called correctly |
|||
self.mock_completions.create.assert_called_once() |
|||
call_args = self.mock_completions.create.call_args |
|||
self.assertEqual(call_args[1]["model"], "grok-2-latest") |
|||
self.assertEqual(call_args[1]["messages"], messages) |
|||
|
|||
# Check response |
|||
self.assertIsInstance(response, ChatResponse) |
|||
self.assertEqual(response.content, "Test response") |
|||
self.assertEqual(response.total_tokens, 100) |
|||
|
|||
def test_chat_with_error(self): |
|||
"""Test chat when an error occurs.""" |
|||
# Create XAI instance with mocked environment |
|||
with patch.dict('os.environ', {}, clear=True): |
|||
llm = XAI() |
|||
|
|||
# Mock an error response |
|||
self.mock_completions.create.side_effect = Exception("XAI API Error") |
|||
|
|||
messages = [{"role": "user", "content": "Hello"}] |
|||
with self.assertRaises(Exception) as context: |
|||
llm.chat(messages) |
|||
|
|||
self.assertEqual(str(context.exception), "XAI API Error") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher.loader package |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher.loader.file_loader package |
@ -1,69 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import tempfile |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
from deepsearcher.loader.file_loader.base import BaseLoader |
|||
|
|||
|
|||
class TestBaseLoader(unittest.TestCase): |
|||
"""Tests for the BaseLoader class.""" |
|||
|
|||
def test_abstract_methods(self): |
|||
"""Test that BaseLoader defines abstract methods.""" |
|||
# For abstract base classes, we can check if methods are defined |
|||
# but not implemented in the base class |
|||
self.assertTrue(hasattr(BaseLoader, 'load_file')) |
|||
self.assertTrue(hasattr(BaseLoader, 'supported_file_types')) |
|||
|
|||
def test_load_directory(self): |
|||
"""Test the load_directory method.""" |
|||
# Create a subclass of BaseLoader for testing |
|||
class TestLoader(BaseLoader): |
|||
@property |
|||
def supported_file_types(self): |
|||
return [".txt", ".md"] |
|||
|
|||
def load_file(self, file_path): |
|||
# Mock implementation that returns a simple Document |
|||
return [Document(page_content=f"Content of {file_path}", metadata={"reference": file_path})] |
|||
|
|||
# Create a temporary directory with test files |
|||
with tempfile.TemporaryDirectory() as temp_dir: |
|||
# Create test files |
|||
file_paths = [ |
|||
os.path.join(temp_dir, "test1.txt"), |
|||
os.path.join(temp_dir, "test2.md"), |
|||
os.path.join(temp_dir, "test3.pdf"), # Unsupported format |
|||
os.path.join(temp_dir, "subdir", "test4.txt") |
|||
] |
|||
|
|||
# Create subdirectory |
|||
os.makedirs(os.path.join(temp_dir, "subdir"), exist_ok=True) |
|||
|
|||
# Create files |
|||
for path in file_paths: |
|||
# Skip the file if it's in a subdirectory that doesn't exist |
|||
if not os.path.exists(os.path.dirname(path)): |
|||
continue |
|||
with open(path, 'w') as f: |
|||
f.write(f"Content of {path}") |
|||
|
|||
# Test loading the directory |
|||
loader = TestLoader() |
|||
documents = loader.load_directory(temp_dir) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(documents), 3) # Should find 3 supported files |
|||
|
|||
# Verify each document |
|||
references = [doc.metadata["reference"] for doc in documents] |
|||
self.assertIn(file_paths[0], references) # test1.txt |
|||
self.assertIn(file_paths[1], references) # test2.md |
|||
self.assertNotIn(file_paths[2], references) # test3.pdf (unsupported) |
|||
self.assertIn(file_paths[3], references) # subdir/test4.txt |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,185 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import tempfile |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.file_loader import DoclingLoader |
|||
|
|||
|
|||
class TestDoclingLoader(unittest.TestCase): |
|||
"""Tests for the DoclingLoader class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create patches for the docling modules |
|||
self.docling_patcher = patch.dict('sys.modules', { |
|||
'docling': MagicMock(), |
|||
'docling.document_converter': MagicMock(), |
|||
'docling_core': MagicMock(), |
|||
'docling_core.transforms': MagicMock(), |
|||
'docling_core.transforms.chunker': MagicMock() |
|||
}) |
|||
self.docling_patcher.start() |
|||
|
|||
# Create mocks for the classes |
|||
self.mock_document_converter = MagicMock() |
|||
self.mock_hierarchical_chunker = MagicMock() |
|||
|
|||
# Add the mocks to the modules |
|||
import sys |
|||
sys.modules['docling.document_converter'].DocumentConverter = self.mock_document_converter |
|||
sys.modules['docling_core.transforms.chunker'].HierarchicalChunker = self.mock_hierarchical_chunker |
|||
|
|||
# Set up mock instances |
|||
self.mock_converter_instance = MagicMock() |
|||
self.mock_chunker_instance = MagicMock() |
|||
self.mock_document_converter.return_value = self.mock_converter_instance |
|||
self.mock_hierarchical_chunker.return_value = self.mock_chunker_instance |
|||
|
|||
# Create a temporary directory |
|||
self.temp_dir = tempfile.TemporaryDirectory() |
|||
|
|||
# Create a test markdown file |
|||
self.md_file_path = os.path.join(self.temp_dir.name, "test.md") |
|||
with open(self.md_file_path, "w", encoding="utf-8") as f: |
|||
f.write("# Test Markdown\nThis is a test markdown file.") |
|||
|
|||
# Create a test unsupported file |
|||
self.unsupported_file_path = os.path.join(self.temp_dir.name, "test.xyz") |
|||
with open(self.unsupported_file_path, "w", encoding="utf-8") as f: |
|||
f.write("This is an unsupported file type.") |
|||
|
|||
# Create a subdirectory with a test file |
|||
self.sub_dir = os.path.join(self.temp_dir.name, "subdir") |
|||
os.makedirs(self.sub_dir, exist_ok=True) |
|||
self.sub_file_path = os.path.join(self.sub_dir, "subfile.md") |
|||
with open(self.sub_file_path, "w", encoding="utf-8") as f: |
|||
f.write("# Subdir Test\nThis is a test markdown file in a subdirectory.") |
|||
|
|||
# Create the loader |
|||
self.loader = DoclingLoader() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.docling_patcher.stop() |
|||
self.temp_dir.cleanup() |
|||
|
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
# Verify instances were created |
|||
self.mock_document_converter.assert_called_once() |
|||
self.mock_hierarchical_chunker.assert_called_once() |
|||
|
|||
# Check that the instances were assigned correctly |
|||
self.assertEqual(self.loader.converter, self.mock_converter_instance) |
|||
self.assertEqual(self.loader.chunker, self.mock_chunker_instance) |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
file_types = self.loader.supported_file_types |
|||
|
|||
# Check that the common file types are included |
|||
common_types = ["pdf", "docx", "md", "html", "csv", "jpg"] |
|||
for file_type in common_types: |
|||
self.assertIn(file_type, file_types) |
|||
|
|||
def test_load_file(self): |
|||
"""Test loading a single file.""" |
|||
# Set up mock document and chunks |
|||
mock_document = MagicMock() |
|||
mock_conversion_result = MagicMock() |
|||
mock_conversion_result.document = mock_document |
|||
|
|||
# Set up three mock chunks |
|||
mock_chunks = [] |
|||
for i in range(3): |
|||
chunk = MagicMock() |
|||
chunk.text = f"Chunk {i} content" |
|||
mock_chunks.append(chunk) |
|||
|
|||
# Configure mock converter and chunker |
|||
self.mock_converter_instance.convert.return_value = mock_conversion_result |
|||
self.mock_chunker_instance.chunk.return_value = mock_chunks |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_file(self.md_file_path) |
|||
|
|||
# Verify converter was called correctly |
|||
self.mock_converter_instance.convert.assert_called_once_with(self.md_file_path) |
|||
|
|||
# Verify chunker was called correctly |
|||
self.mock_chunker_instance.chunk.assert_called_once_with(mock_document) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 3) |
|||
|
|||
# Check each document |
|||
for i, document in enumerate(documents): |
|||
self.assertEqual(document.page_content, f"Chunk {i} content") |
|||
self.assertEqual(document.metadata["reference"], self.md_file_path) |
|||
self.assertEqual(document.metadata["text"], f"Chunk {i} content") |
|||
|
|||
def test_load_file_not_found(self): |
|||
"""Test loading a non-existent file.""" |
|||
non_existent_file = os.path.join(self.temp_dir.name, "non_existent.md") |
|||
with self.assertRaises(FileNotFoundError): |
|||
self.loader.load_file(non_existent_file) |
|||
|
|||
def test_load_unsupported_file_type(self): |
|||
"""Test loading a file with unsupported extension.""" |
|||
with self.assertRaises(ValueError): |
|||
self.loader.load_file(self.unsupported_file_path) |
|||
|
|||
def test_load_file_error(self): |
|||
"""Test error handling when loading a file.""" |
|||
# Configure converter to raise an exception |
|||
self.mock_converter_instance.convert.side_effect = Exception("Test error") |
|||
|
|||
# Verify that the error is propagated |
|||
with self.assertRaises(IOError): |
|||
self.loader.load_file(self.md_file_path) |
|||
|
|||
def test_load_directory(self): |
|||
"""Test loading a directory.""" |
|||
# Set up mock document and chunks |
|||
mock_document = MagicMock() |
|||
mock_conversion_result = MagicMock() |
|||
mock_conversion_result.document = mock_document |
|||
|
|||
# Set up a single mock chunk |
|||
mock_chunk = MagicMock() |
|||
mock_chunk.text = "Test chunk content" |
|||
|
|||
# Configure mock converter and chunker |
|||
self.mock_converter_instance.convert.return_value = mock_conversion_result |
|||
self.mock_chunker_instance.chunk.return_value = [mock_chunk] |
|||
|
|||
# Load the directory |
|||
documents = self.loader.load_directory(self.temp_dir.name) |
|||
|
|||
# Verify converter was called twice (once for each MD file) |
|||
self.assertEqual(self.mock_converter_instance.convert.call_count, 2) |
|||
|
|||
# Verify converter was called with both MD files |
|||
self.mock_converter_instance.convert.assert_any_call(self.md_file_path) |
|||
self.mock_converter_instance.convert.assert_any_call(self.sub_file_path) |
|||
|
|||
# Check results - should have two documents (one from each MD file) |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Check each document |
|||
for document in documents: |
|||
self.assertEqual(document.page_content, "Test chunk content") |
|||
self.assertEqual(document.metadata["text"], "Test chunk content") |
|||
self.assertIn(document.metadata["reference"], [self.md_file_path, self.sub_file_path]) |
|||
|
|||
def test_load_not_a_directory(self): |
|||
"""Test loading a path that is not a directory.""" |
|||
with self.assertRaises(NotADirectoryError): |
|||
self.loader.load_directory(self.md_file_path) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,124 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import json |
|||
import tempfile |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.file_loader import JsonFileLoader |
|||
|
|||
|
|||
class TestJsonFileLoader(unittest.TestCase): |
|||
"""Tests for the JsonFileLoader class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up the test environment.""" |
|||
# Create a temporary directory for test files |
|||
self.temp_dir = tempfile.TemporaryDirectory() |
|||
|
|||
# Sample JSON data |
|||
self.json_data = [ |
|||
{"id": 1, "text": "This is the first document.", "author": "John Doe"}, |
|||
{"id": 2, "text": "This is the second document.", "author": "Jane Smith"} |
|||
] |
|||
|
|||
# Create JSON test file |
|||
self.json_file_path = os.path.join(self.temp_dir.name, "test.json") |
|||
with open(self.json_file_path, "w", encoding="utf-8") as f: |
|||
json.dump(self.json_data, f) |
|||
|
|||
# Create JSONL test file |
|||
self.jsonl_file_path = os.path.join(self.temp_dir.name, "test.jsonl") |
|||
with open(self.jsonl_file_path, "w", encoding="utf-8") as f: |
|||
for item in self.json_data: |
|||
f.write(json.dumps(item) + "\n") |
|||
|
|||
# Create invalid JSON file (not a list) |
|||
self.invalid_json_file_path = os.path.join(self.temp_dir.name, "invalid.json") |
|||
with open(self.invalid_json_file_path, "w", encoding="utf-8") as f: |
|||
json.dump({"id": 1, "text": "This is not a list.", "author": "John Doe"}, f) |
|||
|
|||
# Create invalid JSONL file |
|||
self.invalid_jsonl_file_path = os.path.join(self.temp_dir.name, "invalid.jsonl") |
|||
with open(self.invalid_jsonl_file_path, "w", encoding="utf-8") as f: |
|||
f.write("This is not valid JSON\n") |
|||
f.write(json.dumps({"id": 2, "text": "This is valid JSON", "author": "Jane Smith"}) + "\n") |
|||
|
|||
# Initialize the loader |
|||
self.loader = JsonFileLoader(text_key="text") |
|||
|
|||
# Patch the _read_json_file method to fix the file handling |
|||
original_read_json_file = self.loader._read_json_file |
|||
|
|||
def patched_read_json_file(file_path): |
|||
with open(file_path, 'r') as f: |
|||
json_data = json.load(f) |
|||
if not isinstance(json_data, list): |
|||
raise ValueError("JSON file must contain a list of dictionaries.") |
|||
return json_data |
|||
|
|||
self.loader._read_json_file = patched_read_json_file |
|||
|
|||
def tearDown(self): |
|||
"""Clean up the test environment.""" |
|||
self.temp_dir.cleanup() |
|||
|
|||
def test_load_json_file(self): |
|||
"""Test loading a JSON file.""" |
|||
documents = self.loader.load_file(self.json_file_path) |
|||
|
|||
# Check that we got the right number of documents |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Check the content and metadata of each document |
|||
self.assertEqual(documents[0].page_content, "This is the first document.") |
|||
self.assertEqual(documents[0].metadata["id"], 1) |
|||
self.assertEqual(documents[0].metadata["author"], "John Doe") |
|||
self.assertEqual(documents[0].metadata["reference"], self.json_file_path) |
|||
|
|||
self.assertEqual(documents[1].page_content, "This is the second document.") |
|||
self.assertEqual(documents[1].metadata["id"], 2) |
|||
self.assertEqual(documents[1].metadata["author"], "Jane Smith") |
|||
self.assertEqual(documents[1].metadata["reference"], self.json_file_path) |
|||
|
|||
def test_load_jsonl_file(self): |
|||
"""Test loading a JSONL file.""" |
|||
documents = self.loader.load_file(self.jsonl_file_path) |
|||
|
|||
# Check that we got the right number of documents |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Check the content and metadata of each document |
|||
self.assertEqual(documents[0].page_content, "This is the first document.") |
|||
self.assertEqual(documents[0].metadata["id"], 1) |
|||
self.assertEqual(documents[0].metadata["author"], "John Doe") |
|||
self.assertEqual(documents[0].metadata["reference"], self.jsonl_file_path) |
|||
|
|||
self.assertEqual(documents[1].page_content, "This is the second document.") |
|||
self.assertEqual(documents[1].metadata["id"], 2) |
|||
self.assertEqual(documents[1].metadata["author"], "Jane Smith") |
|||
self.assertEqual(documents[1].metadata["reference"], self.jsonl_file_path) |
|||
|
|||
def test_invalid_json_file(self): |
|||
"""Test loading an invalid JSON file (not a list).""" |
|||
with self.assertRaises(ValueError): |
|||
self.loader.load_file(self.invalid_json_file_path) |
|||
|
|||
def test_invalid_jsonl_file(self): |
|||
"""Test loading a JSONL file with invalid lines.""" |
|||
documents = self.loader.load_file(self.invalid_jsonl_file_path) |
|||
|
|||
# Only the valid line should be loaded |
|||
self.assertEqual(len(documents), 1) |
|||
self.assertEqual(documents[0].page_content, "This is valid JSON") |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
file_types = self.loader.supported_file_types |
|||
self.assertIsInstance(file_types, list) |
|||
self.assertIn("txt", file_types) |
|||
self.assertIn("md", file_types) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,142 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import tempfile |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.file_loader import PDFLoader |
|||
|
|||
|
|||
class TestPDFLoader(unittest.TestCase): |
|||
"""Tests for the PDFLoader class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up the test environment.""" |
|||
# Create a temporary directory |
|||
self.temp_dir = tempfile.TemporaryDirectory() |
|||
|
|||
# Create a text file for testing |
|||
self.text_file_path = os.path.join(self.temp_dir.name, "test.txt") |
|||
with open(self.text_file_path, "w", encoding="utf-8") as f: |
|||
f.write("This is a test text file.") |
|||
|
|||
# Create a markdown file for testing |
|||
self.md_file_path = os.path.join(self.temp_dir.name, "test.md") |
|||
with open(self.md_file_path, "w", encoding="utf-8") as f: |
|||
f.write("# Test Markdown\nThis is a test markdown file.") |
|||
|
|||
# PDF file path (will be mocked) |
|||
self.pdf_file_path = os.path.join(self.temp_dir.name, "test.pdf") |
|||
|
|||
# Create the loader |
|||
self.loader = PDFLoader() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up the test environment.""" |
|||
self.temp_dir.cleanup() |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
file_types = self.loader.supported_file_types |
|||
self.assertIsInstance(file_types, list) |
|||
self.assertIn("pdf", file_types) |
|||
self.assertIn("md", file_types) |
|||
self.assertIn("txt", file_types) |
|||
|
|||
def test_load_text_file(self): |
|||
"""Test loading a text file.""" |
|||
documents = self.loader.load_file(self.text_file_path) |
|||
|
|||
# Check that we got one document |
|||
self.assertEqual(len(documents), 1) |
|||
|
|||
# Check the document content |
|||
document = documents[0] |
|||
self.assertEqual(document.page_content, "This is a test text file.") |
|||
|
|||
# Check the metadata |
|||
self.assertEqual(document.metadata["reference"], self.text_file_path) |
|||
|
|||
def test_load_markdown_file(self): |
|||
"""Test loading a markdown file.""" |
|||
documents = self.loader.load_file(self.md_file_path) |
|||
|
|||
# Check that we got one document |
|||
self.assertEqual(len(documents), 1) |
|||
|
|||
# Check the document content |
|||
document = documents[0] |
|||
self.assertEqual(document.page_content, "# Test Markdown\nThis is a test markdown file.") |
|||
|
|||
# Check the metadata |
|||
self.assertEqual(document.metadata["reference"], self.md_file_path) |
|||
|
|||
@patch("pdfplumber.open") |
|||
def test_load_pdf_file(self, mock_pdf_open): |
|||
"""Test loading a PDF file.""" |
|||
# Set up mock PDF pages |
|||
mock_page1 = MagicMock() |
|||
mock_page1.extract_text.return_value = "Page 1 content" |
|||
|
|||
mock_page2 = MagicMock() |
|||
mock_page2.extract_text.return_value = "Page 2 content" |
|||
|
|||
# Set up mock PDF file |
|||
mock_pdf = MagicMock() |
|||
mock_pdf.pages = [mock_page1, mock_page2] |
|||
mock_pdf.__enter__.return_value = mock_pdf |
|||
mock_pdf.__exit__.return_value = None |
|||
|
|||
# Configure the mock to return our mock PDF |
|||
mock_pdf_open.return_value = mock_pdf |
|||
|
|||
# Create a dummy PDF file |
|||
with open(self.pdf_file_path, "w") as f: |
|||
f.write("dummy pdf content") |
|||
|
|||
# Load the PDF file |
|||
documents = self.loader.load_file(self.pdf_file_path) |
|||
|
|||
# Verify pdfplumber.open was called |
|||
mock_pdf_open.assert_called_once_with(self.pdf_file_path) |
|||
|
|||
# Check that we got one document |
|||
self.assertEqual(len(documents), 1) |
|||
|
|||
# Check the document content |
|||
document = documents[0] |
|||
self.assertEqual(document.page_content, "Page 1 content\n\nPage 2 content") |
|||
|
|||
# Check the metadata |
|||
self.assertEqual(document.metadata["reference"], self.pdf_file_path) |
|||
|
|||
def test_load_directory(self): |
|||
"""Test loading a directory with mixed file types.""" |
|||
# Create the loader |
|||
loader = PDFLoader() |
|||
|
|||
# Mock the load_file method to track calls |
|||
original_load_file = loader.load_file |
|||
calls = [] |
|||
|
|||
def mock_load_file(file_path): |
|||
calls.append(file_path) |
|||
return original_load_file(file_path) |
|||
|
|||
loader.load_file = mock_load_file |
|||
|
|||
# Load the directory |
|||
documents = loader.load_directory(self.temp_dir.name) |
|||
|
|||
# Check that we processed both text and markdown files |
|||
self.assertEqual(len(calls), 2) # text and markdown files |
|||
self.assertIn(self.text_file_path, calls) |
|||
self.assertIn(self.md_file_path, calls) |
|||
|
|||
# Check that we got two documents |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,82 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import tempfile |
|||
|
|||
from deepsearcher.loader.file_loader import TextLoader |
|||
|
|||
|
|||
class TestTextLoader(unittest.TestCase): |
|||
"""Tests for the TextLoader class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up the test environment.""" |
|||
self.loader = TextLoader() |
|||
|
|||
# Create a temporary directory and file for testing |
|||
self.temp_dir = tempfile.TemporaryDirectory() |
|||
self.test_file_path = os.path.join(self.temp_dir.name, "test.txt") |
|||
self.test_content = "This is a test file content.\nWith multiple lines." |
|||
|
|||
# Write test content to the file |
|||
with open(self.test_file_path, "w", encoding="utf-8") as f: |
|||
f.write(self.test_content) |
|||
|
|||
def tearDown(self): |
|||
"""Clean up the test environment.""" |
|||
self.temp_dir.cleanup() |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
supported_types = self.loader.supported_file_types |
|||
self.assertIsInstance(supported_types, list) |
|||
self.assertIn("txt", supported_types) |
|||
self.assertIn("md", supported_types) |
|||
|
|||
def test_load_file(self): |
|||
"""Test loading a text file.""" |
|||
documents = self.loader.load_file(self.test_file_path) |
|||
|
|||
# Check that we got a list with one document |
|||
self.assertIsInstance(documents, list) |
|||
self.assertEqual(len(documents), 1) |
|||
|
|||
# Check the document content |
|||
document = documents[0] |
|||
self.assertEqual(document.page_content, self.test_content) |
|||
|
|||
# Check the metadata |
|||
self.assertIn("reference", document.metadata) |
|||
self.assertEqual(document.metadata["reference"], self.test_file_path) |
|||
|
|||
def test_load_directory(self): |
|||
"""Test loading a directory with text files.""" |
|||
# Create additional test files |
|||
md_file_path = os.path.join(self.temp_dir.name, "test.md") |
|||
with open(md_file_path, "w", encoding="utf-8") as f: |
|||
f.write("# Markdown Test\nThis is a markdown file.") |
|||
|
|||
# Create a non-supported file |
|||
pdf_file_path = os.path.join(self.temp_dir.name, "test.pdf") |
|||
with open(pdf_file_path, "w", encoding="utf-8") as f: |
|||
f.write("PDF content") |
|||
|
|||
# Load the directory |
|||
documents = self.loader.load_directory(self.temp_dir.name) |
|||
|
|||
# Check that we got documents for supported files only |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Get references |
|||
references = [doc.metadata["reference"] for doc in documents] |
|||
|
|||
# Check that supported files were loaded |
|||
self.assertIn(self.test_file_path, references) |
|||
self.assertIn(md_file_path, references) |
|||
|
|||
# Check that unsupported file was not loaded |
|||
for doc in documents: |
|||
self.assertNotEqual(doc.metadata["reference"], pdf_file_path) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,203 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
import shutil |
|||
import tempfile |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.file_loader import UnstructuredLoader |
|||
|
|||
|
|||
class TestUnstructuredLoader(unittest.TestCase): |
|||
"""Tests for the UnstructuredLoader class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create a temporary directory for tests |
|||
self.temp_dir = tempfile.TemporaryDirectory() |
|||
|
|||
# Create a test file |
|||
self.test_file_path = os.path.join(self.temp_dir.name, "test.txt") |
|||
with open(self.test_file_path, "w", encoding="utf-8") as f: |
|||
f.write("This is a test file.") |
|||
|
|||
# Path for mock processed outputs |
|||
self.mock_output_dir = os.path.join(self.temp_dir.name, "mock_outputs") |
|||
os.makedirs(self.mock_output_dir, exist_ok=True) |
|||
|
|||
# Create a mock JSON output file |
|||
self.mock_json_path = os.path.join(self.mock_output_dir, "test_output.json") |
|||
with open(self.mock_json_path, "w", encoding="utf-8") as f: |
|||
f.write('{"elements": [{"text": "This is extracted text.", "metadata": {"filename": "test.txt"}}]}') |
|||
|
|||
# Set up patches for unstructured modules |
|||
self.unstructured_modules = { |
|||
'unstructured_ingest': MagicMock(), |
|||
'unstructured_ingest.interfaces': MagicMock(), |
|||
'unstructured_ingest.pipeline': MagicMock(), |
|||
'unstructured_ingest.pipeline.pipeline': MagicMock(), |
|||
'unstructured_ingest.processes': MagicMock(), |
|||
'unstructured_ingest.processes.connectors': MagicMock(), |
|||
'unstructured_ingest.processes.connectors.local': MagicMock(), |
|||
'unstructured_ingest.processes.partitioner': MagicMock(), |
|||
'unstructured': MagicMock(), |
|||
'unstructured.staging': MagicMock(), |
|||
'unstructured.staging.base': MagicMock(), |
|||
} |
|||
|
|||
self.patches = [] |
|||
for module_name, mock_module in self.unstructured_modules.items(): |
|||
patcher = patch.dict('sys.modules', {module_name: mock_module}) |
|||
patcher.start() |
|||
self.patches.append(patcher) |
|||
|
|||
# Create mock Pipeline class |
|||
self.mock_pipeline = MagicMock() |
|||
self.unstructured_modules['unstructured_ingest.pipeline.pipeline'].Pipeline = self.mock_pipeline |
|||
self.mock_pipeline.from_configs.return_value = self.mock_pipeline |
|||
|
|||
# Create mock Element class |
|||
self.mock_element = MagicMock() |
|||
self.mock_element.text = "This is extracted text." |
|||
self.mock_element.metadata = MagicMock() |
|||
self.mock_element.metadata.to_dict.return_value = {"filename": "test.txt"} |
|||
|
|||
# Set up elements_from_json mock |
|||
self.unstructured_modules['unstructured.staging.base'].elements_from_json = MagicMock() |
|||
self.unstructured_modules['unstructured.staging.base'].elements_from_json.return_value = [self.mock_element] |
|||
|
|||
# Patch makedirs and rmtree but don't assert on them |
|||
with patch('os.makedirs'): |
|||
with patch('shutil.rmtree'): |
|||
self.loader = UnstructuredLoader() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
# Stop all patches |
|||
for patcher in self.patches: |
|||
patcher.stop() |
|||
|
|||
# Remove temporary directory |
|||
self.temp_dir.cleanup() |
|||
|
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
self.assertEqual(self.loader.directory_with_results, "./pdf_processed_outputs") |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
file_types = self.loader.supported_file_types |
|||
|
|||
# Check that common file types are included |
|||
common_types = ["pdf", "docx", "txt", "html", "md", "jpg"] |
|||
for file_type in common_types: |
|||
self.assertIn(file_type, file_types) |
|||
|
|||
# Check total number of supported types (should be extensive) |
|||
self.assertGreater(len(file_types), 20) |
|||
|
|||
@patch('os.listdir') |
|||
def test_load_file(self, mock_listdir): |
|||
"""Test loading a single file.""" |
|||
# Configure mocks |
|||
mock_listdir.return_value = ["test_output.json"] |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_file(self.test_file_path) |
|||
|
|||
# Verify Pipeline.from_configs was called |
|||
self.mock_pipeline.from_configs.assert_called_once() |
|||
self.mock_pipeline.run.assert_called_once() |
|||
|
|||
# Verify elements_from_json was called |
|||
self.unstructured_modules['unstructured.staging.base'].elements_from_json.assert_called_once() |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 1) |
|||
self.assertEqual(documents[0].page_content, "This is extracted text.") |
|||
self.assertEqual(documents[0].metadata["reference"], self.test_file_path) |
|||
self.assertEqual(documents[0].metadata["filename"], "test.txt") |
|||
|
|||
@patch('os.listdir') |
|||
def test_load_directory(self, mock_listdir): |
|||
"""Test loading a directory.""" |
|||
# Configure mocks |
|||
mock_listdir.return_value = ["test_output.json"] |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_directory(self.temp_dir.name) |
|||
|
|||
# Verify Pipeline.from_configs was called |
|||
self.mock_pipeline.from_configs.assert_called_once() |
|||
self.mock_pipeline.run.assert_called_once() |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 1) |
|||
self.assertEqual(documents[0].page_content, "This is extracted text.") |
|||
self.assertEqual(documents[0].metadata["reference"], self.temp_dir.name) |
|||
|
|||
@patch('os.listdir') |
|||
def test_load_with_api(self, mock_listdir): |
|||
"""Test loading with API environment variables.""" |
|||
# Create a mock for os.environ.get |
|||
with patch('os.environ.get') as mock_env_get: |
|||
# Configure environment variables |
|||
mock_env_get.side_effect = lambda key, default=None: { |
|||
"UNSTRUCTURED_API_KEY": "test-key", |
|||
"UNSTRUCTURED_API_URL": "https://api.example.com" |
|||
}.get(key, default) |
|||
|
|||
# Configure listdir mock |
|||
mock_listdir.return_value = ["test_output.json"] |
|||
|
|||
# Create a mock for PartitionerConfig |
|||
mock_partitioner_config = MagicMock() |
|||
self.unstructured_modules['unstructured_ingest.processes.partitioner'].PartitionerConfig = mock_partitioner_config |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_file(self.test_file_path) |
|||
|
|||
# Verify Pipeline.from_configs was called |
|||
self.mock_pipeline.from_configs.assert_called_once() |
|||
|
|||
# Check that PartitionerConfig was called with correct parameters |
|||
mock_partitioner_config.assert_called_once() |
|||
args, kwargs = mock_partitioner_config.call_args |
|||
self.assertTrue(kwargs.get('partition_by_api')) |
|||
self.assertEqual(kwargs.get('api_key'), "test-key") |
|||
self.assertEqual(kwargs.get('partition_endpoint'), "https://api.example.com") |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 1) |
|||
|
|||
@patch('os.listdir') |
|||
def test_empty_output(self, mock_listdir): |
|||
"""Test handling of empty output directory.""" |
|||
# Configure listdir to return no JSON files |
|||
mock_listdir.return_value = [] |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_file(self.test_file_path) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 0) |
|||
|
|||
@patch('os.listdir') |
|||
def test_error_reading_json(self, mock_listdir): |
|||
"""Test handling of errors when reading JSON files.""" |
|||
# Configure listdir mock |
|||
mock_listdir.return_value = ["test_output.json"] |
|||
|
|||
# Configure elements_from_json to raise an IOError |
|||
self.unstructured_modules['unstructured.staging.base'].elements_from_json.side_effect = IOError("Test error") |
|||
|
|||
# Call the method |
|||
documents = self.loader.load_file(self.test_file_path) |
|||
|
|||
# Check results (should be empty) |
|||
self.assertEqual(len(documents), 0) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,98 +0,0 @@ |
|||
import unittest |
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.splitter import Chunk, split_docs_to_chunks, _sentence_window_split |
|||
|
|||
|
|||
class TestSplitter(unittest.TestCase): |
|||
"""Tests for the splitter module.""" |
|||
|
|||
def test_chunk_init(self): |
|||
"""Test initialization of Chunk class.""" |
|||
# Test with minimal parameters |
|||
chunk = Chunk(text="Test text", reference="test_ref") |
|||
self.assertEqual(chunk.text, "Test text") |
|||
self.assertEqual(chunk.reference, "test_ref") |
|||
self.assertEqual(chunk.metadata, {}) |
|||
self.assertIsNone(chunk.embedding) |
|||
|
|||
# Test with all parameters |
|||
metadata = {"key": "value"} |
|||
embedding = [0.1, 0.2, 0.3] |
|||
chunk = Chunk(text="Test text", reference="test_ref", metadata=metadata, embedding=embedding) |
|||
self.assertEqual(chunk.text, "Test text") |
|||
self.assertEqual(chunk.reference, "test_ref") |
|||
self.assertEqual(chunk.metadata, metadata) |
|||
self.assertEqual(chunk.embedding, embedding) |
|||
|
|||
def test_sentence_window_split(self): |
|||
"""Test _sentence_window_split function.""" |
|||
# Create a test document |
|||
original_text = "This is a test document. It has multiple sentences. This is for testing the splitter." |
|||
original_doc = Document(page_content=original_text, metadata={"reference": "test_doc"}) |
|||
|
|||
# Create split documents |
|||
split_docs = [ |
|||
Document(page_content="This is a test document.", metadata={"reference": "test_doc"}), |
|||
Document(page_content="It has multiple sentences.", metadata={"reference": "test_doc"}), |
|||
Document(page_content="This is for testing the splitter.", metadata={"reference": "test_doc"}) |
|||
] |
|||
|
|||
# Test with default offset |
|||
chunks = _sentence_window_split(split_docs, original_doc) |
|||
|
|||
# Verify the results |
|||
self.assertEqual(len(chunks), 3) |
|||
for i, chunk in enumerate(chunks): |
|||
self.assertEqual(chunk.text, split_docs[i].page_content) |
|||
self.assertEqual(chunk.reference, "test_doc") |
|||
self.assertIn("wider_text", chunk.metadata) |
|||
# The wider text should contain the original text since our test document is short |
|||
self.assertEqual(chunk.metadata["wider_text"], original_text) |
|||
|
|||
# Test with smaller offset |
|||
chunks = _sentence_window_split(split_docs, original_doc, offset=10) |
|||
|
|||
# Verify the results with smaller context windows |
|||
self.assertEqual(len(chunks), 3) |
|||
for chunk in chunks: |
|||
# With smaller offset, wider_text should be shorter than the full original text |
|||
self.assertLessEqual(len(chunk.metadata["wider_text"]), len(original_text)) |
|||
|
|||
def test_split_docs_to_chunks(self): |
|||
"""Test split_docs_to_chunks function.""" |
|||
# Create test documents |
|||
docs = [ |
|||
Document( |
|||
page_content="This is document one. It has some content for testing.", |
|||
metadata={"reference": "doc1"} |
|||
), |
|||
Document( |
|||
page_content="This is document two. It also has content for testing purposes.", |
|||
metadata={"reference": "doc2"} |
|||
) |
|||
] |
|||
|
|||
# Test with default parameters |
|||
chunks = split_docs_to_chunks(docs) |
|||
|
|||
# Verify the results |
|||
self.assertGreater(len(chunks), 0) |
|||
for chunk in chunks: |
|||
self.assertIsInstance(chunk, Chunk) |
|||
self.assertIn(chunk.reference, ["doc1", "doc2"]) |
|||
self.assertIn("wider_text", chunk.metadata) |
|||
|
|||
# Test with custom chunk size and overlap |
|||
chunks = split_docs_to_chunks(docs, chunk_size=10, chunk_overlap=2) |
|||
|
|||
# With small chunk size, we should get more chunks |
|||
self.assertGreater(len(chunks), 2) |
|||
for chunk in chunks: |
|||
self.assertIsInstance(chunk, Chunk) |
|||
self.assertIn(chunk.reference, ["doc1", "doc2"]) |
|||
self.assertIn("wider_text", chunk.metadata) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1 +0,0 @@ |
|||
# Tests for the deepsearcher.loader.web_crawler package |
@ -1,53 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from deepsearcher.loader.web_crawler.base import BaseCrawler |
|||
|
|||
|
|||
class TestBaseCrawler(unittest.TestCase): |
|||
"""Tests for the BaseCrawler class.""" |
|||
|
|||
def test_abstract_methods(self): |
|||
"""Test that BaseCrawler defines abstract methods.""" |
|||
# For abstract base classes, we can check if methods are defined |
|||
# but not implemented in the base class |
|||
self.assertTrue(hasattr(BaseCrawler, 'crawl_url')) |
|||
|
|||
def test_crawl_urls(self): |
|||
"""Test the crawl_urls method.""" |
|||
# Create a subclass of BaseCrawler for testing |
|||
class TestCrawler(BaseCrawler): |
|||
def crawl_url(self, url, **kwargs): |
|||
# Mock implementation that returns a list of documents |
|||
from langchain_core.documents import Document |
|||
return [Document( |
|||
page_content=f"Content from {url}", |
|||
metadata={"reference": url, "kwargs": kwargs} |
|||
)] |
|||
|
|||
# Create test URLs |
|||
urls = [ |
|||
"https://example.com", |
|||
"https://example.org", |
|||
"https://example.net" |
|||
] |
|||
|
|||
# Test crawling multiple URLs |
|||
crawler = TestCrawler() |
|||
documents = crawler.crawl_urls(urls, param1="value1") |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(documents), 3) # One document per URL |
|||
|
|||
# Verify each document |
|||
references = [doc.metadata["reference"] for doc in documents] |
|||
for url in urls: |
|||
self.assertIn(url, references) |
|||
|
|||
# Check that kwargs were passed correctly |
|||
for doc in documents: |
|||
self.assertEqual(doc.metadata["kwargs"]["param1"], "value1") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,157 +0,0 @@ |
|||
import unittest |
|||
import asyncio |
|||
from unittest.mock import patch, MagicMock |
|||
import warnings |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.web_crawler import Crawl4AICrawler |
|||
|
|||
|
|||
class TestCrawl4AICrawler(unittest.TestCase): |
|||
"""Tests for the Crawl4AICrawler class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create a mock for the crawl4ai module |
|||
warnings.filterwarnings('ignore', message='coroutine.*never awaited') |
|||
self.crawl4ai_patcher = patch.dict('sys.modules', {'crawl4ai': MagicMock()}) |
|||
self.crawl4ai_patcher.start() |
|||
|
|||
# Create mocks for the classes |
|||
self.mock_async_web_crawler = MagicMock() |
|||
self.mock_browser_config = MagicMock() |
|||
|
|||
# Set up the from_kwargs method |
|||
self.mock_config_instance = MagicMock() |
|||
self.mock_browser_config.from_kwargs.return_value = self.mock_config_instance |
|||
|
|||
# Add the mocks to the crawl4ai module |
|||
import sys |
|||
sys.modules['crawl4ai'].AsyncWebCrawler = self.mock_async_web_crawler |
|||
sys.modules['crawl4ai'].BrowserConfig = self.mock_browser_config |
|||
|
|||
# Set up mock instances |
|||
self.mock_crawler_instance = MagicMock() |
|||
self.mock_async_web_crawler.return_value = self.mock_crawler_instance |
|||
|
|||
# For context manager behavior |
|||
self.mock_crawler_instance.__aenter__.return_value = self.mock_crawler_instance |
|||
self.mock_crawler_instance.__aexit__.return_value = None |
|||
|
|||
# Create test browser_config |
|||
self.test_browser_config = {"headless": True} |
|||
|
|||
# Create the crawler |
|||
self.crawler = Crawl4AICrawler(browser_config=self.test_browser_config) |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.crawl4ai_patcher.stop() |
|||
|
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
# Verify that the browser_config was stored |
|||
self.assertEqual(self.crawler.browser_config, self.test_browser_config) |
|||
|
|||
# Verify that the crawler is not initialized |
|||
self.assertIsNone(self.crawler.crawler) |
|||
|
|||
def test_lazy_init(self): |
|||
"""Test the lazy initialization of the crawler.""" |
|||
# Call _lazy_init method |
|||
self.crawler._lazy_init() |
|||
|
|||
# Verify BrowserConfig.from_kwargs was called |
|||
self.mock_browser_config.from_kwargs.assert_called_once_with(self.test_browser_config) |
|||
|
|||
# Verify AsyncWebCrawler was initialized |
|||
self.mock_async_web_crawler.assert_called_once_with(config=self.mock_config_instance) |
|||
|
|||
# Verify that the crawler is now set |
|||
self.assertEqual(self.crawler.crawler, self.mock_crawler_instance) |
|||
|
|||
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run') |
|||
def test_crawl_url(self, mock_asyncio_run): |
|||
"""Test crawling a single URL.""" |
|||
url = "https://example.com" |
|||
|
|||
# Set up mock document |
|||
mock_document = Document( |
|||
page_content="# Example Page\nThis is a test page.", |
|||
metadata={"reference": url, "title": "Example Page"} |
|||
) |
|||
|
|||
# Configure asyncio.run to return a document |
|||
mock_asyncio_run.return_value = mock_document |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_url(url) |
|||
|
|||
# Verify asyncio.run was called with _async_crawl |
|||
mock_asyncio_run.assert_called_once() |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 1) |
|||
self.assertEqual(documents[0], mock_document) |
|||
|
|||
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run') |
|||
def test_crawl_url_error(self, mock_asyncio_run): |
|||
"""Test error handling when crawling a URL.""" |
|||
url = "https://example.com" |
|||
|
|||
# Configure asyncio.run to raise an exception |
|||
mock_asyncio_run.side_effect = Exception("Test error") |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_url(url) |
|||
|
|||
# Should return empty list on error |
|||
self.assertEqual(documents, []) |
|||
|
|||
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run') |
|||
def test_crawl_urls(self, mock_asyncio_run): |
|||
"""Test crawling multiple URLs.""" |
|||
urls = ["https://example.com", "https://example.org"] |
|||
|
|||
# Set up mock documents |
|||
mock_documents = [ |
|||
Document( |
|||
page_content="# Example Page 1\nThis is test page 1.", |
|||
metadata={"reference": urls[0], "title": "Example Page 1"} |
|||
), |
|||
Document( |
|||
page_content="# Example Page 2\nThis is test page 2.", |
|||
metadata={"reference": urls[1], "title": "Example Page 2"} |
|||
) |
|||
] |
|||
|
|||
# Configure asyncio.run to return documents |
|||
mock_asyncio_run.return_value = mock_documents |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_urls(urls) |
|||
|
|||
# Verify asyncio.run was called with _async_crawl_many |
|||
mock_asyncio_run.assert_called_once() |
|||
|
|||
# Check results |
|||
self.assertEqual(documents, mock_documents) |
|||
|
|||
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run') |
|||
def test_crawl_urls_error(self, mock_asyncio_run): |
|||
"""Test error handling when crawling multiple URLs.""" |
|||
urls = ["https://example.com", "https://example.org"] |
|||
|
|||
# Configure asyncio.run to raise an exception |
|||
mock_asyncio_run.side_effect = Exception("Test error") |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_urls(urls) |
|||
|
|||
# Should return empty list on error |
|||
self.assertEqual(documents, []) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,157 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.web_crawler import DoclingCrawler |
|||
|
|||
|
|||
class TestDoclingCrawler(unittest.TestCase): |
|||
"""Tests for the DoclingCrawler class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mocks for the docling modules |
|||
self.docling_patcher = patch.dict('sys.modules', { |
|||
'docling': MagicMock(), |
|||
'docling.document_converter': MagicMock(), |
|||
'docling_core': MagicMock(), |
|||
'docling_core.transforms': MagicMock(), |
|||
'docling_core.transforms.chunker': MagicMock() |
|||
}) |
|||
self.docling_patcher.start() |
|||
|
|||
# Create mocks for the classes |
|||
self.mock_document_converter = MagicMock() |
|||
self.mock_hierarchical_chunker = MagicMock() |
|||
|
|||
# Add the mocks to the modules |
|||
import sys |
|||
sys.modules['docling.document_converter'].DocumentConverter = self.mock_document_converter |
|||
sys.modules['docling_core.transforms.chunker'].HierarchicalChunker = self.mock_hierarchical_chunker |
|||
|
|||
# Set up mock instances |
|||
self.mock_converter_instance = MagicMock() |
|||
self.mock_chunker_instance = MagicMock() |
|||
self.mock_document_converter.return_value = self.mock_converter_instance |
|||
self.mock_hierarchical_chunker.return_value = self.mock_chunker_instance |
|||
|
|||
# Create the crawler |
|||
self.crawler = DoclingCrawler() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.docling_patcher.stop() |
|||
|
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
# Verify instances were created |
|||
self.mock_document_converter.assert_called_once() |
|||
self.mock_hierarchical_chunker.assert_called_once() |
|||
|
|||
# Check that the instances were assigned correctly |
|||
self.assertEqual(self.crawler.converter, self.mock_converter_instance) |
|||
self.assertEqual(self.crawler.chunker, self.mock_chunker_instance) |
|||
|
|||
def test_crawl_url(self): |
|||
"""Test crawling a URL.""" |
|||
url = "https://example.com" |
|||
|
|||
# Set up mock document and chunks |
|||
mock_document = MagicMock() |
|||
mock_conversion_result = MagicMock() |
|||
mock_conversion_result.document = mock_document |
|||
|
|||
# Set up three mock chunks |
|||
mock_chunks = [] |
|||
for i in range(3): |
|||
chunk = MagicMock() |
|||
chunk.text = f"Chunk {i} content" |
|||
mock_chunks.append(chunk) |
|||
|
|||
# Configure mock converter and chunker |
|||
self.mock_converter_instance.convert.return_value = mock_conversion_result |
|||
self.mock_chunker_instance.chunk.return_value = mock_chunks |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_url(url) |
|||
|
|||
# Verify converter was called correctly |
|||
self.mock_converter_instance.convert.assert_called_once_with(url) |
|||
|
|||
# Verify chunker was called correctly |
|||
self.mock_chunker_instance.chunk.assert_called_once_with(mock_document) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 3) |
|||
|
|||
# Check each document |
|||
for i, document in enumerate(documents): |
|||
self.assertEqual(document.page_content, f"Chunk {i} content") |
|||
self.assertEqual(document.metadata["reference"], url) |
|||
self.assertEqual(document.metadata["text"], f"Chunk {i} content") |
|||
|
|||
def test_crawl_url_error(self): |
|||
"""Test error handling when crawling a URL.""" |
|||
url = "https://example.com" |
|||
|
|||
# Configure converter to raise an exception |
|||
self.mock_converter_instance.convert.side_effect = Exception("Test error") |
|||
|
|||
# Verify that the error is propagated |
|||
with self.assertRaises(IOError): |
|||
self.crawler.crawl_url(url) |
|||
|
|||
def test_supported_file_types(self): |
|||
"""Test the supported_file_types property.""" |
|||
file_types = self.crawler.supported_file_types |
|||
|
|||
# Check that all expected file types are included |
|||
expected_types = [ |
|||
"pdf", "docx", "xlsx", "pptx", "md", "adoc", "asciidoc", |
|||
"html", "xhtml", "csv", "png", "jpg", "jpeg", "tif", "tiff", "bmp" |
|||
] |
|||
|
|||
for file_type in expected_types: |
|||
self.assertIn(file_type, file_types) |
|||
|
|||
# Check that the count matches |
|||
self.assertEqual(len(file_types), len(expected_types)) |
|||
|
|||
def test_crawl_urls(self): |
|||
"""Test crawling multiple URLs.""" |
|||
urls = ["https://example.com", "https://example.org"] |
|||
|
|||
# Set up mock document and chunks for each URL |
|||
mock_document = MagicMock() |
|||
mock_conversion_result = MagicMock() |
|||
mock_conversion_result.document = mock_document |
|||
|
|||
# Set up one mock chunk per URL |
|||
mock_chunk = MagicMock() |
|||
mock_chunk.text = "Test chunk content" |
|||
|
|||
# Configure mock converter and chunker |
|||
self.mock_converter_instance.convert.return_value = mock_conversion_result |
|||
self.mock_chunker_instance.chunk.return_value = [mock_chunk] |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_urls(urls) |
|||
|
|||
# Verify converter was called for each URL |
|||
self.assertEqual(self.mock_converter_instance.convert.call_count, 2) |
|||
|
|||
# Verify chunker was called for each document |
|||
self.assertEqual(self.mock_chunker_instance.chunk.call_count, 2) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Each URL should have generated one document (with one chunk) |
|||
for document in documents: |
|||
self.assertEqual(document.page_content, "Test chunk content") |
|||
self.assertIn(document.metadata["reference"], urls) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,135 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.web_crawler import FireCrawlCrawler |
|||
|
|||
|
|||
class TestFireCrawlCrawler(unittest.TestCase): |
|||
"""Tests for the FireCrawlCrawler class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Patch the environment variable |
|||
self.env_patcher = patch.dict('os.environ', {'FIRECRAWL_API_KEY': 'fake-api-key'}) |
|||
self.env_patcher.start() |
|||
|
|||
# Create a mock for the FirecrawlApp |
|||
self.firecrawl_app_patcher = patch('deepsearcher.loader.web_crawler.firecrawl_crawler.FirecrawlApp') |
|||
self.mock_firecrawl_app = self.firecrawl_app_patcher.start() |
|||
|
|||
# Set up mock instances |
|||
self.mock_app_instance = MagicMock() |
|||
self.mock_firecrawl_app.return_value = self.mock_app_instance |
|||
|
|||
# Create the crawler |
|||
self.crawler = FireCrawlCrawler() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.env_patcher.stop() |
|||
self.firecrawl_app_patcher.stop() |
|||
|
|||
def test_init(self): |
|||
"""Test initialization.""" |
|||
self.assertIsNone(self.crawler.app) |
|||
|
|||
def test_crawl_url_single_page(self): |
|||
"""Test crawling a single URL.""" |
|||
url = "https://example.com" |
|||
|
|||
# Set up mock response for scrape_url |
|||
mock_response = MagicMock() |
|||
mock_response.model_dump.return_value = { |
|||
"markdown": "# Example Page\nThis is a test page.", |
|||
"metadata": {"title": "Example Page", "url": url} |
|||
} |
|||
self.mock_app_instance.scrape_url.return_value = mock_response |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_url(url) |
|||
|
|||
# Verify FirecrawlApp was initialized |
|||
self.mock_firecrawl_app.assert_called_once_with(api_key='fake-api-key') |
|||
|
|||
# Verify scrape_url was called correctly |
|||
self.mock_app_instance.scrape_url.assert_called_once_with(url=url, formats=["markdown"]) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 1) |
|||
document = documents[0] |
|||
self.assertEqual(document.page_content, "# Example Page\nThis is a test page.") |
|||
self.assertEqual(document.metadata["reference"], url) |
|||
self.assertEqual(document.metadata["title"], "Example Page") |
|||
|
|||
def test_crawl_url_multiple_pages(self): |
|||
"""Test crawling multiple pages recursively.""" |
|||
url = "https://example.com" |
|||
max_depth = 3 |
|||
limit = 10 |
|||
|
|||
# Set up mock response for crawl_url |
|||
mock_response = MagicMock() |
|||
mock_response.model_dump.return_value = { |
|||
"data": [ |
|||
{ |
|||
"markdown": "# Page 1\nContent 1", |
|||
"metadata": {"title": "Page 1", "url": "https://example.com/page1"} |
|||
}, |
|||
{ |
|||
"markdown": "# Page 2\nContent 2", |
|||
"metadata": {"title": "Page 2", "url": "https://example.com/page2"} |
|||
} |
|||
] |
|||
} |
|||
self.mock_app_instance.crawl_url.return_value = mock_response |
|||
|
|||
# Call the method |
|||
documents = self.crawler.crawl_url(url, max_depth=max_depth, limit=limit) |
|||
|
|||
# Verify FirecrawlApp was initialized |
|||
self.mock_firecrawl_app.assert_called_once_with(api_key='fake-api-key') |
|||
|
|||
# Verify crawl_url was called correctly |
|||
self.mock_app_instance.crawl_url.assert_called_once() |
|||
call_kwargs = self.mock_app_instance.crawl_url.call_args[1] |
|||
self.assertEqual(call_kwargs['url'], url) |
|||
self.assertEqual(call_kwargs['max_depth'], max_depth) |
|||
self.assertEqual(call_kwargs['limit'], limit) |
|||
|
|||
# Check results |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Check first document |
|||
self.assertEqual(documents[0].page_content, "# Page 1\nContent 1") |
|||
self.assertEqual(documents[0].metadata["reference"], "https://example.com/page1") |
|||
self.assertEqual(documents[0].metadata["title"], "Page 1") |
|||
|
|||
# Check second document |
|||
self.assertEqual(documents[1].page_content, "# Page 2\nContent 2") |
|||
self.assertEqual(documents[1].metadata["reference"], "https://example.com/page2") |
|||
self.assertEqual(documents[1].metadata["title"], "Page 2") |
|||
|
|||
def test_crawl_url_with_default_params(self): |
|||
"""Test crawling with default parameters.""" |
|||
url = "https://example.com" |
|||
|
|||
# Set up mock response for crawl_url |
|||
mock_response = MagicMock() |
|||
mock_response.model_dump.return_value = {"data": []} |
|||
self.mock_app_instance.crawl_url.return_value = mock_response |
|||
|
|||
# Call the method with only max_depth |
|||
self.crawler.crawl_url(url, max_depth=2) |
|||
|
|||
# Verify default values were used |
|||
call_kwargs = self.mock_app_instance.crawl_url.call_args[1] |
|||
self.assertEqual(call_kwargs['limit'], 20) # Default limit |
|||
self.assertEqual(call_kwargs['max_depth'], 2) # Provided max_depth |
|||
self.assertEqual(call_kwargs['allow_backward_links'], False) # Default allow_backward_links |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,112 +0,0 @@ |
|||
import unittest |
|||
import os |
|||
from unittest.mock import patch, MagicMock |
|||
|
|||
import requests |
|||
from langchain_core.documents import Document |
|||
|
|||
from deepsearcher.loader.web_crawler import JinaCrawler |
|||
|
|||
|
|||
class TestJinaCrawler(unittest.TestCase): |
|||
"""Tests for the JinaCrawler class.""" |
|||
|
|||
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"}) |
|||
def test_init_with_token(self): |
|||
"""Test initialization with API token in environment.""" |
|||
crawler = JinaCrawler() |
|||
self.assertEqual(crawler.jina_api_token, "fake-token") |
|||
|
|||
@patch.dict(os.environ, {"JINAAI_API_KEY": "fake-key"}) |
|||
def test_init_with_alternative_key(self): |
|||
"""Test initialization with alternative API key in environment.""" |
|||
crawler = JinaCrawler() |
|||
self.assertEqual(crawler.jina_api_token, "fake-key") |
|||
|
|||
@patch.dict(os.environ, {}, clear=True) |
|||
def test_init_without_token(self): |
|||
"""Test initialization without API token raises ValueError.""" |
|||
with self.assertRaises(ValueError): |
|||
JinaCrawler() |
|||
|
|||
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"}) |
|||
@patch("requests.get") |
|||
def test_crawl_url(self, mock_get): |
|||
"""Test crawling a URL.""" |
|||
# Set up the mock response |
|||
mock_response = MagicMock() |
|||
mock_response.text = "# Markdown Content\nThis is a test." |
|||
mock_response.status_code = 200 |
|||
mock_response.headers = {"Content-Type": "text/markdown"} |
|||
mock_get.return_value = mock_response |
|||
|
|||
# Create the crawler and crawl a test URL |
|||
crawler = JinaCrawler() |
|||
url = "https://example.com" |
|||
documents = crawler.crawl_url(url) |
|||
|
|||
# Check that requests.get was called correctly |
|||
mock_get.assert_called_once_with( |
|||
f"https://r.jina.ai/{url}", |
|||
headers={ |
|||
"Authorization": "Bearer fake-token", |
|||
"X-Return-Format": "markdown", |
|||
} |
|||
) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(documents), 1) |
|||
document = documents[0] |
|||
|
|||
# Check the content |
|||
self.assertEqual(document.page_content, mock_response.text) |
|||
|
|||
# Check the metadata |
|||
self.assertEqual(document.metadata["reference"], url) |
|||
self.assertEqual(document.metadata["status_code"], 200) |
|||
self.assertEqual(document.metadata["headers"], {"Content-Type": "text/markdown"}) |
|||
|
|||
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"}) |
|||
@patch("requests.get") |
|||
def test_crawl_url_http_error(self, mock_get): |
|||
"""Test handling of HTTP errors.""" |
|||
# Set up the mock response to raise an HTTPError |
|||
mock_get.side_effect = requests.exceptions.HTTPError("404 Client Error") |
|||
|
|||
# Create the crawler |
|||
crawler = JinaCrawler() |
|||
|
|||
# Crawl a URL and check that the error is propagated |
|||
with self.assertRaises(requests.exceptions.HTTPError): |
|||
crawler.crawl_url("https://example.com") |
|||
|
|||
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"}) |
|||
@patch("requests.get") |
|||
def test_crawl_urls(self, mock_get): |
|||
"""Test crawling multiple URLs.""" |
|||
# Set up the mock response |
|||
mock_response = MagicMock() |
|||
mock_response.text = "# Markdown Content\nThis is a test." |
|||
mock_response.status_code = 200 |
|||
mock_response.headers = {"Content-Type": "text/markdown"} |
|||
mock_get.return_value = mock_response |
|||
|
|||
# Create the crawler and crawl multiple URLs |
|||
crawler = JinaCrawler() |
|||
urls = ["https://example.com", "https://example.org"] |
|||
documents = crawler.crawl_urls(urls) |
|||
|
|||
# Check that requests.get was called twice |
|||
self.assertEqual(mock_get.call_count, 2) |
|||
|
|||
# Check the results |
|||
self.assertEqual(len(documents), 2) |
|||
|
|||
# Check that each document has the correct reference |
|||
references = [doc.metadata["reference"] for doc in documents] |
|||
for url in urls: |
|||
self.assertIn(url, references) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,172 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock, call |
|||
import logging |
|||
from termcolor import colored |
|||
|
|||
from deepsearcher.utils import log |
|||
from deepsearcher.utils.log import ColoredFormatter |
|||
|
|||
|
|||
class TestColoredFormatter(unittest.TestCase): |
|||
"""Tests for the ColoredFormatter class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
self.formatter = ColoredFormatter("%(levelname)s - %(message)s") |
|||
|
|||
def test_format_debug(self): |
|||
"""Test formatting debug level messages.""" |
|||
record = logging.LogRecord( |
|||
"test", logging.DEBUG, "test.py", 1, "Debug message", (), None |
|||
) |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("DEBUG - Debug message", "cyan") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
def test_format_info(self): |
|||
"""Test formatting info level messages.""" |
|||
record = logging.LogRecord( |
|||
"test", logging.INFO, "test.py", 1, "Info message", (), None |
|||
) |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("INFO - Info message", "green") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
def test_format_warning(self): |
|||
"""Test formatting warning level messages.""" |
|||
record = logging.LogRecord( |
|||
"test", logging.WARNING, "test.py", 1, "Warning message", (), None |
|||
) |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("WARNING - Warning message", "yellow") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
def test_format_error(self): |
|||
"""Test formatting error level messages.""" |
|||
record = logging.LogRecord( |
|||
"test", logging.ERROR, "test.py", 1, "Error message", (), None |
|||
) |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("ERROR - Error message", "red") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
def test_format_critical(self): |
|||
"""Test formatting critical level messages.""" |
|||
record = logging.LogRecord( |
|||
"test", logging.CRITICAL, "test.py", 1, "Critical message", (), None |
|||
) |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("CRITICAL - Critical message", "magenta") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
def test_format_unknown_level(self): |
|||
"""Test formatting messages with unknown log level.""" |
|||
record = logging.LogRecord( |
|||
"test", 60, "test.py", 1, "Custom level message", (), None |
|||
) |
|||
record.levelname = "CUSTOM" |
|||
formatted = self.formatter.format(record) |
|||
expected = colored("CUSTOM - Custom level message", "white") |
|||
self.assertEqual(formatted, expected) |
|||
|
|||
|
|||
class TestLogFunctions(unittest.TestCase): |
|||
"""Tests for the logging functions.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Reset dev mode before each test |
|||
log.set_dev_mode(False) |
|||
|
|||
# Create mock for dev_logger |
|||
self.mock_dev_logger = MagicMock() |
|||
self.dev_logger_patcher = patch("deepsearcher.utils.log.dev_logger", self.mock_dev_logger) |
|||
self.dev_logger_patcher.start() |
|||
|
|||
# Create mock for progress_logger |
|||
self.mock_progress_logger = MagicMock() |
|||
self.progress_logger_patcher = patch("deepsearcher.utils.log.progress_logger", self.mock_progress_logger) |
|||
self.progress_logger_patcher.start() |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.dev_logger_patcher.stop() |
|||
self.progress_logger_patcher.stop() |
|||
|
|||
def test_set_dev_mode(self): |
|||
"""Test setting development mode.""" |
|||
self.assertFalse(log.dev_mode) |
|||
log.set_dev_mode(True) |
|||
self.assertTrue(log.dev_mode) |
|||
log.set_dev_mode(False) |
|||
self.assertFalse(log.dev_mode) |
|||
|
|||
def test_set_level(self): |
|||
"""Test setting log level.""" |
|||
log.set_level(logging.DEBUG) |
|||
self.mock_dev_logger.setLevel.assert_called_once_with(logging.DEBUG) |
|||
|
|||
def test_debug_in_dev_mode(self): |
|||
"""Test debug logging in dev mode.""" |
|||
log.set_dev_mode(True) |
|||
log.debug("Test debug") |
|||
self.mock_dev_logger.debug.assert_called_once_with("Test debug") |
|||
|
|||
def test_debug_not_in_dev_mode(self): |
|||
"""Test debug logging not in dev mode.""" |
|||
log.set_dev_mode(False) |
|||
log.debug("Test debug") |
|||
self.mock_dev_logger.debug.assert_not_called() |
|||
|
|||
def test_info_in_dev_mode(self): |
|||
"""Test info logging in dev mode.""" |
|||
log.set_dev_mode(True) |
|||
log.info("Test info") |
|||
self.mock_dev_logger.info.assert_called_once_with("Test info") |
|||
|
|||
def test_info_not_in_dev_mode(self): |
|||
"""Test info logging not in dev mode.""" |
|||
log.set_dev_mode(False) |
|||
log.info("Test info") |
|||
self.mock_dev_logger.info.assert_not_called() |
|||
|
|||
def test_warning_in_dev_mode(self): |
|||
"""Test warning logging in dev mode.""" |
|||
log.set_dev_mode(True) |
|||
log.warning("Test warning") |
|||
self.mock_dev_logger.warning.assert_called_once_with("Test warning") |
|||
|
|||
def test_warning_not_in_dev_mode(self): |
|||
"""Test warning logging not in dev mode.""" |
|||
log.set_dev_mode(False) |
|||
log.warning("Test warning") |
|||
self.mock_dev_logger.warning.assert_not_called() |
|||
|
|||
def test_error_in_dev_mode(self): |
|||
"""Test error logging in dev mode.""" |
|||
log.set_dev_mode(True) |
|||
log.error("Test error") |
|||
self.mock_dev_logger.error.assert_called_once_with("Test error") |
|||
|
|||
def test_error_not_in_dev_mode(self): |
|||
"""Test error logging not in dev mode.""" |
|||
log.set_dev_mode(False) |
|||
log.error("Test error") |
|||
self.mock_dev_logger.error.assert_not_called() |
|||
|
|||
def test_critical(self): |
|||
"""Test critical logging and exception raising.""" |
|||
with self.assertRaises(RuntimeError) as context: |
|||
log.critical("Test critical") |
|||
|
|||
self.mock_dev_logger.critical.assert_called_once_with("Test critical") |
|||
self.assertEqual(str(context.exception), "Test critical") |
|||
|
|||
def test_color_print(self): |
|||
"""Test color print function.""" |
|||
log.color_print("Test message") |
|||
self.mock_progress_logger.info.assert_called_once_with("Test message") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,237 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import numpy as np |
|||
import sys |
|||
|
|||
from deepsearcher.vector_db import AzureSearch |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
|
|||
class TestAzureSearch(unittest.TestCase): |
|||
"""Tests for the Azure Search vector database implementation.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock modules |
|||
self.mock_azure = MagicMock() |
|||
self.mock_search = MagicMock() |
|||
self.mock_indexes = MagicMock() |
|||
self.mock_models = MagicMock() |
|||
self.mock_credentials = MagicMock() |
|||
self.mock_exceptions = MagicMock() |
|||
|
|||
# Setup nested structure |
|||
self.mock_azure.search = self.mock_search |
|||
self.mock_search.documents = self.mock_search |
|||
self.mock_search.documents.indexes = self.mock_indexes |
|||
self.mock_indexes.models = self.mock_models |
|||
self.mock_azure.core = self.mock_credentials |
|||
self.mock_azure.core.credentials = self.mock_credentials |
|||
self.mock_azure.core.exceptions = self.mock_exceptions |
|||
|
|||
# Mock specific models needed for init_collection |
|||
self.mock_models.SearchableField = MagicMock() |
|||
self.mock_models.SimpleField = MagicMock() |
|||
self.mock_models.SearchField = MagicMock() |
|||
self.mock_models.SearchIndex = MagicMock() |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', { |
|||
'azure': self.mock_azure, |
|||
'azure.core': self.mock_credentials, |
|||
'azure.core.credentials': self.mock_credentials, |
|||
'azure.core.exceptions': self.mock_exceptions, |
|||
'azure.search': self.mock_search, |
|||
'azure.search.documents': self.mock_search, |
|||
'azure.search.documents.indexes': self.mock_indexes, |
|||
'azure.search.documents.indexes.models': self.mock_models |
|||
}) |
|||
|
|||
# Start the patcher |
|||
self.module_patcher.start() |
|||
|
|||
# Import after mocking |
|||
from deepsearcher.vector_db import AzureSearch |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
self.AzureSearch = AzureSearch |
|||
self.RetrievalResult = RetrievalResult |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init(self): |
|||
"""Test basic initialization.""" |
|||
# Setup mock |
|||
mock_client = MagicMock() |
|||
self.mock_search.SearchClient.return_value = mock_client |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
# Verify initialization |
|||
self.assertEqual(azure_search.index_name, "test-index") |
|||
self.assertEqual(azure_search.endpoint, "https://test-search.search.windows.net") |
|||
self.assertEqual(azure_search.api_key, "test-key") |
|||
self.assertEqual(azure_search.vector_field, "content_vector") |
|||
self.assertIsNotNone(azure_search.client) |
|||
|
|||
def test_init_collection(self): |
|||
"""Test collection initialization.""" |
|||
# Setup mock |
|||
mock_index_client = MagicMock() |
|||
self.mock_indexes.SearchIndexClient.return_value = mock_index_client |
|||
mock_index_client.create_index.return_value = None |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
azure_search.init_collection() |
|||
self.assertTrue(mock_index_client.create_index.called) |
|||
|
|||
def test_insert_data(self): |
|||
"""Test inserting data.""" |
|||
# Setup mock |
|||
mock_client = MagicMock() |
|||
self.mock_search.SearchClient.return_value = mock_client |
|||
|
|||
# Mock successful upload result |
|||
mock_result = [MagicMock(succeeded=True) for _ in range(2)] |
|||
mock_client.upload_documents.return_value = mock_result |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
# Create test data |
|||
d = 1536 # Azure Search expects 1536 dimensions |
|||
rng = np.random.default_rng(seed=42) |
|||
|
|||
test_docs = [ |
|||
{ |
|||
"text": "hello world", |
|||
"vector": rng.random(d).tolist(), |
|||
"id": "doc1" |
|||
}, |
|||
{ |
|||
"text": "hello azure search", |
|||
"vector": rng.random(d).tolist(), |
|||
"id": "doc2" |
|||
} |
|||
] |
|||
|
|||
results = azure_search.insert_data(documents=test_docs) |
|||
self.assertEqual(len(results), 2) |
|||
self.assertTrue(all(results)) |
|||
|
|||
def test_search_data(self): |
|||
"""Test search functionality.""" |
|||
# Setup mock |
|||
mock_client = MagicMock() |
|||
self.mock_search.SearchClient.return_value = mock_client |
|||
|
|||
# Mock search results |
|||
d = 1536 |
|||
rng = np.random.default_rng(seed=42) |
|||
|
|||
mock_results = MagicMock() |
|||
mock_results.results = [ |
|||
{ |
|||
"content": "hello world", |
|||
"id": "doc1", |
|||
"@search.score": 0.95 |
|||
}, |
|||
{ |
|||
"content": "hello azure search", |
|||
"id": "doc2", |
|||
"@search.score": 0.85 |
|||
} |
|||
] |
|||
mock_client._client.documents.search_post.return_value = mock_results |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
# Test search |
|||
query_vector = rng.random(d).tolist() |
|||
results = azure_search.search_data( |
|||
collection="test-index", |
|||
vector=query_vector, |
|||
top_k=2 |
|||
) |
|||
|
|||
self.assertIsInstance(results, list) |
|||
self.assertEqual(len(results), 2) |
|||
# Verify results are RetrievalResult objects |
|||
for result in results: |
|||
self.assertIsInstance(result, self.RetrievalResult) |
|||
|
|||
def test_clear_db(self): |
|||
"""Test clearing database.""" |
|||
# Setup mock |
|||
mock_client = MagicMock() |
|||
self.mock_search.SearchClient.return_value = mock_client |
|||
|
|||
# Mock search results for documents to delete |
|||
mock_client.search.return_value = [ |
|||
{"id": "doc1"}, |
|||
{"id": "doc2"} |
|||
] |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
deleted_count = azure_search.clear_db() |
|||
self.assertEqual(deleted_count, 2) |
|||
|
|||
def test_list_collections(self): |
|||
"""Test listing collections.""" |
|||
# Setup mock |
|||
mock_index_client = MagicMock() |
|||
self.mock_indexes.SearchIndexClient.return_value = mock_index_client |
|||
|
|||
# Mock list_indexes response |
|||
mock_index1 = MagicMock() |
|||
mock_index1.name = "test-index-1" |
|||
mock_index1.fields = ["field1", "field2"] |
|||
|
|||
mock_index2 = MagicMock() |
|||
mock_index2.name = "test-index-2" |
|||
mock_index2.fields = ["field1", "field2", "field3"] |
|||
|
|||
mock_index_client.list_indexes.return_value = [mock_index1, mock_index2] |
|||
|
|||
azure_search = self.AzureSearch( |
|||
endpoint="https://test-search.search.windows.net", |
|||
index_name="test-index", |
|||
api_key="test-key", |
|||
vector_field="content_vector" |
|||
) |
|||
|
|||
collections = azure_search.list_collections() |
|||
self.assertIsInstance(collections, list) |
|||
self.assertEqual(len(collections), 2) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,157 +0,0 @@ |
|||
import unittest |
|||
import numpy as np |
|||
from typing import List |
|||
|
|||
from deepsearcher.vector_db.base import ( |
|||
RetrievalResult, |
|||
deduplicate_results, |
|||
CollectionInfo, |
|||
BaseVectorDB, |
|||
) |
|||
from deepsearcher.loader.splitter import Chunk |
|||
|
|||
|
|||
class TestRetrievalResult(unittest.TestCase): |
|||
"""Tests for the RetrievalResult class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
self.embedding = np.array([0.1, 0.2, 0.3]) |
|||
self.text = "Test text" |
|||
self.reference = "test.txt" |
|||
self.metadata = {"key": "value"} |
|||
self.score = 0.95 |
|||
|
|||
def test_init(self): |
|||
"""Test initialization of RetrievalResult.""" |
|||
result = RetrievalResult( |
|||
embedding=self.embedding, |
|||
text=self.text, |
|||
reference=self.reference, |
|||
metadata=self.metadata, |
|||
score=self.score, |
|||
) |
|||
|
|||
self.assertTrue(np.array_equal(result.embedding, self.embedding)) |
|||
self.assertEqual(result.text, self.text) |
|||
self.assertEqual(result.reference, self.reference) |
|||
self.assertEqual(result.metadata, self.metadata) |
|||
self.assertEqual(result.score, self.score) |
|||
|
|||
def test_init_default_score(self): |
|||
"""Test initialization of RetrievalResult with default score.""" |
|||
result = RetrievalResult( |
|||
embedding=self.embedding, |
|||
text=self.text, |
|||
reference=self.reference, |
|||
metadata=self.metadata, |
|||
) |
|||
self.assertEqual(result.score, 0.0) |
|||
|
|||
def test_repr(self): |
|||
"""Test string representation of RetrievalResult.""" |
|||
result = RetrievalResult( |
|||
embedding=self.embedding, |
|||
text=self.text, |
|||
reference=self.reference, |
|||
metadata=self.metadata, |
|||
score=self.score, |
|||
) |
|||
expected = f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}" |
|||
self.assertEqual(repr(result), expected) |
|||
|
|||
|
|||
class TestDeduplicateResults(unittest.TestCase): |
|||
"""Tests for the deduplicate_results function.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
self.embedding1 = np.array([0.1, 0.2, 0.3]) |
|||
self.embedding2 = np.array([0.4, 0.5, 0.6]) |
|||
self.text1 = "Text 1" |
|||
self.text2 = "Text 2" |
|||
self.reference = "test.txt" |
|||
self.metadata = {"key": "value"} |
|||
|
|||
def test_no_duplicates(self): |
|||
"""Test deduplication with no duplicate results.""" |
|||
results = [ |
|||
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata), |
|||
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata), |
|||
] |
|||
deduplicated = deduplicate_results(results) |
|||
self.assertEqual(len(deduplicated), 2) |
|||
self.assertEqual(deduplicated, results) |
|||
|
|||
def test_with_duplicates(self): |
|||
"""Test deduplication with duplicate results.""" |
|||
results = [ |
|||
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata), |
|||
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata), |
|||
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata), |
|||
] |
|||
deduplicated = deduplicate_results(results) |
|||
self.assertEqual(len(deduplicated), 2) |
|||
self.assertEqual(deduplicated[0].text, self.text1) |
|||
self.assertEqual(deduplicated[1].text, self.text2) |
|||
|
|||
def test_empty_list(self): |
|||
"""Test deduplication with empty list.""" |
|||
results = [] |
|||
deduplicated = deduplicate_results(results) |
|||
self.assertEqual(len(deduplicated), 0) |
|||
|
|||
|
|||
class TestCollectionInfo(unittest.TestCase): |
|||
"""Tests for the CollectionInfo class.""" |
|||
|
|||
def test_init(self): |
|||
"""Test initialization of CollectionInfo.""" |
|||
name = "test_collection" |
|||
description = "Test collection description" |
|||
collection_info = CollectionInfo(name, description) |
|||
|
|||
self.assertEqual(collection_info.collection_name, name) |
|||
self.assertEqual(collection_info.description, description) |
|||
|
|||
|
|||
class MockVectorDB(BaseVectorDB): |
|||
"""Mock implementation of BaseVectorDB for testing.""" |
|||
|
|||
def init_collection(self, dim, collection, description, force_new_collection=False, *args, **kwargs): |
|||
pass |
|||
|
|||
def insert_data(self, collection, chunks, *args, **kwargs): |
|||
pass |
|||
|
|||
def search_data(self, collection, vector, *args, **kwargs) -> List[RetrievalResult]: |
|||
return [] |
|||
|
|||
def clear_db(self, *args, **kwargs): |
|||
pass |
|||
|
|||
|
|||
class TestBaseVectorDB(unittest.TestCase): |
|||
"""Tests for the BaseVectorDB class.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
self.db = MockVectorDB() |
|||
|
|||
def test_init_default(self): |
|||
"""Test initialization with default collection name.""" |
|||
self.assertEqual(self.db.default_collection, "deepsearcher") |
|||
|
|||
def test_init_custom_collection(self): |
|||
"""Test initialization with custom collection name.""" |
|||
custom_collection = "custom_collection" |
|||
db = MockVectorDB(default_collection=custom_collection) |
|||
self.assertEqual(db.default_collection, custom_collection) |
|||
|
|||
def test_list_collections_default(self): |
|||
"""Test default list_collections implementation.""" |
|||
self.assertIsNone(self.db.list_collections()) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,141 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import numpy as np |
|||
import warnings |
|||
|
|||
# Filter out the pkg_resources deprecation warning from milvus_lite |
|||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") |
|||
|
|||
from deepsearcher.vector_db import Milvus |
|||
from deepsearcher.loader.splitter import Chunk |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
|
|||
class TestMilvus(unittest.TestCase): |
|||
"""Simple tests for the Milvus vector database implementation.""" |
|||
|
|||
def test_init(self): |
|||
"""Test basic initialization.""" |
|||
milvus = Milvus( |
|||
default_collection="test_collection", |
|||
uri="./milvus.db", |
|||
hybrid=False |
|||
) |
|||
|
|||
# Verify initialization - just check basic properties |
|||
self.assertEqual(milvus.default_collection, "test_collection") |
|||
self.assertFalse(milvus.hybrid) |
|||
self.assertIsNotNone(milvus.client) |
|||
|
|||
def test_init_collection(self): |
|||
"""Test collection initialization.""" |
|||
milvus = Milvus(uri="./milvus.db") |
|||
|
|||
# Test collection initialization |
|||
d = 8 |
|||
collection = "hello_deepsearcher" |
|||
|
|||
try: |
|||
milvus.init_collection(dim=d, collection=collection) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "init_collection should work") |
|||
|
|||
def test_insert_data_with_retrieval_results(self): |
|||
"""Test inserting data using RetrievalResult objects.""" |
|||
milvus = Milvus(uri="./milvus.db") |
|||
|
|||
# Create test data |
|||
d = 8 |
|||
collection = "hello_deepsearcher" |
|||
rng = np.random.default_rng(seed=19530) |
|||
|
|||
# Create RetrievalResult objects |
|||
test_data = [ |
|||
RetrievalResult( |
|||
embedding=rng.random((1, d))[0], |
|||
text="hello world", |
|||
reference="local file: hi.txt", |
|||
metadata={"a": 1}, |
|||
), |
|||
RetrievalResult( |
|||
embedding=rng.random((1, d))[0], |
|||
text="hello milvus", |
|||
reference="local file: hi.txt", |
|||
metadata={"a": 1}, |
|||
), |
|||
] |
|||
|
|||
try: |
|||
milvus.insert_data(collection=collection, chunks=test_data) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "insert_data should work with RetrievalResult objects") |
|||
|
|||
def test_search_data(self): |
|||
"""Test search functionality.""" |
|||
milvus = Milvus(uri="./milvus.db") |
|||
|
|||
# Test search |
|||
d = 8 |
|||
collection = "hello_deepsearcher" |
|||
rng = np.random.default_rng(seed=19530) |
|||
query_vector = rng.random((1, d))[0] |
|||
|
|||
try: |
|||
top_2 = milvus.search_data( |
|||
collection=collection, |
|||
vector=query_vector, |
|||
top_k=2 |
|||
) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "search_data should work") |
|||
if test_passed: |
|||
self.assertIsInstance(top_2, list) |
|||
# Note: In an empty collection, we might not get 2 results |
|||
self.assertIsInstance(top_2[0], RetrievalResult) if top_2 else None |
|||
|
|||
def test_clear_collection(self): |
|||
"""Test clearing collection.""" |
|||
milvus = Milvus(uri="./milvus.db") |
|||
|
|||
collection = "hello_deepsearcher" |
|||
|
|||
try: |
|||
milvus.clear_db(collection=collection) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "clear_db should work") |
|||
|
|||
def test_list_collections(self): |
|||
"""Test listing collections.""" |
|||
milvus = Milvus(uri="./milvus.db") |
|||
|
|||
try: |
|||
collections = milvus.list_collections() |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "list_collections should work") |
|||
if test_passed: |
|||
self.assertIsInstance(collections, list) |
|||
self.assertGreaterEqual(len(collections), 0) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,255 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import numpy as np |
|||
import sys |
|||
import json |
|||
|
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
from deepsearcher.loader.splitter import Chunk |
|||
import logging |
|||
logging.disable(logging.CRITICAL) |
|||
|
|||
class TestOracleDB(unittest.TestCase): |
|||
"""Tests for the Oracle vector database implementation.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock modules |
|||
self.mock_oracledb = MagicMock() |
|||
|
|||
# Setup mock DB_TYPE_VECTOR |
|||
self.mock_oracledb.DB_TYPE_VECTOR = "VECTOR" |
|||
self.mock_oracledb.defaults = MagicMock() |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', { |
|||
'oracledb': self.mock_oracledb |
|||
}) |
|||
|
|||
# Start the patcher |
|||
self.module_patcher.start() |
|||
|
|||
# Import after mocking |
|||
from deepsearcher.vector_db import OracleDB |
|||
self.OracleDB = OracleDB |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
def test_init(self): |
|||
"""Test basic initialization.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd", |
|||
default_collection="test_collection" |
|||
) |
|||
|
|||
# Verify initialization |
|||
self.assertEqual(oracle_db.default_collection, "test_collection") |
|||
self.assertIsNotNone(oracle_db.client) |
|||
self.mock_oracledb.create_pool.assert_called_once() |
|||
self.assertTrue(self.mock_oracledb.defaults.fetch_lobs is False) |
|||
|
|||
def test_insert_data(self): |
|||
"""Test inserting data.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
mock_connection = MagicMock() |
|||
mock_cursor = MagicMock() |
|||
|
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
mock_pool.acquire.return_value.__enter__.return_value = mock_connection |
|||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd" |
|||
) |
|||
|
|||
# Create test data |
|||
d = 8 |
|||
rng = np.random.default_rng(seed=42) |
|||
test_chunks = [ |
|||
Chunk( |
|||
embedding=rng.random(d).tolist(), |
|||
text="hello world", |
|||
reference="test.txt", |
|||
metadata={"key": "value1"} |
|||
), |
|||
Chunk( |
|||
embedding=rng.random(d).tolist(), |
|||
text="hello oracle", |
|||
reference="test.txt", |
|||
metadata={"key": "value2"} |
|||
) |
|||
] |
|||
|
|||
oracle_db.insert_data(collection="test_collection", chunks=test_chunks) |
|||
self.assertTrue(mock_cursor.execute.called) |
|||
self.assertTrue(mock_connection.commit.called) |
|||
|
|||
def test_search_data(self): |
|||
"""Test search functionality.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
mock_connection = MagicMock() |
|||
mock_cursor = MagicMock() |
|||
|
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
mock_pool.acquire.return_value.__enter__.return_value = mock_connection |
|||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor |
|||
|
|||
# Mock search results |
|||
mock_cursor.description = [("embedding",), ("text",), ("reference",), ("distance",), ("metadata",)] |
|||
mock_cursor.fetchall.return_value = [ |
|||
( |
|||
np.array([0.1, 0.2, 0.3]), |
|||
"hello world", |
|||
"test.txt", |
|||
0.95, |
|||
json.dumps({"key": "value1"}) |
|||
), |
|||
( |
|||
np.array([0.4, 0.5, 0.6]), |
|||
"hello oracle", |
|||
"test.txt", |
|||
0.85, |
|||
json.dumps({"key": "value2"}) |
|||
) |
|||
] |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd" |
|||
) |
|||
|
|||
# Test search |
|||
d = 8 |
|||
rng = np.random.default_rng(seed=42) |
|||
query_vector = rng.random(d) |
|||
|
|||
results = oracle_db.search_data( |
|||
collection="test_collection", |
|||
vector=query_vector, |
|||
top_k=2 |
|||
) |
|||
|
|||
self.assertIsInstance(results, list) |
|||
self.assertEqual(len(results), 2) |
|||
for result in results: |
|||
self.assertIsInstance(result, RetrievalResult) |
|||
|
|||
def test_list_collections(self): |
|||
"""Test listing collections.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
mock_connection = MagicMock() |
|||
mock_cursor = MagicMock() |
|||
|
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
mock_pool.acquire.return_value.__enter__.return_value = mock_connection |
|||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor |
|||
|
|||
# Mock list_collections response |
|||
mock_cursor.description = [("collection",), ("description",)] |
|||
mock_cursor.fetchall.return_value = [ |
|||
("test_collection_1", "Test collection 1"), |
|||
("test_collection_2", "Test collection 2") |
|||
] |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd" |
|||
) |
|||
|
|||
collections = oracle_db.list_collections() |
|||
self.assertIsInstance(collections, list) |
|||
self.assertEqual(len(collections), 2) |
|||
|
|||
def test_clear_db(self): |
|||
"""Test clearing database.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
mock_connection = MagicMock() |
|||
mock_cursor = MagicMock() |
|||
|
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
mock_pool.acquire.return_value.__enter__.return_value = mock_connection |
|||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd" |
|||
) |
|||
|
|||
oracle_db.clear_db("test_collection") |
|||
self.assertTrue(mock_cursor.execute.called) |
|||
self.assertTrue(mock_connection.commit.called) |
|||
|
|||
def test_has_collection(self): |
|||
"""Test checking if collection exists.""" |
|||
# Setup mock |
|||
mock_pool = MagicMock() |
|||
mock_connection = MagicMock() |
|||
mock_cursor = MagicMock() |
|||
|
|||
self.mock_oracledb.create_pool.return_value = mock_pool |
|||
mock_pool.acquire.return_value.__enter__.return_value = mock_connection |
|||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor |
|||
|
|||
# Mock check_table response first (called during init) |
|||
mock_cursor.description = [("table_name",)] |
|||
mock_cursor.fetchall.return_value = [ |
|||
("DEEPSEARCHER_COLLECTION_INFO",), |
|||
("DEEPSEARCHER_COLLECTION_ITEM",) |
|||
] |
|||
|
|||
oracle_db = self.OracleDB( |
|||
user="test_user", |
|||
password="test_password", |
|||
dsn="test_dsn", |
|||
config_dir="/test/config", |
|||
wallet_location="/test/wallet", |
|||
wallet_password="test_wallet_pwd" |
|||
) |
|||
|
|||
# Now mock has_collection response - collection exists |
|||
mock_cursor.description = [("rowcnt",)] |
|||
mock_cursor.fetchall.return_value = [(1,)] # Return tuple, not dict |
|||
|
|||
result = oracle_db.has_collection("test_collection") |
|||
self.assertTrue(result) |
|||
|
|||
# Test collection doesn't exist |
|||
mock_cursor.fetchall.return_value = [(0,)] # Return tuple, not dict |
|||
result = oracle_db.has_collection("nonexistent_collection") |
|||
self.assertFalse(result) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
@ -1,192 +0,0 @@ |
|||
import unittest |
|||
from unittest.mock import patch, MagicMock |
|||
import numpy as np |
|||
import sys |
|||
|
|||
class TestQdrant(unittest.TestCase): |
|||
"""Tests for the Qdrant vector database implementation.""" |
|||
|
|||
def setUp(self): |
|||
"""Set up test fixtures.""" |
|||
# Create mock modules |
|||
self.mock_qdrant = MagicMock() |
|||
self.mock_models = MagicMock() |
|||
self.mock_qdrant.models = self.mock_models |
|||
|
|||
# Create the module patcher |
|||
self.module_patcher = patch.dict('sys.modules', { |
|||
'qdrant_client': self.mock_qdrant, |
|||
'qdrant_client.models': self.mock_models |
|||
}) |
|||
self.module_patcher.start() |
|||
|
|||
# Import after mocking |
|||
from deepsearcher.vector_db import Qdrant |
|||
from deepsearcher.loader.splitter import Chunk |
|||
from deepsearcher.vector_db.base import RetrievalResult |
|||
|
|||
self.Qdrant = Qdrant |
|||
self.Chunk = Chunk |
|||
self.RetrievalResult = RetrievalResult |
|||
|
|||
def tearDown(self): |
|||
"""Clean up test fixtures.""" |
|||
self.module_patcher.stop() |
|||
|
|||
@patch('qdrant_client.QdrantClient') |
|||
def test_init(self, mock_client_class): |
|||
"""Test basic initialization.""" |
|||
mock_client = MagicMock() |
|||
mock_client_class.return_value = mock_client |
|||
|
|||
qdrant = self.Qdrant( |
|||
location="memory", |
|||
url="http://custom:6333", |
|||
port=6333, |
|||
api_key="test_key", |
|||
default_collection="custom" |
|||
) |
|||
|
|||
# Verify initialization - just check basic properties |
|||
self.assertEqual(qdrant.default_collection, "custom") |
|||
self.assertIsNotNone(qdrant.client) |
|||
|
|||
@patch('qdrant_client.QdrantClient') |
|||
def test_init_collection(self, mock_client_class): |
|||
"""Test collection initialization.""" |
|||
mock_client = MagicMock() |
|||
mock_client_class.return_value = mock_client |
|||
mock_client.collection_exists.return_value = False |
|||
|
|||
qdrant = self.Qdrant() |
|||
|
|||
# Test collection initialization |
|||
d = 8 |
|||
collection = "test_collection" |
|||
|
|||
try: |
|||
qdrant.init_collection(dim=d, collection=collection) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "init_collection should work") |
|||
|
|||
@patch('qdrant_client.QdrantClient') |
|||
def test_insert_data(self, mock_client_class): |
|||
"""Test inserting data.""" |
|||
mock_client = MagicMock() |
|||
mock_client_class.return_value = mock_client |
|||
mock_client.upsert.return_value = None |
|||
|
|||
qdrant = self.Qdrant() |
|||
|
|||
# Create test data |
|||
d = 8 |
|||
collection = "test_collection" |
|||
rng = np.random.default_rng(seed=42) |
|||
|
|||
# Create test chunks with numpy arrays converted to lists |
|||
chunks = [ |
|||
self.Chunk( |
|||
embedding=rng.random(d).tolist(), # Convert to list |
|||
text="hello world", |
|||
reference="test.txt", |
|||
metadata={"key": "value1"} |
|||
), |
|||
self.Chunk( |
|||
embedding=rng.random(d).tolist(), # Convert to list |
|||
text="hello qdrant", |
|||
reference="test.txt", |
|||
metadata={"key": "value2"} |
|||
) |
|||
] |
|||
|
|||
try: |
|||
qdrant.insert_data(collection=collection, chunks=chunks) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "insert_data should work") |
|||
|
|||
@patch('qdrant_client.QdrantClient') |
|||
def test_search_data(self, mock_client_class): |
|||
"""Test search functionality.""" |
|||
mock_client = MagicMock() |
|||
mock_client_class.return_value = mock_client |
|||
|
|||
# Mock search results |
|||
d = 8 |
|||
rng = np.random.default_rng(seed=42) |
|||
mock_point1 = MagicMock() |
|||
mock_point1.vector = rng.random(d) |
|||
mock_point1.payload = { |
|||
"text": "hello world", |
|||
"reference": "test.txt", |
|||
"metadata": {"key": "value1"} |
|||
} |
|||
mock_point1.score = 0.95 |
|||
|
|||
mock_point2 = MagicMock() |
|||
mock_point2.vector = rng.random(d) |
|||
mock_point2.payload = { |
|||
"text": "hello qdrant", |
|||
"reference": "test.txt", |
|||
"metadata": {"key": "value2"} |
|||
} |
|||
mock_point2.score = 0.85 |
|||
|
|||
mock_response = MagicMock() |
|||
mock_response.points = [mock_point1, mock_point2] |
|||
mock_client.query_points.return_value = mock_response |
|||
|
|||
qdrant = self.Qdrant() |
|||
|
|||
# Test search |
|||
collection = "test_collection" |
|||
query_vector = rng.random(d) |
|||
|
|||
try: |
|||
results = qdrant.search_data( |
|||
collection=collection, |
|||
vector=query_vector, |
|||
top_k=2 |
|||
) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "search_data should work") |
|||
if test_passed: |
|||
self.assertIsInstance(results, list) |
|||
self.assertEqual(len(results), 2) |
|||
# Verify results are RetrievalResult objects |
|||
for result in results: |
|||
self.assertIsInstance(result, self.RetrievalResult) |
|||
|
|||
@patch('qdrant_client.QdrantClient') |
|||
def test_clear_collection(self, mock_client_class): |
|||
"""Test clearing collection.""" |
|||
mock_client = MagicMock() |
|||
mock_client_class.return_value = mock_client |
|||
mock_client.delete_collection.return_value = None |
|||
|
|||
qdrant = self.Qdrant() |
|||
collection = "test_collection" |
|||
|
|||
try: |
|||
qdrant.clear_db(collection=collection) |
|||
test_passed = True |
|||
except Exception as e: |
|||
test_passed = False |
|||
print(f"Error: {e}") |
|||
|
|||
self.assertTrue(test_passed, "clear_db should work") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
unittest.main() |
Loading…
Reference in new issue