Browse Source

移除无关文件

main
tanxing 2 weeks ago
parent
commit
dee0de1087
  1. 19
      Dockerfile
  2. 201
      LICENSE
  3. 7
      Makefile
  4. BIN
      assets/pic/deep-searcher-arch.png
  5. BIN
      assets/pic/demo.gif
  6. BIN
      assets/pic/logo.png
  7. 53
      evaluation/README.md
  8. 119
      evaluation/eval_config.yaml
  9. 329
      evaluation/evaluate.py
  10. BIN
      evaluation/plot_results/max_iter_vs_avg_token_usage.png
  11. BIN
      evaluation/plot_results/max_iter_vs_error_num.png
  12. BIN
      evaluation/plot_results/max_iter_vs_recall.png
  13. 1
      tests/__init__.py
  14. 1
      tests/agent/__init__.py
  15. 149
      tests/agent/test_base.py
  16. 237
      tests/agent/test_chain_of_rag.py
  17. 154
      tests/agent/test_collection_router.py
  18. 235
      tests/agent/test_deep_search.py
  19. 130
      tests/agent/test_naive_rag.py
  20. 162
      tests/agent/test_rag_router.py
  21. 1
      tests/embedding/__init__.py
  22. 105
      tests/embedding/test_base.py
  23. 148
      tests/embedding/test_bedrock_embedding.py
  24. 144
      tests/embedding/test_fastembed_embedding.py
  25. 266
      tests/embedding/test_gemini_embedding.py
  26. 143
      tests/embedding/test_glm_embedding.py
  27. 130
      tests/embedding/test_milvus_embedding.py
  28. 193
      tests/embedding/test_novita_embedding.py
  29. 239
      tests/embedding/test_ollama_embedding.py
  30. 272
      tests/embedding/test_openai_embedding.py
  31. 191
      tests/embedding/test_ppio_embedding.py
  32. 213
      tests/embedding/test_sentence_transformer_embedding.py
  33. 201
      tests/embedding/test_siliconflow_embedding.py
  34. 201
      tests/embedding/test_volcengine_embedding.py
  35. 144
      tests/embedding/test_voyage_embedding.py
  36. 284
      tests/embedding/test_watsonx_embedding.py
  37. 1
      tests/llm/__init__.py
  38. 164
      tests/llm/test_aliyun.py
  39. 169
      tests/llm/test_anthropic.py
  40. 170
      tests/llm/test_azure_openai.py
  41. 154
      tests/llm/test_base.py
  42. 196
      tests/llm/test_bedrock.py
  43. 169
      tests/llm/test_deepseek.py
  44. 136
      tests/llm/test_gemini.py
  45. 165
      tests/llm/test_glm.py
  46. 165
      tests/llm/test_novita.py
  47. 136
      tests/llm/test_ollama.py
  48. 167
      tests/llm/test_openai.py
  49. 165
      tests/llm/test_ppio.py
  50. 165
      tests/llm/test_siliconflow.py
  51. 151
      tests/llm/test_together_ai.py
  52. 165
      tests/llm/test_volcengine.py
  53. 421
      tests/llm/test_watsonx.py
  54. 165
      tests/llm/test_xai.py
  55. 1
      tests/loader/__init__.py
  56. 1
      tests/loader/file_loader/__init__.py
  57. 69
      tests/loader/file_loader/test_base.py
  58. 185
      tests/loader/file_loader/test_docling_loader.py
  59. 124
      tests/loader/file_loader/test_json_loader.py
  60. 142
      tests/loader/file_loader/test_pdf_loader.py
  61. 82
      tests/loader/file_loader/test_text_loader.py
  62. 203
      tests/loader/file_loader/test_unstructured_loader.py
  63. 98
      tests/loader/test_splitter.py
  64. 1
      tests/loader/web_crawler/__init__.py
  65. 53
      tests/loader/web_crawler/test_base.py
  66. 157
      tests/loader/web_crawler/test_crawl4ai_crawler.py
  67. 157
      tests/loader/web_crawler/test_docling_crawler.py
  68. 135
      tests/loader/web_crawler/test_firecrawl_crawler.py
  69. 112
      tests/loader/web_crawler/test_jina_crawler.py
  70. 172
      tests/utils/test_log.py
  71. 237
      tests/vector_db/test_azure_search.py
  72. 157
      tests/vector_db/test_base.py
  73. 141
      tests/vector_db/test_milvus.py
  74. 255
      tests/vector_db/test_oracle.py
  75. 192
      tests/vector_db/test_qdrant.py

19
Dockerfile

@ -1,19 +0,0 @@
FROM ghcr.io/astral-sh/uv:python3.10-bookworm-slim
WORKDIR /app
RUN mkdir -p /tmp/uv-cache /app/data /app/logs
COPY pyproject.toml uv.lock LICENSE README.md ./
COPY deepsearcher/ ./deepsearcher/
RUN uv sync
COPY . .
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/docs || exit 1
CMD ["uv", "run", "python", "main.py", "--enable-cors", "true"]

201
LICENSE

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019 Zilliz
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

7
Makefile

@ -1,7 +0,0 @@
lint:
uv run ruff format --diff
uv run ruff check
format:
uv run ruff format
uv run ruff check --fix

BIN
assets/pic/deep-searcher-arch.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 307 KiB

BIN
assets/pic/demo.gif

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.4 MiB

BIN
assets/pic/logo.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

53
evaluation/README.md

@ -1,53 +0,0 @@
# Evaluation of DeepSearcher
## Introduction
DeepSearcher is very good at answering complex queries. In this evaluation introduction, we provide some scripts to evaluate the performance of DeepSearcher vs. naive RAG.
The evaluation is based on the Recall metric:
> Recall@K: The percentage of relevant documents that are retrieved among the top K documents returned by the search engine.
Currently, we support the multi-hop question answering dataset of [2WikiMultiHopQA](https://paperswithcode.com/dataset/2wikimultihopqa). More dataset will be added in the future.
## Evaluation Script
The main evaluation script is `evaluate.py`.
Your can provide a config file, say `eval_config.yaml`, to specify the LLM, embedding model, and other provider and parameters.
```shell
python evaluate.py \
--dataset 2wikimultihopqa \
--config_yaml ./eval_config.yaml \
--pre_num 5 \
--output_dir ./eval_output
```
`pre_num` is the number of samples to evaluate, the more samples, the more accurate the results will be, but it will consume more time and your LLM api token usage.
After you have loaded the dataset into vectorDB in the first run, if you want to skip loading dataset again, you can set the flag `--skip_load` in the command line.
For more arguments details, you can run
```shell
python evaluate.py --help
```
## Evaluation Results
We conducted tests using the commonly used 2WikiMultiHopQA dataset. (Due to the high consumption of API tokens for testing, we only tested the first 50 samples. This may introduce some fluctuations compared to testing the entire dataset, but it can still roughly reflect the general landscape of performance.)
### Recall Comparison between Naive RAG and DeepSearcher with Different Models
With Max Iterations on the horizontal axis and Recall on the vertical axis, the following chart compares the recall rates of Deep Searcher and naive RAG.
![](plot_results/max_iter_vs_recall.png)
#### Performance Improvement with Iterations
As we can see, as the number of Max Iterations increases, the recall performance of Deep Searcher improves significantly. And all the model results from Deep Searcher are significantly higher than those from naive RAG.
#### Diminishing Returns
However, it is also evident that as the number of iterations gradually increases, the marginal gains decrease, indicating that there may be a certain limit reached after increasing the feedback iterations, and further feedback might not yield significantly better results.
#### Model Performance Comparison
Claude-3-7-sonnet (red line) demonstrates superior performance throughout, achieving nearly perfect recall at 7 iterations. Most models show significant improvement as iterations increase, with the steepest gains occurring between 2-4 iterations. Models like o1-mini (yellow) and deepseek-r1 (green) exhibit strong performance at higher iteration counts. Since our sample number for testing is limited, the results of each test may vary somewhat.
Overall, reasoning models generally perform better than non-reasoning models.
#### Limitations of Non-Reasoning Models
Additionally, in our tests, weaker and smaller non-reasoning models sometimes failed to complete the entire agent query pipeline, due to their inadequate instruction-following capabilities.
### Token Consumption
We plotted the graph below with the number of iterations on the horizontal axis and the average token consumption per sample on the vertical axis:
![](plot_results/max_iter_vs_avg_token_usage.png)
It is evident that as the number of iterations increases, the token consumption of Deep Searcher rises linearly. Based on this approximate token consumption, you can check the pricing on your model provider's website to estimate the cost of running evaluations with different iteration settings.

119
evaluation/eval_config.yaml

@ -1,119 +0,0 @@
provide_settings:
llm:
provider: "OpenAI"
config:
model: "o1-mini"
# api_key: "sk-xxxx" # Uncomment to override the `OPENAI_API_KEY` set in the environment variable
# base_url: ""
# provider: "DeepSeek"
# config:
# model: "deepseek-reasoner"
## api_key: "sk-xxxx" # Uncomment to override the `DEEPSEEK_API_KEY` set in the environment variable
## base_url: ""
# provider: "SiliconFlow"
# config:
# model: "deepseek-ai/DeepSeek-R1"
## api_key: "xxxx" # Uncomment to override the `SILICONFLOW_API_KEY` set in the environment variable
## base_url: ""
# provider: "PPIO"
# config:
# model: "deepseek/deepseek-r1-turbo"
## api_key: "xxxx" # Uncomment to override the `PPIO_API_KEY` set in the environment variable
## base_url: ""
# provider: "TogetherAI"
# config:
# model: "deepseek-ai/DeepSeek-R1"
## api_key: "xxxx" # Uncomment to override the `TOGETHER_API_KEY` set in the environment variable
# provider: "AzureOpenAI"
# config:
# model: ""
# api_version: ""
## azure_endpoint: "xxxx" # Uncomment to override the `AZURE_OPENAI_ENDPOINT` set in the environment variable
## api_key: "xxxx" # Uncomment to override the `AZURE_OPENAI_KEY` set in the environment variable
# provider: "Ollama"
# config:
# model: "qwq"
## base_url: ""
# provider: "Novita"
# config:
# model: "deepseek/deepseek-v3-0324"
## api_key: "xxxx" # Uncomment to override the `NOVITA_API_KEY` set in the environment variable
## base_url: ""
embedding:
provider: "OpenAIEmbedding"
config:
model: "text-embedding-ada-002"
# api_key: "" # Uncomment to override the `OPENAI_API_KEY` set in the environment variable
# provider: "MilvusEmbedding"
# config:
# model: "default"
# provider: "VoyageEmbedding"
# config:
# model: "voyage-3"
## api_key: "" # Uncomment to override the `VOYAGE_API_KEY` set in the environment variable
# provider: "BedrockEmbedding"
# config:
# model: "amazon.titan-embed-text-v2:0"
## aws_access_key_id: "" # Uncomment to override the `AWS_ACCESS_KEY_ID` set in the environment variable
## aws_secret_access_key: "" # Uncomment to override the `AWS_SECRET_ACCESS_KEY` set in the environment variable
# provider: "SiliconflowEmbedding"
# config:
# model: "BAAI/bge-m3"
# . api_key: "" # Uncomment to override the `SILICONFLOW_API_KEY` set in the environment variable
# provider: "NovitaEmbedding"
# config:
# model: "baai/bge-m3"
# . api_key: "" # Uncomment to override the `NOVITA_API_KEY` set in the environment variable
file_loader:
# provider: "PDFLoader"
# config: {}
provider: "JsonFileLoader"
config:
text_key: "text"
# provider: "TextLoader"
# config: {}
# provider: "UnstructuredLoader"
# config: {}
web_crawler:
provider: "FireCrawlCrawler"
config: {}
# provider: "Crawl4AICrawler"
# config: {}
# provider: "JinaCrawler"
# config: {}
vector_db:
provider: "Milvus"
config:
default_collection: "deepsearcher"
uri: "./milvus.db"
token: "root:Milvus"
db: "default"
query_settings:
max_iter: 3
load_settings:
chunk_size: 1500
chunk_overlap: 100

329
evaluation/evaluate.py

@ -1,329 +0,0 @@
# Some test dataset and evaluation method are ref from https://github.com/OSU-NLP-Group/HippoRAG/tree/main/data , many thanks
################################################################################
# Note: This evaluation script will cost a lot of LLM token usage, please make sure you have enough token budget.
################################################################################
import argparse
import ast
import json
import logging
import os
import time
import warnings
from collections import defaultdict
from typing import List, Tuple
import pandas as pd
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.online_query import naive_retrieve, retrieve
httpx_logger = logging.getLogger("httpx") # disable openai's logger output
httpx_logger.setLevel(logging.WARNING)
warnings.simplefilter(action="ignore", category=FutureWarning) # disable warning output
current_dir = os.path.dirname(os.path.abspath(__file__))
k_list = [2, 5]
def _deepsearch_retrieve_titles(
question: str,
retry_num: int = 4,
base_wait_time: int = 4,
max_iter: int = 3,
) -> Tuple[List[str], int, bool]:
"""
Retrieve document titles using DeepSearcher with retry mechanism.
Args:
question (str): The query question.
retry_num (int, optional): Number of retry attempts. Defaults to 4.
base_wait_time (int, optional): Base wait time between retries in seconds. Defaults to 4.
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
Returns:
Tuple[List[str], int, bool]: A tuple containing:
- List of retrieved document titles
- Number of tokens consumed
- Boolean indicating whether the retrieval failed
"""
retrieved_results = []
consume_tokens = 0
for i in range(retry_num):
try:
retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter)
break
except Exception:
wait_time = base_wait_time * (2**i)
print(f"Parse LLM's output failed, retry again after {wait_time} seconds...")
time.sleep(wait_time)
if retrieved_results:
retrieved_titles = [
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
]
fail = False
else:
print("Pipeline error, no retrieved results.")
retrieved_titles = []
fail = True
return retrieved_titles, consume_tokens, fail
def _naive_retrieve_titles(question: str) -> List[str]:
"""
Retrieve document titles using naive retrieval method.
Args:
question (str): The query question.
Returns:
List[str]: List of retrieved document titles.
"""
retrieved_results = naive_retrieve(question)
retrieved_titles = [
retrieved_result.metadata["title"] for retrieved_result in retrieved_results
]
return retrieved_titles
def _calcu_recall(sample, retrieved_titles, dataset) -> dict:
"""
Calculate recall metrics for retrieved titles.
Args:
sample: The sample data containing ground truth information.
retrieved_titles: List of retrieved document titles.
dataset (str): The name of the dataset being evaluated.
Returns:
dict: Dictionary containing recall values at different k values.
Raises:
NotImplementedError: If the dataset is not supported.
"""
if dataset in ["2wikimultihopqa"]:
gold_passages = [item for item in sample["supporting_facts"]]
gold_items = set([item[0] for item in gold_passages])
retrieved_items = retrieved_titles
else:
raise NotImplementedError
recall = dict()
for k in k_list:
recall[k] = round(
sum(1 for t in gold_items if t in retrieved_items[:k]) / len(gold_items), 4
)
return recall
def _print_recall_line(recall: dict, pre_str="", post_str="\n"):
"""
Print recall metrics in a formatted line.
Args:
recall (dict): Dictionary containing recall values at different k values.
pre_str (str, optional): String to print before recall values. Defaults to "".
post_str (str, optional): String to print after recall values. Defaults to "\n".
"""
print(pre_str, end="")
for k in k_list:
print(f"R@{k}: {recall[k]:.3f} ", end="")
print(post_str, end="")
def evaluate(
dataset: str,
output_root: str,
pre_num: int = 10,
max_iter: int = 3,
skip_load=False,
flag: str = "result",
):
"""
Evaluate the retrieval performance on a dataset.
Args:
dataset (str): Name of the dataset to evaluate.
output_root (str): Root directory for output files.
pre_num (int, optional): Number of samples to evaluate. Defaults to 10.
max_iter (int, optional): Maximum number of iterations for retrieval. Defaults to 3.
skip_load (bool, optional): Whether to skip loading the dataset. Defaults to False.
flag (str, optional): Flag for the evaluation run. Defaults to "result".
"""
corpus_file = os.path.join(current_dir, f"../examples/data/{dataset}_corpus.json")
if not skip_load:
# set chunk size to a large number to avoid chunking, because the dataset was chunked already.
load_from_local_files(
corpus_file, force_new_collection=True, chunk_size=999999, chunk_overlap=0
)
eval_output_subdir = os.path.join(output_root, flag)
os.makedirs(eval_output_subdir, exist_ok=True)
csv_file_path = os.path.join(eval_output_subdir, "details.csv")
statistics_file_path = os.path.join(eval_output_subdir, "statistics.json")
data_with_gt_file_path = os.path.join(current_dir, f"../examples/data/{dataset}.json")
data_with_gt = json.load(open(data_with_gt_file_path, "r"))
if not pre_num:
pre_num = len(data_with_gt)
pipeline_error_num = 0
end_ind = min(pre_num, len(data_with_gt))
start_ind = 0
existing_df = pd.DataFrame()
existing_statistics = defaultdict(dict)
existing_token_usage = 0
existing_error_num = 0
existing_sample_num = 0
if os.path.exists(csv_file_path):
existing_df = pd.read_csv(csv_file_path)
start_ind = len(existing_df)
print(f"Loading results from {csv_file_path}, start_index = {start_ind}")
if os.path.exists(statistics_file_path):
existing_statistics = json.load(open(statistics_file_path, "r"))
print(
f"Loading statistics from {statistics_file_path}, will recalculate the statistics based on both new and existing results."
)
existing_token_usage = existing_statistics["deepsearcher"]["token_usage"]
existing_error_num = existing_statistics["deepsearcher"].get("error_num", 0)
existing_sample_num = existing_statistics["deepsearcher"].get("sample_num", 0)
for sample_idx, sample in enumerate(data_with_gt[start_ind:end_ind]):
global_idx = sample_idx + start_ind
question = sample["question"]
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(
question, max_iter=max_iter
)
retrieved_titles_naive = _naive_retrieve_titles(question)
if fail:
pipeline_error_num += 1
print(
f"Pipeline error, no retrieved results. Current pipeline_error_num = {pipeline_error_num}"
)
print(f"idx: {global_idx}: ")
recall = _calcu_recall(sample, retrieved_titles, dataset)
recall_naive = _calcu_recall(sample, retrieved_titles_naive, dataset)
current_result = [
{
"idx": global_idx,
"question": question,
"recall": recall,
"recall_naive": recall_naive,
"gold_titles": [item[0] for item in sample["supporting_facts"]],
"retrieved_titles": retrieved_titles,
"retrieved_titles_naive": retrieved_titles_naive,
}
]
current_df = pd.DataFrame(current_result)
existing_df = pd.concat([existing_df, current_df], ignore_index=True)
existing_df.to_csv(csv_file_path, index=False)
average_recall = dict()
average_recall_naive = dict()
for k in k_list:
average_recall[k] = sum(
[
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
for d in existing_df["recall"]
]
) / len(existing_df)
average_recall_naive[k] = sum(
[
ast.literal_eval(d).get(k) if isinstance(d, str) else d.get(k)
for d in existing_df["recall_naive"]
]
) / len(existing_df)
_print_recall_line(average_recall, pre_str="Average recall of DeepSearcher: ")
_print_recall_line(average_recall_naive, pre_str="Average recall of naive RAG : ")
existing_token_usage += consume_tokens
existing_error_num += 1 if fail else 0
existing_sample_num += 1
existing_statistics["deepsearcher"]["average_recall"] = average_recall
existing_statistics["deepsearcher"]["token_usage"] = existing_token_usage
existing_statistics["deepsearcher"]["error_num"] = existing_error_num
existing_statistics["deepsearcher"]["sample_num"] = existing_sample_num
existing_statistics["deepsearcher"]["token_usage_per_sample"] = (
existing_token_usage / existing_sample_num
)
existing_statistics["naive_rag"]["average_recall"] = average_recall_naive
json.dump(existing_statistics, open(statistics_file_path, "w"), indent=4)
print("")
print("Finish results to save.")
def main_eval():
"""
Main function for running the evaluation from command line.
This function parses command line arguments and calls the evaluate function
with the appropriate parameters.
"""
parser = argparse.ArgumentParser(prog="evaluate", description="Deep Searcher evaluation.")
parser.add_argument(
"--dataset",
type=str,
default="2wikimultihopqa",
help="Dataset name, default is `2wikimultihopqa`. More datasets will be supported in the future.",
)
parser.add_argument(
"--config_yaml",
type=str,
default="./eval_config.yaml",
help="Configuration yaml file path, default is `./eval_config.yaml`",
)
parser.add_argument(
"--pre_num",
type=int,
default=30,
help="Number of samples to evaluate, default is 30",
)
parser.add_argument(
"--max_iter",
type=int,
default=3,
help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./eval_output",
help="Output root directory, default is `./eval_output`",
)
parser.add_argument(
"--skip_load",
action="store_true",
help="Whether to skip loading the dataset. Default it don't skip loading. If you want to skip loading, please set this flag.",
)
parser.add_argument(
"--flag",
type=str,
default="result",
help="Flag for evaluation, default is `result`",
)
args = parser.parse_args()
config = Configuration(config_path=args.config_yaml)
init_config(config=config)
evaluate(
dataset=args.dataset,
output_root=args.output_dir,
pre_num=args.pre_num,
max_iter=args.max_iter,
skip_load=args.skip_load,
flag=args.flag,
)
if __name__ == "__main__":
main_eval()

BIN
evaluation/plot_results/max_iter_vs_avg_token_usage.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 124 KiB

BIN
evaluation/plot_results/max_iter_vs_error_num.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 92 KiB

BIN
evaluation/plot_results/max_iter_vs_recall.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

1
tests/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher package

1
tests/agent/__init__.py

@ -1 +0,0 @@
# Tests for the agent module

149
tests/agent/test_base.py

@ -1,149 +0,0 @@
import unittest
from unittest.mock import MagicMock
import numpy as np
from deepsearcher.llm.base import BaseLLM, ChatResponse
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, CollectionInfo
class MockLLM(BaseLLM):
"""Mock LLM implementation for testing agents."""
def __init__(self, predefined_responses=None):
"""
Initialize the MockLLM.
Args:
predefined_responses: Dictionary mapping prompt substrings to responses
"""
self.chat_called = False
self.last_messages = None
self.predefined_responses = predefined_responses or {}
def chat(self, messages, **kwargs):
"""Mock implementation of chat that returns predefined responses or a default response."""
self.chat_called = True
self.last_messages = messages
if self.predefined_responses:
message_content = messages[0]["content"] if messages else ""
for key, response in self.predefined_responses.items():
if key in message_content:
return ChatResponse(content=response, total_tokens=10)
# Default response for RERANK_PROMPT - treat all chunks as relevant
if "Based on the query questions and the retrieved chunks" in message_content:
# Count the number of chunks in the message
chunk_count = message_content.count("<chunk_")
# Return a list with "YES" for each chunk
return ChatResponse(content=str(["YES"] * chunk_count), total_tokens=10)
return ChatResponse(content="This is a test answer", total_tokens=10)
def literal_eval(self, text):
"""Mock implementation of literal_eval."""
# Default implementation returns a list with test_collection
# Override this in specific tests if needed
if text.strip().startswith("[") and text.strip().endswith("]"):
# Return the list as is if it's already in list format
try:
import ast
return ast.literal_eval(text)
except:
pass
return ["test_collection"]
class MockEmbedding(BaseEmbedding):
"""Mock embedding model implementation for testing agents."""
def __init__(self, dimension=8):
"""Initialize the MockEmbedding with a specific dimension."""
self._dimension = dimension
@property
def dimension(self):
"""Return the dimension of the embedding model."""
return self._dimension
def embed_query(self, text):
"""Mock implementation that returns a random vector of the specified dimension."""
return np.random.random(self._dimension).tolist()
def embed_documents(self, documents):
"""Mock implementation that returns random vectors for each document."""
return [np.random.random(self._dimension).tolist() for _ in documents]
class MockVectorDB(BaseVectorDB):
"""Mock vector database implementation for testing agents."""
def __init__(self, collections=None):
"""
Initialize the MockVectorDB.
Args:
collections: List of collection names to initialize with
"""
self.default_collection = "test_collection"
self.search_called = False
self.insert_called = False
self._collections = []
if collections:
for collection in collections:
self._collections.append(
CollectionInfo(collection_name=collection, description=f"Test collection {collection}")
)
else:
self._collections = [
CollectionInfo(collection_name="test_collection", description="Test collection for testing")
]
def search_data(self, collection, vector, top_k=10, **kwargs):
"""Mock implementation that returns test results."""
self.search_called = True
self.last_search_collection = collection
self.last_search_vector = vector
self.last_search_top_k = top_k
return [
RetrievalResult(
embedding=vector,
text=f"Test result {i} for collection {collection}",
reference=f"test_reference_{collection}_{i}",
metadata={"a": i, "wider_text": f"Wider context for test result {i} in collection {collection}"}
)
for i in range(min(3, top_k))
]
def insert_data(self, collection, chunks):
"""Mock implementation of insert_data."""
self.insert_called = True
self.last_insert_collection = collection
self.last_insert_chunks = chunks
return True
def init_collection(self, dim, collection, **kwargs):
"""Mock implementation of init_collection."""
return True
def list_collections(self, dim=None):
"""Mock implementation that returns the list of collections."""
return self._collections
def clear_db(self, collection):
"""Mock implementation of clear_db."""
return True
class BaseAgentTest(unittest.TestCase):
"""Base test class for agent tests with common setup."""
def setUp(self):
"""Set up test fixtures for agent tests."""
self.llm = MockLLM()
self.embedding_model = MockEmbedding(dimension=8)
self.vector_db = MockVectorDB()

237
tests/agent/test_chain_of_rag.py

@ -1,237 +0,0 @@
from unittest.mock import MagicMock, patch
from deepsearcher.agent import ChainOfRAG
from deepsearcher.vector_db.base import RetrievalResult
from deepsearcher.llm.base import ChatResponse
from tests.agent.test_base import BaseAgentTest
class TestChainOfRAG(BaseAgentTest):
"""Test class for ChainOfRAG agent."""
def setUp(self):
"""Set up test fixtures for ChainOfRAG tests."""
super().setUp()
# Set up predefined responses for the LLM for exact prompt substrings
self.llm.predefined_responses = {
"previous queries and answers, generate a new simple follow-up question": "What is the significance of deep learning?",
"Given the following documents, generate an appropriate answer": "Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
"given the following intermediate queries and answers, judge whether you have enough information": "Yes",
"Given a list of agent indexes and corresponding descriptions": "1",
"Given the following documents, select the ones that are support the Q-A pair": "[0, 1]",
"Given the following intermediate queries and answers, generate a final answer": "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers."
}
self.chain_of_rag = ChainOfRAG(
llm=self.llm,
embedding_model=self.embedding_model,
vector_db=self.vector_db,
max_iter=3,
early_stopping=True,
route_collection=True,
text_window_splitter=True
)
def test_init(self):
"""Test the initialization of ChainOfRAG."""
self.assertEqual(self.chain_of_rag.llm, self.llm)
self.assertEqual(self.chain_of_rag.embedding_model, self.embedding_model)
self.assertEqual(self.chain_of_rag.vector_db, self.vector_db)
self.assertEqual(self.chain_of_rag.max_iter, 3)
self.assertEqual(self.chain_of_rag.early_stopping, True)
self.assertEqual(self.chain_of_rag.route_collection, True)
self.assertEqual(self.chain_of_rag.text_window_splitter, True)
def test_reflect_get_subquery(self):
"""Test the _reflect_get_subquery method."""
query = "What is deep learning?"
intermediate_context = ["Previous query: What is AI?", "Previous answer: AI is artificial intelligence."]
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="What is the significance of deep learning?",
total_tokens=10
))
subquery, tokens = self.chain_of_rag._reflect_get_subquery(query, intermediate_context)
self.assertEqual(subquery, "What is the significance of deep learning?")
self.assertEqual(tokens, 10)
self.assertTrue(self.llm.chat.called)
def test_retrieve_and_answer(self):
"""Test the _retrieve_and_answer method."""
query = "What is deep learning?"
# Mock the collection_router.invoke method
self.chain_of_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
total_tokens=10
))
answer, results, tokens = self.chain_of_rag._retrieve_and_answer(query)
# Check if correct methods were called
self.chain_of_rag.collection_router.invoke.assert_called_once()
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.")
self.assertEqual(tokens, 15) # 5 from collection_router + 10 from LLM
def test_get_supported_docs(self):
"""Test the _get_supported_docs method."""
results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i}
)
for i in range(3)
]
query = "What is deep learning?"
answer = "Deep learning is a subset of machine learning that uses neural networks with multiple layers."
# Mock the literal_eval method to return indices as integers
self.llm.literal_eval = MagicMock(return_value=[0, 1])
supported_docs, tokens = self.chain_of_rag._get_supported_docs(results, query, answer)
self.assertEqual(len(supported_docs), 2) # Based on our mock response of [0, 1]
self.assertEqual(tokens, 10)
def test_check_has_enough_info(self):
"""Test the _check_has_enough_info method."""
query = "What is deep learning?"
intermediate_contexts = [
"Intermediate query1: What is deep learning?",
"Intermediate answer1: Deep learning is a subset of machine learning that uses neural networks with multiple layers."
]
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Yes",
total_tokens=10
))
has_enough, tokens = self.chain_of_rag._check_has_enough_info(query, intermediate_contexts)
self.assertTrue(has_enough) # Based on our mock response of "Yes"
self.assertEqual(tokens, 10)
def test_retrieve(self):
"""Test the retrieve method."""
query = "What is deep learning?"
# Mock all the methods that retrieve calls
self.chain_of_rag._reflect_get_subquery = MagicMock(return_value=("What is the significance of deep learning?", 5))
self.chain_of_rag._retrieve_and_answer = MagicMock(
return_value=("Deep learning is important in AI", [RetrievalResult(
embedding=[0.1] * 8,
text="Test result",
reference="test_reference",
metadata={"a": 1}
)], 10)
)
self.chain_of_rag._get_supported_docs = MagicMock(return_value=([RetrievalResult(
embedding=[0.1] * 8,
text="Test result",
reference="test_reference",
metadata={"a": 1}
)], 5))
self.chain_of_rag._check_has_enough_info = MagicMock(return_value=(True, 5))
results, tokens, metadata = self.chain_of_rag.retrieve(query)
# Check if methods were called
self.chain_of_rag._reflect_get_subquery.assert_called_once()
self.chain_of_rag._retrieve_and_answer.assert_called_once()
self.chain_of_rag._get_supported_docs.assert_called_once()
# With early stopping, it should check if we have enough info
self.chain_of_rag._check_has_enough_info.assert_called_once()
# Check results
self.assertEqual(len(results), 1)
self.assertEqual(tokens, 25) # 5 + 10 + 5 + 5
self.assertIn("intermediate_context", metadata)
def test_query(self):
"""Test the query method."""
query = "What is deep learning?"
# Mock the retrieve method
retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.chain_of_rag.retrieve = MagicMock(
return_value=(retrieved_results, 20, {"intermediate_context": ["Some context"]})
)
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.",
total_tokens=10
))
answer, results, tokens = self.chain_of_rag.query(query)
# Check if methods were called
self.chain_of_rag.retrieve.assert_called_once_with(query)
self.assertTrue(self.llm.chat.called)
# Check results
self.assertEqual(answer, "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.")
self.assertEqual(results, retrieved_results)
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM
def test_format_retrieved_results(self):
"""Test the _format_retrieved_results method."""
retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Test result 1",
reference="test_reference",
metadata={"a": 1, "wider_text": "Wider context for test result 1"}
),
RetrievalResult(
embedding=[0.1] * 8,
text="Test result 2",
reference="test_reference",
metadata={"a": 2, "wider_text": "Wider context for test result 2"}
)
]
# Test with text_window_splitter enabled
self.chain_of_rag.text_window_splitter = True
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
self.assertIn("Wider context for test result 1", formatted)
self.assertIn("Wider context for test result 2", formatted)
# Test with text_window_splitter disabled
self.chain_of_rag.text_window_splitter = False
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
self.assertIn("Test result 1", formatted)
self.assertIn("Test result 2", formatted)
self.assertNotIn("Wider context for test result 1", formatted)
if __name__ == "__main__":
import unittest
unittest.main()

154
tests/agent/test_collection_router.py

@ -1,154 +0,0 @@
from unittest.mock import MagicMock, patch
from deepsearcher.agent.collection_router import CollectionRouter
from deepsearcher.llm.base import ChatResponse
from deepsearcher.vector_db.base import CollectionInfo
from tests.agent.test_base import BaseAgentTest
class TestCollectionRouter(BaseAgentTest):
"""Test class for CollectionRouter."""
def setUp(self):
"""Set up test fixtures for CollectionRouter tests."""
super().setUp()
# Create mock collections
self.collection_infos = [
CollectionInfo(collection_name="books", description="Collection of book summaries"),
CollectionInfo(collection_name="science", description="Scientific articles and papers"),
CollectionInfo(collection_name="news", description="Recent news articles")
]
# Configure vector_db mock
self.vector_db.list_collections = MagicMock(return_value=self.collection_infos)
self.vector_db.default_collection = "books"
# Create the CollectionRouter
self.collection_router = CollectionRouter(
llm=self.llm,
vector_db=self.vector_db,
dim=8
)
def test_init(self):
"""Test the initialization of CollectionRouter."""
self.assertEqual(self.collection_router.llm, self.llm)
self.assertEqual(self.collection_router.vector_db, self.vector_db)
self.assertEqual(
self.collection_router.all_collections,
["books", "science", "news"]
)
def test_invoke_with_multiple_collections(self):
"""Test the invoke method with multiple collections."""
query = "What are the latest scientific breakthroughs?"
# Mock LLM to return specific collections based on query
self.llm.chat = MagicMock(return_value=ChatResponse(
content='["science", "news"]',
total_tokens=10
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Check results
self.assertTrue("science" in selected_collections)
self.assertTrue("news" in selected_collections)
self.assertTrue("books" in selected_collections) # Default collection is always included
self.assertEqual(tokens, 10)
# Verify that the LLM was called with the right prompt
self.llm.chat.assert_called_once()
self.assertIn(query, self.llm.chat.call_args[1]["messages"][0]["content"])
self.assertIn("collection_name", self.llm.chat.call_args[1]["messages"][0]["content"])
def test_invoke_with_empty_response(self):
"""Test the invoke method when LLM returns an empty list."""
query = "Something completely unrelated"
# Mock LLM to return empty list
self.llm.chat = MagicMock(return_value=ChatResponse(
content='[]',
total_tokens=5
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Only default collection should be included
self.assertEqual(len(selected_collections), 1)
self.assertEqual(selected_collections[0], "books")
self.assertEqual(tokens, 5)
def test_invoke_with_no_collections(self):
"""Test the invoke method when no collections are available."""
query = "Test query"
# Mock vector_db to return empty list
self.vector_db.list_collections = MagicMock(return_value=[])
# Disable log warnings for testing
with patch('deepsearcher.utils.log.warning'):
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Should return empty list and zero tokens
self.assertEqual(selected_collections, [])
self.assertEqual(tokens, 0)
def test_invoke_with_single_collection(self):
"""Test the invoke method when only one collection is available."""
query = "Test query"
# Create a fresh mock for llm.chat to verify it's not called
mock_chat = MagicMock(return_value=ChatResponse(content='[]', total_tokens=0))
self.llm.chat = mock_chat
# Mock vector_db to return single collection
single_collection = [CollectionInfo(collection_name="single", description="The only collection")]
self.vector_db.list_collections = MagicMock(return_value=single_collection)
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Should return the only collection without calling LLM
self.assertEqual(selected_collections, ["single"])
self.assertEqual(tokens, 0)
mock_chat.assert_not_called()
def test_invoke_with_no_description(self):
"""Test the invoke method when a collection has no description."""
query = "Test query"
# Create collections with one having no description
collections_with_no_desc = [
CollectionInfo(collection_name="with_desc", description="Has description"),
CollectionInfo(collection_name="no_desc", description="")
]
self.vector_db.list_collections = MagicMock(return_value=collections_with_no_desc)
self.vector_db.default_collection = "with_desc"
# Mock LLM to return only the first collection
self.llm.chat = MagicMock(return_value=ChatResponse(
content='["with_desc"]',
total_tokens=5
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Both collections should be included (one from LLM, one with no description)
self.assertEqual(set(selected_collections), {"with_desc", "no_desc"})
self.assertEqual(tokens, 5)
if __name__ == "__main__":
import unittest
unittest.main()

235
tests/agent/test_deep_search.py

@ -1,235 +0,0 @@
from unittest.mock import MagicMock, patch
import asyncio
from deepsearcher.agent import DeepSearch
from deepsearcher.vector_db.base import RetrievalResult
from tests.agent.test_base import BaseAgentTest
class TestDeepSearch(BaseAgentTest):
"""Test class for DeepSearch agent."""
def setUp(self):
"""Set up test fixtures for DeepSearch tests."""
super().setUp()
# Set up predefined responses for the LLM for exact prompt substrings
self.llm.predefined_responses = {
"Original Question:": '["What is deep learning?", "How does deep learning work?", "What are applications of deep learning?"]',
"Is the chunk helpful": "YES",
"Respond exclusively in valid List": '["What are limitations of deep learning?"]',
"You are a AI content analysis expert": "Deep learning is a subset of machine learning that uses neural networks with multiple layers."
}
self.deep_search = DeepSearch(
llm=self.llm,
embedding_model=self.embedding_model,
vector_db=self.vector_db,
max_iter=2,
route_collection=True,
text_window_splitter=True
)
def test_init(self):
"""Test the initialization of DeepSearch."""
self.assertEqual(self.deep_search.llm, self.llm)
self.assertEqual(self.deep_search.embedding_model, self.embedding_model)
self.assertEqual(self.deep_search.vector_db, self.vector_db)
self.assertEqual(self.deep_search.max_iter, 2)
self.assertEqual(self.deep_search.route_collection, True)
self.assertEqual(self.deep_search.text_window_splitter, True)
def test_generate_sub_queries(self):
"""Test the _generate_sub_queries method."""
query = "Tell me about deep learning"
sub_queries, tokens = self.deep_search._generate_sub_queries(query)
self.assertEqual(len(sub_queries), 3)
self.assertEqual(sub_queries[0], "What is deep learning?")
self.assertEqual(sub_queries[1], "How does deep learning work?")
self.assertEqual(sub_queries[2], "What are applications of deep learning?")
self.assertEqual(tokens, 10)
self.assertTrue(self.llm.chat_called)
def test_search_chunks_from_vectordb(self):
"""Test the _search_chunks_from_vectordb method."""
query = "What is deep learning?"
sub_queries = ["What is deep learning?", "How does deep learning work?"]
# Mock the collection_router.invoke method
self.deep_search.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
# Run the async method using asyncio.run
results, tokens = asyncio.run(
self.deep_search._search_chunks_from_vectordb(query, sub_queries)
)
# Check if correct methods were called
self.deep_search.collection_router.invoke.assert_called_once()
self.assertTrue(self.vector_db.search_called)
self.assertTrue(self.llm.chat_called)
# With our mock returning "YES" for RERANK_PROMPT, all chunks should be accepted
self.assertEqual(len(results), 3) # 3 mock results from MockVectorDB
self.assertEqual(tokens, 15) # 5 from collection_router + 10*1 from LLM calls for reranking (batch)
def test_generate_gap_queries(self):
"""Test the _generate_gap_queries method."""
query = "Tell me about deep learning"
all_sub_queries = ["What is deep learning?", "How does deep learning work?"]
all_chunks = [
RetrievalResult(
embedding=[0.1] * 8,
text="Deep learning is a subset of machine learning",
reference="test_reference",
metadata={"a": 1}
),
RetrievalResult(
embedding=[0.1] * 8,
text="Deep learning uses neural networks",
reference="test_reference",
metadata={"a": 2}
)
]
gap_queries, tokens = self.deep_search._generate_gap_queries(query, all_sub_queries, all_chunks)
self.assertEqual(len(gap_queries), 1)
self.assertEqual(gap_queries[0], "What are limitations of deep learning?")
self.assertEqual(tokens, 10)
def test_retrieve(self):
"""Test the retrieve method."""
query = "Tell me about deep learning"
# Mock async method to run synchronously
async def mock_async_retrieve(*args, **kwargs):
# Create some test results
results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Deep learning is a subset of machine learning",
reference="test_reference",
metadata={"a": 1}
),
RetrievalResult(
embedding=[0.1] * 8,
text="Deep learning uses neural networks",
reference="test_reference",
metadata={"a": 2}
)
]
# Return the results, token count, and additional info
return results, 30, {"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]}
# Replace the async method with our mock
self.deep_search.async_retrieve = mock_async_retrieve
results, tokens, metadata = self.deep_search.retrieve(query)
# Check results
self.assertEqual(len(results), 2)
self.assertEqual(tokens, 30)
self.assertIn("all_sub_queries", metadata)
self.assertEqual(len(metadata["all_sub_queries"]), 2)
def test_async_retrieve(self):
"""Test the async_retrieve method."""
query = "Tell me about deep learning"
# Create mock results
mock_results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Deep learning is a subset of machine learning",
reference="test_reference",
metadata={"a": 1}
)
]
# Create a mock async_retrieve result
mock_retrieve_result = (
mock_results,
20,
{"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]}
)
# Mock the async_retrieve method
async def mock_async_retrieve(*args, **kwargs):
return mock_retrieve_result
self.deep_search.async_retrieve = mock_async_retrieve
# Run the async method using asyncio.run
results, tokens, metadata = asyncio.run(self.deep_search.async_retrieve(query))
# Check results
self.assertEqual(len(results), 1)
self.assertEqual(tokens, 20)
self.assertIn("all_sub_queries", metadata)
def test_query(self):
"""Test the query method."""
query = "Tell me about deep learning"
# Mock the retrieve method
retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.deep_search.retrieve = MagicMock(
return_value=(retrieved_results, 20, {"all_sub_queries": ["What is deep learning?"]})
)
answer, results, tokens = self.deep_search.query(query)
# Check if methods were called
self.deep_search.retrieve.assert_called_once_with(query)
self.assertTrue(self.llm.chat_called)
# Check results
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.")
self.assertEqual(results, retrieved_results)
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM
def test_query_no_results(self):
"""Test the query method when no results are found."""
query = "Tell me about deep learning"
# Mock the retrieve method to return no results
self.deep_search.retrieve = MagicMock(
return_value=([], 10, {"all_sub_queries": ["What is deep learning?"]})
)
answer, results, tokens = self.deep_search.query(query)
# Should return a message saying no results found
self.assertIn("No relevant information found", answer)
self.assertEqual(results, [])
self.assertEqual(tokens, 10) # Only tokens from retrieve
def test_format_chunk_texts(self):
"""Test the _format_chunk_texts method."""
chunk_texts = ["Text 1", "Text 2", "Text 3"]
formatted = self.deep_search._format_chunk_texts(chunk_texts)
self.assertIn("<chunk_0>", formatted)
self.assertIn("Text 1", formatted)
self.assertIn("<chunk_1>", formatted)
self.assertIn("Text 2", formatted)
self.assertIn("<chunk_2>", formatted)
self.assertIn("Text 3", formatted)
if __name__ == "__main__":
import unittest
unittest.main()

130
tests/agent/test_naive_rag.py

@ -1,130 +0,0 @@
from unittest.mock import MagicMock
from deepsearcher.agent import NaiveRAG
from deepsearcher.vector_db.base import RetrievalResult
from tests.agent.test_base import BaseAgentTest
class TestNaiveRAG(BaseAgentTest):
"""Test class for NaiveRAG agent."""
def setUp(self):
"""Set up test fixtures for NaiveRAG tests."""
super().setUp()
self.naive_rag = NaiveRAG(
llm=self.llm,
embedding_model=self.embedding_model,
vector_db=self.vector_db,
top_k=5,
route_collection=True,
text_window_splitter=True
)
def test_init(self):
"""Test the initialization of NaiveRAG."""
self.assertEqual(self.naive_rag.llm, self.llm)
self.assertEqual(self.naive_rag.embedding_model, self.embedding_model)
self.assertEqual(self.naive_rag.vector_db, self.vector_db)
self.assertEqual(self.naive_rag.top_k, 5)
self.assertEqual(self.naive_rag.route_collection, True)
self.assertEqual(self.naive_rag.text_window_splitter, True)
def test_retrieve(self):
"""Test the retrieve method."""
query = "Test query"
# Mock the collection_router.invoke method
self.naive_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
results, tokens, metadata = self.naive_rag.retrieve(query)
# Check if correct methods were called
self.naive_rag.collection_router.invoke.assert_called_once()
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertIsInstance(results, list)
self.assertEqual(len(results), 3) # Should match our mock return of 3 results
for result in results:
self.assertIsInstance(result, RetrievalResult)
# Check token count
self.assertEqual(tokens, 5) # From our mocked collection_router.invoke
def test_retrieve_without_routing(self):
"""Test retrieve method with routing disabled."""
self.naive_rag.route_collection = False
query = "Test query without routing"
results, tokens, metadata = self.naive_rag.retrieve(query)
# Check that routing was not called
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertIsInstance(results, list)
for result in results:
self.assertIsInstance(result, RetrievalResult)
# Check token count
self.assertEqual(tokens, 0) # No tokens used for routing
def test_query(self):
"""Test the query method."""
query = "Test query for full RAG"
# Mock the retrieve method
mock_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
answer, retrieved_results, tokens = self.naive_rag.query(query)
# Check if correct methods were called
self.naive_rag.retrieve.assert_called_once_with(query)
self.assertTrue(self.llm.chat_called)
# Check the messages sent to LLM
self.assertIn("content", self.llm.last_messages[0])
self.assertIn(query, self.llm.last_messages[0]["content"])
# Check the results
self.assertEqual(answer, "This is a test answer")
self.assertEqual(retrieved_results, mock_results)
self.assertEqual(tokens, 15) # 5 from retrieve + 10 from LLM
def test_with_window_splitter_disabled(self):
"""Test with text window splitter disabled."""
self.naive_rag.text_window_splitter = False
query = "Test query with window splitter off"
# Mock the retrieve method
mock_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
answer, retrieved_results, tokens = self.naive_rag.query(query)
# Check that regular text is used instead of wider_text
self.assertIn("Test result 0", self.llm.last_messages[0]["content"])
self.assertNotIn("Wider context", self.llm.last_messages[0]["content"])
if __name__ == "__main__":
import unittest
unittest.main()

162
tests/agent/test_rag_router.py

@ -1,162 +0,0 @@
from unittest.mock import MagicMock, patch
from deepsearcher.agent import NaiveRAG, ChainOfRAG, DeepSearch
from deepsearcher.agent.rag_router import RAGRouter
from deepsearcher.vector_db.base import RetrievalResult
from deepsearcher.llm.base import ChatResponse
from tests.agent.test_base import BaseAgentTest
class TestRAGRouter(BaseAgentTest):
"""Test class for RAGRouter agent."""
def setUp(self):
"""Set up test fixtures for RAGRouter tests."""
super().setUp()
# Create mock agent instances
self.naive_rag = MagicMock(spec=NaiveRAG)
self.chain_of_rag = MagicMock(spec=ChainOfRAG)
# Create the RAGRouter with the mock agents
self.rag_router = RAGRouter(
llm=self.llm,
rag_agents=[self.naive_rag, self.chain_of_rag],
agent_descriptions=[
"This agent is suitable for simple factual queries",
"This agent is suitable for complex multi-hop questions"
]
)
def test_init(self):
"""Test the initialization of RAGRouter."""
self.assertEqual(self.rag_router.llm, self.llm)
self.assertEqual(len(self.rag_router.rag_agents), 2)
self.assertEqual(len(self.rag_router.agent_descriptions), 2)
self.assertEqual(self.rag_router.agent_descriptions[0], "This agent is suitable for simple factual queries")
self.assertEqual(self.rag_router.agent_descriptions[1], "This agent is suitable for complex multi-hop questions")
def test_route(self):
"""Test the _route method."""
query = "What is the capital of France?"
# Directly mock the chat method to return a numeric response
self.llm.chat = MagicMock(return_value=ChatResponse(content="1", total_tokens=10))
agent, tokens = self.rag_router._route(query)
# Should select the first agent based on our mock response
self.assertEqual(agent, self.naive_rag)
self.assertEqual(tokens, 10)
self.assertTrue(self.llm.chat.called)
def test_route_with_non_numeric_response(self):
"""Test the _route method with a non-numeric response from LLM."""
query = "What is the history of deep learning?"
# Mock the LLM to return a response with a trailing digit
self.llm.chat = MagicMock(return_value=ChatResponse(content="I recommend agent 2", total_tokens=10))
self.rag_router.find_last_digit = MagicMock(return_value="2")
agent, tokens = self.rag_router._route(query)
# Should select the second agent based on our mock response
self.assertEqual(agent, self.chain_of_rag)
self.assertTrue(self.rag_router.find_last_digit.called)
def test_retrieve(self):
"""Test the retrieve method."""
query = "What is the capital of France?"
# Mock the _route method to return the first agent
mock_retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Paris is the capital of France",
reference="test_reference",
metadata={"a": 1}
)
]
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5))
self.naive_rag.retrieve = MagicMock(return_value=(mock_retrieved_results, 10, {"some": "metadata"}))
results, tokens, metadata = self.rag_router.retrieve(query)
# Check if methods were called
self.rag_router._route.assert_called_once_with(query)
self.naive_rag.retrieve.assert_called_once_with(query)
# Check results
self.assertEqual(results, mock_retrieved_results)
self.assertEqual(tokens, 15) # 5 from route + 10 from retrieve
self.assertEqual(metadata, {"some": "metadata"})
def test_query(self):
"""Test the query method."""
query = "What is the capital of France?"
# Mock the _route method to return the first agent
mock_retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Paris is the capital of France",
reference="test_reference",
metadata={"a": 1}
)
]
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5))
self.naive_rag.query = MagicMock(return_value=("Paris is the capital of France", mock_retrieved_results, 10))
answer, results, tokens = self.rag_router.query(query)
# Check if methods were called
self.rag_router._route.assert_called_once_with(query)
self.naive_rag.query.assert_called_once_with(query)
# Check results
self.assertEqual(answer, "Paris is the capital of France")
self.assertEqual(results, mock_retrieved_results)
self.assertEqual(tokens, 15) # 5 from route + 10 from query
def test_find_last_digit(self):
"""Test the find_last_digit method."""
self.assertEqual(self.rag_router.find_last_digit("Agent 2 is better"), "2")
self.assertEqual(self.rag_router.find_last_digit("I recommend agent number 1"), "1")
self.assertEqual(self.rag_router.find_last_digit("Choose 3"), "3")
# Test with no digit
with self.assertRaises(ValueError):
self.rag_router.find_last_digit("No digits here")
def test_auto_description_fallback(self):
"""Test that RAGRouter falls back to __description__ when no descriptions provided."""
# Create classes with __description__ attribute
class MockAgent1:
__description__ = "Auto description 1"
class MockAgent2:
__description__ = "Auto description 2"
# Create instances of these classes
mock_agent1 = MagicMock(spec=MockAgent1)
mock_agent1.__class__ = MockAgent1
mock_agent2 = MagicMock(spec=MockAgent2)
mock_agent2.__class__ = MockAgent2
# Create the RAGRouter without explicit descriptions
router = RAGRouter(
llm=self.llm,
rag_agents=[mock_agent1, mock_agent2]
)
# Check that descriptions were pulled from the class attributes
self.assertEqual(len(router.agent_descriptions), 2)
self.assertEqual(router.agent_descriptions[0], "Auto description 1")
self.assertEqual(router.agent_descriptions[1], "Auto description 2")
if __name__ == "__main__":
import unittest
unittest.main()

1
tests/embedding/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher.embedding package

105
tests/embedding/test_base.py

@ -1,105 +0,0 @@
import unittest
from typing import List
from unittest.mock import patch, MagicMock
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.loader.splitter import Chunk
class ConcreteEmbedding(BaseEmbedding):
"""A concrete implementation of BaseEmbedding for testing."""
def __init__(self, dimension=768):
self._dimension = dimension
def embed_query(self, text: str) -> List[float]:
"""Simple implementation that returns a vector of the given dimension."""
return [0.1] * self._dimension
@property
def dimension(self) -> int:
return self._dimension
class TestBaseEmbedding(unittest.TestCase):
"""Tests for the BaseEmbedding base class."""
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test the embed_query method."""
embedding = ConcreteEmbedding()
result = embedding.embed_query("test text")
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test the embed_documents method."""
embedding = ConcreteEmbedding()
texts = ["text 1", "text 2", "text 3"]
results = embedding.embed_documents(texts)
# Check we got the right number of embeddings
self.assertEqual(len(results), 3)
# Check each embedding
for result in results:
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch('deepsearcher.embedding.base.tqdm')
@patch.dict('os.environ', {}, clear=True)
def test_embed_chunks(self, mock_tqdm):
"""Test the embed_chunks method."""
embedding = ConcreteEmbedding()
# Set up mock tqdm to just return the iterable
mock_tqdm.return_value = lambda x, desc: x
# Create test chunks
chunks = [
Chunk(text="text 1", reference="ref1"),
Chunk(text="text 2", reference="ref2"),
Chunk(text="text 3", reference="ref3")
]
# Create a spy on embed_documents
original_embed_documents = embedding.embed_documents
embed_documents_calls = []
def mock_embed_documents(texts):
embed_documents_calls.append(texts)
return original_embed_documents(texts)
embedding.embed_documents = mock_embed_documents
# Mock tqdm to return the batch_texts directly
mock_tqdm.side_effect = lambda x, **kwargs: x
# Call the method
result_chunks = embedding.embed_chunks(chunks, batch_size=2)
# Verify embed_documents was called correctly
self.assertEqual(len(embed_documents_calls), 2) # Should be called twice with batch_size=2
self.assertEqual(embed_documents_calls[0], ["text 1", "text 2"])
self.assertEqual(embed_documents_calls[1], ["text 3"])
# Verify chunks were updated with embeddings
self.assertEqual(len(result_chunks), 3)
for chunk in result_chunks:
self.assertEqual(len(chunk.embedding), 768)
self.assertEqual(chunk.embedding, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
embedding = ConcreteEmbedding()
self.assertEqual(embedding.dimension, 768)
# Test with different dimension
embedding = ConcreteEmbedding(dimension=1024)
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

148
tests/embedding/test_bedrock_embedding.py

@ -1,148 +0,0 @@
import unittest
import json
import os
from unittest.mock import patch, MagicMock
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.embedding import BedrockEmbedding
from deepsearcher.embedding.bedrock_embedding import (
MODEL_ID_TITAN_TEXT_V2,
MODEL_ID_TITAN_TEXT_G1,
MODEL_ID_COHERE_ENGLISH_V3,
)
class TestBedrockEmbedding(unittest.TestCase):
"""Tests for the BedrockEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_boto3 = MagicMock()
self.mock_client = MagicMock()
self.mock_boto3.client = MagicMock(return_value=self.mock_client)
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3})
self.module_patcher.start()
# Configure mock response
self.mock_response = {
"body": MagicMock(),
"ResponseMetadata": {"HTTPStatusCode": 200}
}
self.mock_response["body"].read.return_value = json.dumps({"embedding": [0.1] * 1024})
self.mock_client.invoke_model.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create instance to test
embedding = BedrockEmbedding()
# Check that boto3 client was created correctly
self.mock_boto3.client.assert_called_once_with(
"bedrock-runtime",
region_name="us-east-1",
aws_access_key_id=None,
aws_secret_access_key=None
)
# Check default model
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_V2)
# Ensure no coroutine warnings
self.mock_client.invoke_model.return_value = self.mock_response
@patch.dict('os.environ', {
'AWS_ACCESS_KEY_ID': 'test_key',
'AWS_SECRET_ACCESS_KEY': 'test_secret'
}, clear=True)
def test_init_with_credentials(self):
"""Test initialization with AWS credentials."""
embedding = BedrockEmbedding()
self.mock_boto3.client.assert_called_with(
"bedrock-runtime",
region_name="us-east-1",
aws_access_key_id="test_key",
aws_secret_access_key="test_secret"
)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_different_models(self):
"""Test initialization with different models."""
# Test Titan Text G1
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1)
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_G1)
# Test Cohere English V3
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3)
self.assertEqual(embedding.model, MODEL_ID_COHERE_ENGLISH_V3)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create instance to test
embedding = BedrockEmbedding()
query = "test query"
result = embedding.embed_query(query)
# Check that invoke_model was called correctly
self.mock_client.invoke_model.assert_called_once_with(
modelId=MODEL_ID_TITAN_TEXT_V2,
body=json.dumps({"inputText": query})
)
# Check result
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create instance to test
embedding = BedrockEmbedding()
texts = ["text 1", "text 2", "text 3"]
results = embedding.embed_documents(texts)
# Check that invoke_model was called for each text
self.assertEqual(self.mock_client.invoke_model.call_count, 3)
for text in texts:
self.mock_client.invoke_model.assert_any_call(
modelId=MODEL_ID_TITAN_TEXT_V2,
body=json.dumps({"inputText": text})
)
# Check results
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property for different models."""
# Create instance to test with Titan Text V2
embedding = BedrockEmbedding()
self.assertEqual(embedding.dimension, 1024)
# Test Titan Text G1
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1)
self.assertEqual(embedding.dimension, 1536)
# Test Cohere English V3
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3)
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

144
tests/embedding/test_fastembed_embedding.py

@ -1,144 +0,0 @@
import unittest
import numpy as np
from unittest.mock import patch, MagicMock
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.embedding import FastEmbedEmbedding
class TestFastEmbedEmbedding(unittest.TestCase):
"""Tests for the FastEmbedEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_fastembed = MagicMock()
self.mock_text_embedding = MagicMock()
self.mock_fastembed.TextEmbedding = MagicMock(return_value=self.mock_text_embedding)
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'fastembed': self.mock_fastembed})
self.module_patcher.start()
# Set up mock embeddings
self.mock_embedding = np.array([0.1] * 384) # BGE-small has 384 dimensions
self.mock_text_embedding.query_embed.return_value = iter([self.mock_embedding])
self.mock_text_embedding.embed.return_value = [self.mock_embedding] * 3
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create instance to test
embedding = FastEmbedEmbedding()
# Access a method to trigger lazy loading
embedding.embed_query("test")
# Check that TextEmbedding was initialized correctly
self.mock_fastembed.TextEmbedding.assert_called_once_with(
model_name="BAAI/bge-small-en-v1.5"
)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
custom_model = "custom/model-name"
embedding = FastEmbedEmbedding(model=custom_model)
# Access a method to trigger lazy loading
embedding.embed_query("test")
self.mock_fastembed.TextEmbedding.assert_called_with(
model_name=custom_model
)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_kwargs(self):
"""Test initialization with additional kwargs."""
kwargs = {"batch_size": 32, "max_length": 512}
embedding = FastEmbedEmbedding(**kwargs)
# Access a method to trigger lazy loading
embedding.embed_query("test")
self.mock_fastembed.TextEmbedding.assert_called_with(
model_name="BAAI/bge-small-en-v1.5",
**kwargs
)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create instance to test
embedding = FastEmbedEmbedding()
query = "test query"
result = embedding.embed_query(query)
# Check that query_embed was called correctly
self.mock_text_embedding.query_embed.assert_called_once_with([query])
# Check result
self.assertEqual(len(result), 384)
np.testing.assert_array_equal(result, [0.1] * 384)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create instance to test
embedding = FastEmbedEmbedding()
texts = ["text 1", "text 2", "text 3"]
results = embedding.embed_documents(texts)
# Check that embed was called correctly
self.mock_text_embedding.embed.assert_called_once_with(texts)
# Check results
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 384)
np.testing.assert_array_equal(result, [0.1] * 384)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create instance to test
embedding = FastEmbedEmbedding()
# Mock a sample embedding
sample_embedding = np.array([0.1] * 384)
self.mock_text_embedding.query_embed.return_value = iter([sample_embedding])
# Check dimension
self.assertEqual(embedding.dimension, 384)
# Verify that query_embed was called with sample text
self.mock_text_embedding.query_embed.assert_called_with(["SAMPLE TEXT"])
@patch.dict('os.environ', {}, clear=True)
def test_lazy_loading(self):
"""Test that the model is loaded lazily."""
# Create a new instance
embedding = FastEmbedEmbedding()
# Check that TextEmbedding wasn't called during initialization
self.mock_fastembed.TextEmbedding.reset_mock()
self.mock_fastembed.TextEmbedding.assert_not_called()
# Access a method that requires the model
embedding.embed_query("test")
# Now TextEmbedding should have been called
self.mock_fastembed.TextEmbedding.assert_called_once()
if __name__ == "__main__":
unittest.main()

266
tests/embedding/test_gemini_embedding.py

@ -1,266 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import logging
import warnings
import multiprocessing.resource_tracker
# Disable logging for tests
logging.disable(logging.CRITICAL)
# Suppress resource tracker warning
warnings.filterwarnings("ignore", category=ResourceWarning)
# Patch resource tracker to avoid warnings
def _resource_tracker():
pass
multiprocessing.resource_tracker._resource_tracker = _resource_tracker
from deepsearcher.embedding import GeminiEmbedding
from deepsearcher.embedding.gemini_embedding import GEMINI_MODEL_DIM_MAP
class TestGeminiEmbedding(unittest.TestCase):
"""Tests for the GeminiEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_genai = MagicMock()
self.mock_client = MagicMock()
self.mock_types = MagicMock()
# Set up the mock module structure
self.mock_genai.Client = MagicMock(return_value=self.mock_client)
self.mock_genai.types = self.mock_types
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'google.genai': self.mock_genai})
self.module_patcher.start()
# Set up mock response for embed_content
self.mock_response = MagicMock()
self.mock_response.embeddings = [
MagicMock(values=[0.1] * 768) # Default embedding for text-embedding-004
]
self.mock_client.models.embed_content.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create instance to test
embedding = GeminiEmbedding()
# Check that Client was initialized correctly
self.mock_genai.Client.assert_called_once_with(api_key=None)
# Check default model and dimension
self.assertEqual(embedding.model, "text-embedding-004")
self.assertEqual(embedding.dim, 768)
self.assertEqual(embedding.dimension, 768)
@patch.dict('os.environ', {'GEMINI_API_KEY': 'test_api_key_from_env'}, clear=True)
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
embedding = GeminiEmbedding()
self.mock_genai.Client.assert_called_with(api_key='test_api_key_from_env')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
api_key = "test_api_key_param"
embedding = GeminiEmbedding(api_key=api_key)
self.mock_genai.Client.assert_called_with(api_key=api_key)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
model = "gemini-embedding-exp-03-07"
embedding = GeminiEmbedding(model=model)
self.assertEqual(embedding.model, model)
self.assertEqual(embedding.dim, GEMINI_MODEL_DIM_MAP[model])
self.assertEqual(embedding.dimension, 3072)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_custom_dimension(self):
"""Test initialization with custom dimension."""
custom_dim = 1024
embedding = GeminiEmbedding(dimension=custom_dim)
self.assertEqual(embedding.dim, custom_dim)
self.assertEqual(embedding.dimension, custom_dim)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query_single_char(self):
"""Test embedding a single character query."""
# Create instance to test
embedding = GeminiEmbedding()
query = "a"
result = embedding.embed_query(query)
# Check that embed_content was called correctly
self.mock_client.models.embed_content.assert_called_once()
call_args = self.mock_client.models.embed_content.call_args
# For single character, it should be passed as is
self.assertEqual(call_args[1]["model"], "text-embedding-004")
self.assertEqual(call_args[1]["contents"], query)
# Check result
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query_multi_char(self):
"""Test embedding a multi-character query."""
# Create instance to test
embedding = GeminiEmbedding()
query = "test query"
result = embedding.embed_query(query)
# Check that embed_content was called correctly
self.mock_client.models.embed_content.assert_called_once()
call_args = self.mock_client.models.embed_content.call_args
# For multi-character string, it should be joined with spaces
self.assertEqual(call_args[1]["model"], "text-embedding-004")
self.assertEqual(call_args[1]["contents"], "t e s t q u e r y")
# Check result
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create instance to test
embedding = GeminiEmbedding()
# Set up mock response for multiple documents
mock_embeddings = [
MagicMock(values=[0.1] * 768),
MagicMock(values=[0.2] * 768),
MagicMock(values=[0.3] * 768)
]
self.mock_response.embeddings = mock_embeddings
texts = ["text 1", "text 2", "text 3"]
results = embedding.embed_documents(texts)
# Check that embed_content was called correctly
self.mock_client.models.embed_content.assert_called_once()
call_args = self.mock_client.models.embed_content.call_args
self.assertEqual(call_args[1]["model"], "text-embedding-004")
self.assertEqual(call_args[1]["contents"], texts)
# Check that EmbedContentConfig was used
config_arg = call_args[1]["config"]
self.mock_types.EmbedContentConfig.assert_called_once_with(output_dimensionality=768)
# Check results
self.assertEqual(len(results), 3)
expected_results = [[0.1] * 768, [0.2] * 768, [0.3] * 768]
for i, result in enumerate(results):
self.assertEqual(len(result), 768)
self.assertEqual(result, expected_results[i])
@patch.dict('os.environ', {}, clear=True)
def test_embed_chunks(self):
"""Test embedding chunks with batch processing."""
# Create instance to test
embedding = GeminiEmbedding()
# Set up mock response for batched documents
batch1_embeddings = [MagicMock(values=[0.1] * 768)] * 100
batch2_embeddings = [MagicMock(values=[0.2] * 768)] * 50
# Mock multiple calls to embed_content
self.mock_client.models.embed_content.side_effect = [
MagicMock(embeddings=batch1_embeddings),
MagicMock(embeddings=batch2_embeddings)
]
# Create mock chunks
class MockChunk:
def __init__(self, text: str):
self.text = text
self.embedding = None
chunks = [MockChunk(f"text {i}") for i in range(150)]
results = embedding.embed_chunks(chunks, batch_size=100)
# Check that embed_content was called twice (150 chunks split into 2 batches)
self.assertEqual(self.mock_client.models.embed_content.call_count, 2)
# Check that the same chunk objects are returned
self.assertEqual(len(results), 150)
self.assertEqual(results, chunks)
# Check that each chunk has an embedding
for i, chunk in enumerate(results):
self.assertIsNotNone(chunk.embedding)
if i < 100:
self.assertEqual(chunk.embedding, [0.1] * 768)
else:
self.assertEqual(chunk.embedding, [0.2] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property_different_models(self):
"""Test the dimension property for different models."""
# Create instance to test
embedding = GeminiEmbedding()
# Test default model
self.assertEqual(embedding.dimension, 768)
# Test experimental model
embedding_exp = GeminiEmbedding(model="gemini-embedding-exp-03-07")
self.assertEqual(embedding_exp.dimension, 3072)
# Test custom dimension
embedding_custom = GeminiEmbedding(dimension=512)
self.assertEqual(embedding_custom.dimension, 512)
@patch.dict('os.environ', {}, clear=True)
def test_get_dim_method(self):
"""Test the private _get_dim method."""
# Create instance to test
embedding = GeminiEmbedding()
# Test default dimension
self.assertEqual(embedding._get_dim(), 768)
# Test custom dimension
embedding_custom = GeminiEmbedding(dimension=1024)
self.assertEqual(embedding_custom._get_dim(), 1024)
@patch.dict('os.environ', {}, clear=True)
def test_embed_content_method(self):
"""Test the private _embed_content method."""
# Create instance to test
embedding = GeminiEmbedding()
texts = ["test text 1", "test text 2"]
result = embedding._embed_content(texts)
# Check that embed_content was called with correct parameters
self.mock_client.models.embed_content.assert_called_once()
call_args = self.mock_client.models.embed_content.call_args
self.assertEqual(call_args[1]["model"], "text-embedding-004")
self.assertEqual(call_args[1]["contents"], texts)
self.mock_types.EmbedContentConfig.assert_called_once_with(output_dimensionality=768)
# Check that the response embeddings are returned
self.assertEqual(result, self.mock_response.embeddings)
if __name__ == "__main__":
unittest.main()

143
tests/embedding/test_glm_embedding.py

@ -1,143 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
from deepsearcher.embedding import GLMEmbedding
class TestGLMEmbedding(unittest.TestCase):
"""Tests for the GLMEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_zhipuai = MagicMock()
self.mock_client = MagicMock()
self.mock_embeddings = MagicMock()
# Set up mock response
mock_data_item = MagicMock()
mock_data_item.embedding = [0.1] * 2048 # embedding-3 has 2048 dimensions
mock_response = MagicMock()
mock_response.data = [mock_data_item]
self.mock_embeddings.create.return_value = mock_response
# Set up the mock module structure
self.mock_zhipuai.ZhipuAI.return_value = self.mock_client
self.mock_client.embeddings = self.mock_embeddings
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'zhipuai': self.mock_zhipuai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = GLMEmbedding()
# Check that ZhipuAI was initialized correctly
self.mock_zhipuai.ZhipuAI.assert_called_once_with(
api_key='fake-api-key',
base_url='https://open.bigmodel.cn/api/paas/v4/'
)
# Check attributes
self.assertEqual(embedding.model, 'embedding-3')
self.assertEqual(embedding.client, self.mock_client)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = GLMEmbedding(api_key='test-api-key')
# Check that ZhipuAI was initialized with the provided API key
self.mock_zhipuai.ZhipuAI.assert_called_with(
api_key='test-api-key',
base_url='https://open.bigmodel.cn/api/paas/v4/'
)
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_base_url(self):
"""Test initialization with base URL parameter."""
# Initialize with base URL
embedding = GLMEmbedding(base_url='https://custom-api.example.com')
# Check that ZhipuAI was initialized with the provided base URL
self.mock_zhipuai.ZhipuAI.assert_called_with(
api_key='fake-api-key',
base_url='https://custom-api.example.com'
)
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = GLMEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that create was called correctly
self.mock_embeddings.create.assert_called_once_with(
input=[query],
model='embedding-3'
)
# Check the result
self.assertEqual(result, [0.1] * 2048)
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = GLMEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
mock_data_items = []
for i in range(3):
mock_data_item = MagicMock()
mock_data_item.embedding = [0.1 * (i + 1)] * 2048
mock_data_items.append(mock_data_item)
mock_response = MagicMock()
mock_response.data = mock_data_items
self.mock_embeddings.create.return_value = mock_response
# Call the method
results = embedding.embed_documents(texts)
# Verify that create was called correctly
self.mock_embeddings.create.assert_called_once_with(
input=texts,
model='embedding-3'
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 2048)
@patch.dict('os.environ', {'GLM_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = GLMEmbedding()
# For embedding-3
self.assertEqual(embedding.dimension, 2048)
if __name__ == "__main__":
unittest.main()

130
tests/embedding/test_milvus_embedding.py

@ -1,130 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
from deepsearcher.embedding import MilvusEmbedding
class TestMilvusEmbedding(unittest.TestCase):
"""Tests for the MilvusEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_pymilvus = MagicMock()
self.mock_model = MagicMock()
self.mock_default_embedding = MagicMock()
self.mock_jina_embedding = MagicMock()
self.mock_sentence_transformer = MagicMock()
# Set up the mock module structure
self.mock_pymilvus.model = self.mock_model
self.mock_model.DefaultEmbeddingFunction = MagicMock(return_value=self.mock_default_embedding)
self.mock_model.dense = MagicMock()
self.mock_model.dense.JinaEmbeddingFunction = MagicMock(return_value=self.mock_jina_embedding)
self.mock_model.dense.SentenceTransformerEmbeddingFunction = MagicMock(return_value=self.mock_sentence_transformer)
# Set up default dimensions and responses
self.mock_default_embedding.dim = 768
self.mock_jina_embedding.dim = 1024
self.mock_sentence_transformer.dim = 1024
# Set up mock responses for encoding
self.mock_default_embedding.encode_queries.return_value = [np.array([0.1] * 768)]
self.mock_default_embedding.encode_documents.return_value = [np.array([0.1] * 768)]
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'pymilvus': self.mock_pymilvus})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
embedding = MilvusEmbedding()
# Check that default model was initialized
self.mock_model.DefaultEmbeddingFunction.assert_called_once()
self.assertEqual(embedding.model, self.mock_default_embedding)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_jina_model(self):
"""Test initialization with Jina model."""
embedding = MilvusEmbedding(model='jina-embeddings-v3')
# Check that Jina model was initialized
self.mock_model.dense.JinaEmbeddingFunction.assert_called_once_with('jina-embeddings-v3')
self.assertEqual(embedding.model, self.mock_jina_embedding)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_bge_model(self):
"""Test initialization with BGE model."""
embedding = MilvusEmbedding(model='BAAI/bge-large-en-v1.5')
# Check that SentenceTransformer model was initialized
self.mock_model.dense.SentenceTransformerEmbeddingFunction.assert_called_once_with('BAAI/bge-large-en-v1.5')
self.assertEqual(embedding.model, self.mock_sentence_transformer)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_invalid_model(self):
"""Test initialization with invalid model raises error."""
with self.assertRaises(ValueError):
MilvusEmbedding(model='invalid-model')
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
embedding = MilvusEmbedding()
query = "This is a test query"
result = embedding.embed_query(query)
# Check that encode_queries was called correctly
self.mock_default_embedding.encode_queries.assert_called_once_with([query])
# Convert numpy array to list for comparison
expected = [0.1] * 768
np.testing.assert_array_almost_equal(result, expected)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
embedding = MilvusEmbedding()
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
mock_embeddings = [np.array([0.1 * (i + 1)] * 768) for i in range(3)]
self.mock_default_embedding.encode_documents.return_value = mock_embeddings
results = embedding.embed_documents(texts)
# Check that encode_documents was called correctly
self.mock_default_embedding.encode_documents.assert_called_once_with(texts)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
expected = [0.1 * (i + 1)] * 768
np.testing.assert_array_almost_equal(result, expected)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# For default model
embedding = MilvusEmbedding()
self.assertEqual(embedding.dimension, 768)
# For Jina model
embedding = MilvusEmbedding(model='jina-embeddings-v3')
self.assertEqual(embedding.dimension, 1024)
# For BGE model
embedding = MilvusEmbedding(model='BAAI/bge-large-en-v1.5')
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

193
tests/embedding/test_novita_embedding.py

@ -1,193 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import requests
from deepsearcher.embedding import NovitaEmbedding
class TestNovitaEmbedding(unittest.TestCase):
"""Tests for the NovitaEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for requests
self.requests_patcher = patch('requests.request')
self.mock_request = self.requests_patcher.start()
# Set up mock response
self.mock_response = MagicMock()
self.mock_response.json.return_value = {
'data': [
{'index': 0, 'embedding': [0.1] * 1024} # baai/bge-m3 has 1024 dimensions
]
}
self.mock_response.raise_for_status = MagicMock()
self.mock_request.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.requests_patcher.stop()
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = NovitaEmbedding()
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
self.assertEqual(embedding.api_key, 'fake-api-key')
self.assertEqual(embedding.batch_size, 32)
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = NovitaEmbedding(model='baai/bge-m3')
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
self.assertEqual(embedding.dimension, 1024)
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = NovitaEmbedding(model_name='baai/bge-m3')
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = NovitaEmbedding(api_key='test-api-key')
# Check that the API key was set correctly
self.assertEqual(embedding.api_key, 'test-api-key')
@patch.dict('os.environ', {}, clear=True)
def test_init_without_api_key(self):
"""Test initialization without API key raises error."""
with self.assertRaises(RuntimeError):
NovitaEmbedding()
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = NovitaEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.novita.ai/v3/openai/embeddings',
json={
'model': 'baai/bge-m3',
'input': query,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the result
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = NovitaEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
self.mock_response.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024}
for i in range(3)
]
}
# Call the method
results = embedding.embed_documents(texts)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.novita.ai/v3/openai/embeddings',
json={
'model': 'baai/bge-m3',
'input': texts,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 1024)
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents_with_batching(self):
"""Test embedding documents with batching."""
# Create the embedder
embedding = NovitaEmbedding()
# Create test documents
texts = ["text " + str(i) for i in range(50)] # More than batch_size
# Set up mock response for batched documents
def mock_batch_response(*args, **kwargs):
batch_input = kwargs['json']['input']
mock_resp = MagicMock()
mock_resp.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1] * 1024}
for i in range(len(batch_input))
]
}
mock_resp.raise_for_status = MagicMock()
return mock_resp
self.mock_request.side_effect = mock_batch_response
# Call the method
results = embedding.embed_documents(texts)
# Check that request was called multiple times
self.assertTrue(self.mock_request.call_count > 1)
# Check the results
self.assertEqual(len(results), 50)
for result in results:
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'NOVITA_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = NovitaEmbedding()
# For baai/bge-m3
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

239
tests/embedding/test_ollama_embedding.py

@ -1,239 +0,0 @@
import unittest
import sys
from unittest.mock import patch, MagicMock
from deepsearcher.embedding import OllamaEmbedding
class TestOllamaEmbedding(unittest.TestCase):
"""Tests for the OllamaEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module for ollama
mock_ollama_module = MagicMock()
# Create mock Client class
self.mock_ollama_client = MagicMock()
mock_ollama_module.Client = self.mock_ollama_client
# Add the mock module to sys.modules
self.module_patcher = patch.dict('sys.modules', {'ollama': mock_ollama_module})
self.module_patcher.start()
# Set up mock client instance
self.mock_client = MagicMock()
self.mock_ollama_client.return_value = self.mock_client
# Configure mock embed method
self.mock_client.embed.return_value = {"embeddings": [[0.1] * 1024]}
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Check that Client was initialized correctly
self.mock_ollama_client.assert_called_once_with(host="http://localhost:11434/")
# Check instance attributes
self.assertEqual(embedding.model, "bge-m3")
self.assertEqual(embedding.dim, 1024)
self.assertEqual(embedding.batch_size, 32)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_base_url(self):
"""Test initialization with custom base URL."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with custom base URL
embedding = OllamaEmbedding(base_url="http://custom-ollama-server:11434/")
# Check that Client was initialized with custom base URL
self.mock_ollama_client.assert_called_with(host="http://custom-ollama-server:11434/")
@patch.dict('os.environ', {}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with model_name
embedding = OllamaEmbedding(model_name="mxbai-embed-large")
# Check model attribute
self.assertEqual(embedding.model, "mxbai-embed-large")
# Check dimension is set correctly based on model
self.assertEqual(embedding.dim, 768)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_dimension(self):
"""Test initialization with custom dimension."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with custom dimension
embedding = OllamaEmbedding(dimension=512)
# Check dimension attribute
self.assertEqual(embedding.dim, 512)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set up mock response
self.mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3] * 341 + [0.4]]} # 1024 dimensions
# Call the method
result = embedding.embed_query("test query")
# Verify embed was called correctly
self.mock_client.embed.assert_called_once_with(model="bge-m3", input="test query")
# Check the result
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1, 0.2, 0.3] * 341 + [0.4])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_small_batch(self):
"""Test embedding documents with a small batch (less than batch size)."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set up mock response for multiple documents
mock_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7],
[0.7, 0.8, 0.9] * 341 + [0.1]
]
self.mock_client.embed.return_value = {"embeddings": mock_embeddings}
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify embed was called correctly
self.mock_client.embed.assert_called_once_with(model="bge-m3", input=texts)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(len(result), 1024)
self.assertEqual(result, mock_embeddings[i])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_large_batch(self):
"""Test embedding documents with a large batch (more than batch size)."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set a smaller batch size for testing
embedding.batch_size = 2
# Set up mock responses for batches
batch1_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7]
]
batch2_embeddings = [
[0.7, 0.8, 0.9] * 341 + [0.1]
]
# Configure mock to return different responses for each call
self.mock_client.embed.side_effect = [
{"embeddings": batch1_embeddings},
{"embeddings": batch2_embeddings}
]
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify embed was called twice with the right batches
self.assertEqual(self.mock_client.embed.call_count, 2)
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 1", "text 2"])
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 3"])
# Check the results
self.assertEqual(len(results), 3)
self.assertEqual(results[0], batch1_embeddings[0])
self.assertEqual(results[1], batch1_embeddings[1])
self.assertEqual(results[2], batch2_embeddings[0])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_no_batching(self):
"""Test embedding documents with batching disabled."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Disable batching
embedding.batch_size = 0
# Mock the embed_query method
original_embed_query = embedding.embed_query
embed_query_calls = []
def mock_embed_query(text):
embed_query_calls.append(text)
return [0.1] * 1024 # Return a simple mock embedding
embedding.embed_query = mock_embed_query
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Check that embed_query was called for each text
self.assertEqual(len(embed_query_calls), 3)
self.assertEqual(embed_query_calls, texts)
# Check the results
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1] * 1024)
# Restore original method
embedding.embed_query = original_embed_query
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Check dimension for bge-m3
self.assertEqual(embedding.dimension, 1024)
# Test with different models
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(model="mxbai-embed-large")
self.assertEqual(embedding.dimension, 768)
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(model="nomic-embed-text")
self.assertEqual(embedding.dimension, 768)
# Test with custom dimension
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(dimension=512)
self.assertEqual(embedding.dimension, 512)
if __name__ == "__main__":
unittest.main()

272
tests/embedding/test_openai_embedding.py

@ -1,272 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock, ANY
from openai._types import NOT_GIVEN
from deepsearcher.embedding import OpenAIEmbedding
class TestOpenAIEmbedding(unittest.TestCase):
"""Tests for the OpenAIEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for OpenAI classes
self.openai_patcher = patch('openai.OpenAI')
self.mock_openai = self.openai_patcher.start()
# Set up mock client
self.mock_client = MagicMock()
self.mock_openai.return_value = self.mock_client
# Set up mock embeddings
self.mock_embeddings = MagicMock()
self.mock_client.embeddings = self.mock_embeddings
# Set up mock response for embed_query
mock_data_item = MagicMock()
mock_data_item.embedding = [0.1] * 1536
self.mock_response = MagicMock()
self.mock_response.data = [mock_data_item]
self.mock_embeddings.create.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.openai_patcher.stop()
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = OpenAIEmbedding()
# Check that OpenAI was initialized correctly
self.mock_openai.assert_called_once_with(api_key='fake-api-key', base_url=None)
# Check attributes
self.assertEqual(embedding.model, 'text-embedding-ada-002')
self.assertEqual(embedding.dim, 1536)
self.assertFalse(embedding.is_azure)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = OpenAIEmbedding(model='text-embedding-3-large')
# Check attributes
self.assertEqual(embedding.model, 'text-embedding-3-large')
self.assertEqual(embedding.dim, 3072)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = OpenAIEmbedding(model_name='text-embedding-3-small')
# Check attributes
self.assertEqual(embedding.model, 'text-embedding-3-small')
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_dimension(self):
"""Test initialization with specified dimension."""
# Initialize with custom dimension
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512)
# Check attributes
self.assertEqual(embedding.dim, 512)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = OpenAIEmbedding(api_key='test-api-key')
# Check that OpenAI was initialized with the provided API key
self.mock_openai.assert_called_with(api_key='test-api-key', base_url=None)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_base_url(self):
"""Test initialization with base URL parameter."""
# Initialize with base URL
embedding = OpenAIEmbedding(base_url='https://test-openai-api.com')
# Check that OpenAI was initialized with the provided base URL
self.mock_openai.assert_called_with(api_key='fake-api-key', base_url='https://test-openai-api.com')
@patch('openai.AzureOpenAI')
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_azure(self, mock_azure_openai):
"""Test initialization with Azure OpenAI."""
# Set up mock Azure client
mock_azure_client = MagicMock()
mock_azure_openai.return_value = mock_azure_client
# Initialize with Azure endpoint
embedding = OpenAIEmbedding(
azure_endpoint='https://test-azure.openai.azure.com',
api_key='test-azure-key',
api_version='2023-05-15'
)
# Check that AzureOpenAI was initialized correctly
mock_azure_openai.assert_called_once_with(
api_key='test-azure-key',
api_version='2023-05-15',
azure_endpoint='https://test-azure.openai.azure.com'
)
# Check attributes
self.assertEqual(embedding.model, 'text-embedding-ada-002')
self.assertEqual(embedding.client, mock_azure_client)
self.assertTrue(embedding.is_azure)
self.assertEqual(embedding.deployment, 'text-embedding-ada-002')
@patch('openai.AzureOpenAI')
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_azure_deployment(self, mock_azure_openai):
"""Test initialization with Azure OpenAI and custom deployment."""
# Set up mock Azure client
mock_azure_client = MagicMock()
mock_azure_openai.return_value = mock_azure_client
# Initialize with Azure endpoint and deployment
embedding = OpenAIEmbedding(
azure_endpoint='https://test-azure.openai.azure.com',
azure_deployment='test-deployment'
)
# Check attributes
self.assertEqual(embedding.deployment, 'test-deployment')
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_get_dim(self):
"""Test the _get_dim method."""
# Create the embedder
embedding = OpenAIEmbedding()
# For text-embedding-ada-002
self.assertIs(embedding._get_dim(), NOT_GIVEN)
# For text-embedding-3-small
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512)
self.assertEqual(embedding._get_dim(), 512)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = OpenAIEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that create was called correctly
self.mock_embeddings.create.assert_called_once_with(
input=[query],
model='text-embedding-ada-002',
dimensions=ANY
)
# Check the result
self.assertEqual(result, [0.1] * 1536)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query_azure(self):
"""Test embedding a single query with Azure."""
# Set up Azure embedding
with patch('openai.AzureOpenAI') as mock_azure_openai:
# Set up mock Azure client
mock_azure_client = MagicMock()
mock_azure_openai.return_value = mock_azure_client
# Set up mock embeddings
mock_azure_embeddings = MagicMock()
mock_azure_client.embeddings = mock_azure_embeddings
# Set up mock response
mock_data_item = MagicMock()
mock_data_item.embedding = [0.2] * 1536
mock_response = MagicMock()
mock_response.data = [mock_data_item]
mock_azure_embeddings.create.return_value = mock_response
# Initialize with Azure endpoint
embedding = OpenAIEmbedding(
azure_endpoint='https://test-azure.openai.azure.com',
azure_deployment='test-deployment'
)
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that create was called correctly
mock_azure_embeddings.create.assert_called_once_with(
input=[query],
model='text-embedding-ada-002' # For Azure, this is the deployment name
)
# Check the result
self.assertEqual(result, [0.2] * 1536)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = OpenAIEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
mock_data_items = []
for i in range(3):
mock_data_item = MagicMock()
mock_data_item.embedding = [0.1 * (i + 1)] * 1536
mock_data_items.append(mock_data_item)
mock_response = MagicMock()
mock_response.data = mock_data_items
self.mock_embeddings.create.return_value = mock_response
# Call the method
results = embedding.embed_documents(texts)
# Verify that create was called correctly
self.mock_embeddings.create.assert_called_once_with(
input=texts,
model='text-embedding-ada-002',
dimensions=ANY
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 1536)
@patch.dict('os.environ', {'OPENAI_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = OpenAIEmbedding()
# For text-embedding-ada-002
self.assertEqual(embedding.dimension, 1536)
# For text-embedding-3-small
embedding = OpenAIEmbedding(model='text-embedding-3-small', dimension=512)
self.assertEqual(embedding.dimension, 512)
# For text-embedding-3-large
embedding = OpenAIEmbedding(model='text-embedding-3-large')
self.assertEqual(embedding.dimension, 3072)
if __name__ == "__main__":
unittest.main()

191
tests/embedding/test_ppio_embedding.py

@ -1,191 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import requests
from deepsearcher.embedding import PPIOEmbedding
class TestPPIOEmbedding(unittest.TestCase):
"""Tests for the PPIOEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for requests
self.requests_patcher = patch('requests.request')
self.mock_request = self.requests_patcher.start()
# Set up mock response
self.mock_response = MagicMock()
self.mock_response.json.return_value = {
'data': [
{'index': 0, 'embedding': [0.1] * 1024} # baai/bge-m3 has 1024 dimensions
]
}
self.mock_response.raise_for_status = MagicMock()
self.mock_request.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.requests_patcher.stop()
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = PPIOEmbedding()
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
self.assertEqual(embedding.api_key, 'fake-api-key')
self.assertEqual(embedding.batch_size, 32)
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = PPIOEmbedding(model='baai/bge-m3')
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
self.assertEqual(embedding.dimension, 1024)
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = PPIOEmbedding(model_name='baai/bge-m3')
# Check attributes
self.assertEqual(embedding.model, 'baai/bge-m3')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = PPIOEmbedding(api_key='test-api-key')
# Check that the API key was set correctly
self.assertEqual(embedding.api_key, 'test-api-key')
@patch.dict('os.environ', {}, clear=True)
def test_init_without_api_key(self):
"""Test initialization without API key raises error."""
with self.assertRaises(RuntimeError):
PPIOEmbedding()
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = PPIOEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.ppinfra.com/v3/openai/embeddings',
json={
'model': 'baai/bge-m3',
'input': [query]
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the result
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = PPIOEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
self.mock_response.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024}
for i in range(3)
]
}
# Call the method
results = embedding.embed_documents(texts)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.ppinfra.com/v3/openai/embeddings',
json={
'model': 'baai/bge-m3',
'input': texts
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 1024)
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents_with_batching(self):
"""Test embedding documents with batching."""
# Create the embedder
embedding = PPIOEmbedding()
# Create test documents
texts = ["text " + str(i) for i in range(50)] # More than batch_size
# Set up mock response for batched documents
def mock_batch_response(*args, **kwargs):
batch_input = kwargs['json']['input']
mock_resp = MagicMock()
mock_resp.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1] * 1024}
for i in range(len(batch_input))
]
}
mock_resp.raise_for_status = MagicMock()
return mock_resp
self.mock_request.side_effect = mock_batch_response
# Call the method
results = embedding.embed_documents(texts)
# Check that request was called multiple times
self.assertTrue(self.mock_request.call_count > 1)
# Check the results
self.assertEqual(len(results), 50)
for result in results:
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'PPIO_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = PPIOEmbedding()
# For baai/bge-m3
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

213
tests/embedding/test_sentence_transformer_embedding.py

@ -1,213 +0,0 @@
import unittest
import sys
import logging
from unittest.mock import patch, MagicMock
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.embedding import SentenceTransformerEmbedding
class TestSentenceTransformerEmbedding(unittest.TestCase):
"""Tests for the SentenceTransformerEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module for sentence_transformers
mock_st_module = MagicMock()
# Create mock SentenceTransformer class
self.mock_sentence_transformer = MagicMock()
mock_st_module.SentenceTransformer = self.mock_sentence_transformer
# Add the mock module to sys.modules
self.module_patcher = patch.dict('sys.modules', {'sentence_transformers': mock_st_module})
self.module_patcher.start()
# Set up mock instance
self.mock_model = MagicMock()
self.mock_sentence_transformer.return_value = self.mock_model
# Configure mock encode method
mock_embedding = [[0.1, 0.2, 0.3] * 341 + [0.4]] # 1024 dimensions
self.mock_model.encode.return_value = MagicMock()
self.mock_model.encode.return_value.tolist.return_value = mock_embedding
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init(self):
"""Test initialization."""
# Create instance to test
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
# Check that SentenceTransformer was called with the right model
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-m3")
# Check that model and client were set correctly
self.assertEqual(embedding.model, "BAAI/bge-m3")
self.assertEqual(embedding.client, self.mock_model)
# Check batch size default
self.assertEqual(embedding.batch_size, 32)
# Test with model_name parameter
self.mock_sentence_transformer.reset_mock()
embedding = SentenceTransformerEmbedding(model_name="BAAI/bge-large-zh-v1.5")
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-large-zh-v1.5")
self.assertEqual(embedding.model, "BAAI/bge-large-zh-v1.5")
# Test with custom batch size
self.mock_sentence_transformer.reset_mock()
embedding = SentenceTransformerEmbedding(batch_size=64)
self.assertEqual(embedding.batch_size, 64)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create instance to test
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
# Mock the encode response for a single query
single_embedding = [0.1, 0.2, 0.3] * 341 + [0.4] # 1024 dimensions
self.mock_model.encode.return_value = MagicMock()
self.mock_model.encode.return_value.tolist.return_value = [single_embedding]
# Call the method
result = embedding.embed_query("test query")
# Verify encode was called correctly
self.mock_model.encode.assert_called_once_with("test query")
# Check the result
self.assertEqual(len(result), 1024)
self.assertEqual(result, single_embedding)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_small_batch(self):
"""Test embedding documents with a small batch (less than batch size)."""
# Create instance to test
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
# Mock the encode response for documents
batch_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7],
[0.7, 0.8, 0.9] * 341 + [0.1]
]
self.mock_model.encode.return_value = MagicMock()
self.mock_model.encode.return_value.tolist.return_value = batch_embeddings
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify encode was called correctly
self.mock_model.encode.assert_called_once_with(texts)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(len(result), 1024)
self.assertEqual(result, batch_embeddings[i])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_large_batch(self):
"""Test embedding documents with a large batch (more than batch size)."""
# Create instance to test with small batch size
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=2)
# Mock the encode response for the first batch
batch1_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7]
]
# Mock the encode response for the second batch
batch2_embeddings = [
[0.7, 0.8, 0.9] * 341 + [0.1]
]
# Set up the mock to return different values on each call
self.mock_model.encode.side_effect = [
MagicMock(tolist=lambda: batch1_embeddings),
MagicMock(tolist=lambda: batch2_embeddings)
]
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify encode was called twice with the right batches
self.assertEqual(self.mock_model.encode.call_count, 2)
self.mock_model.encode.assert_any_call(["text 1", "text 2"])
self.mock_model.encode.assert_any_call(["text 3"])
# Check the results
self.assertEqual(len(results), 3)
self.assertEqual(results[0], batch1_embeddings[0])
self.assertEqual(results[1], batch1_embeddings[1])
self.assertEqual(results[2], batch2_embeddings[0])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_no_batching(self):
"""Test embedding documents with batching disabled."""
# Create instance to test with batching disabled
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=0)
# Mock the embed_query method
original_embed_query = embedding.embed_query
embed_query_calls = []
def mock_embed_query(text):
embed_query_calls.append(text)
return [0.1] * 1024 # Return a simple mock embedding
embedding.embed_query = mock_embed_query
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Check that embed_query was called for each text
self.assertEqual(len(embed_query_calls), 3)
self.assertEqual(embed_query_calls, texts)
# Check the results
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1] * 1024)
# Restore original method
embedding.embed_query = original_embed_query
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create instance to test
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
# Check dimension for BAAI/bge-m3
self.assertEqual(embedding.dimension, 1024)
# Test with different models
self.mock_sentence_transformer.reset_mock()
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-zh-v1.5")
self.assertEqual(embedding.dimension, 1024)
self.mock_sentence_transformer.reset_mock()
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-en-v1.5")
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

201
tests/embedding/test_siliconflow_embedding.py

@ -1,201 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import requests
from deepsearcher.embedding import SiliconflowEmbedding
class TestSiliconflowEmbedding(unittest.TestCase):
"""Tests for the SiliconflowEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for requests
self.requests_patcher = patch('requests.request')
self.mock_request = self.requests_patcher.start()
# Set up mock response
self.mock_response = MagicMock()
self.mock_response.json.return_value = {
'data': [
{'index': 0, 'embedding': [0.1] * 1024} # BAAI/bge-m3 has 1024 dimensions
]
}
self.mock_response.raise_for_status = MagicMock()
self.mock_request.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.requests_patcher.stop()
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = SiliconflowEmbedding()
# Check attributes
self.assertEqual(embedding.model, 'BAAI/bge-m3')
self.assertEqual(embedding.api_key, 'fake-api-key')
self.assertEqual(embedding.batch_size, 32)
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = SiliconflowEmbedding(model='netease-youdao/bce-embedding-base_v1')
# Check attributes
self.assertEqual(embedding.model, 'netease-youdao/bce-embedding-base_v1')
self.assertEqual(embedding.dimension, 768)
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = SiliconflowEmbedding(model_name='BAAI/bge-large-zh-v1.5')
# Check attributes
self.assertEqual(embedding.model, 'BAAI/bge-large-zh-v1.5')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = SiliconflowEmbedding(api_key='test-api-key')
# Check that the API key was set correctly
self.assertEqual(embedding.api_key, 'test-api-key')
@patch.dict('os.environ', {}, clear=True)
def test_init_without_api_key(self):
"""Test initialization without API key raises error."""
with self.assertRaises(RuntimeError):
SiliconflowEmbedding()
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = SiliconflowEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.siliconflow.cn/v1/embeddings',
json={
'model': 'BAAI/bge-m3',
'input': query,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the result
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = SiliconflowEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
self.mock_response.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1 * (i + 1)] * 1024}
for i in range(3)
]
}
# Call the method
results = embedding.embed_documents(texts)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://api.siliconflow.cn/v1/embeddings',
json={
'model': 'BAAI/bge-m3',
'input': texts,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 1024)
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents_with_batching(self):
"""Test embedding documents with batching."""
# Create the embedder
embedding = SiliconflowEmbedding()
# Create test documents
texts = ["text " + str(i) for i in range(50)] # More than batch_size
# Set up mock response for batched documents
def mock_batch_response(*args, **kwargs):
batch_input = kwargs['json']['input']
mock_resp = MagicMock()
mock_resp.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1] * 1024}
for i in range(len(batch_input))
]
}
mock_resp.raise_for_status = MagicMock()
return mock_resp
self.mock_request.side_effect = mock_batch_response
# Call the method
results = embedding.embed_documents(texts)
# Check that request was called multiple times
self.assertTrue(self.mock_request.call_count > 1)
# Check the results
self.assertEqual(len(results), 50)
for result in results:
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'SILICONFLOW_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = SiliconflowEmbedding()
# For BAAI/bge-m3
self.assertEqual(embedding.dimension, 1024)
# For netease-youdao/bce-embedding-base_v1
embedding = SiliconflowEmbedding(model='netease-youdao/bce-embedding-base_v1')
self.assertEqual(embedding.dimension, 768)
# For BAAI/bge-large-zh-v1.5
embedding = SiliconflowEmbedding(model='BAAI/bge-large-zh-v1.5')
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

201
tests/embedding/test_volcengine_embedding.py

@ -1,201 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import requests
from deepsearcher.embedding import VolcengineEmbedding
class TestVolcengineEmbedding(unittest.TestCase):
"""Tests for the VolcengineEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for requests
self.requests_patcher = patch('requests.request')
self.mock_request = self.requests_patcher.start()
# Set up mock response
self.mock_response = MagicMock()
self.mock_response.json.return_value = {
'data': [
{'index': 0, 'embedding': [0.1] * 4096} # doubao-embedding-large-text-240915 has 4096 dimensions
]
}
self.mock_response.raise_for_status = MagicMock()
self.mock_request.return_value = self.mock_response
def tearDown(self):
"""Clean up test fixtures."""
self.requests_patcher.stop()
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = VolcengineEmbedding()
# Check attributes
self.assertEqual(embedding.model, 'doubao-embedding-large-text-240915')
self.assertEqual(embedding.api_key, 'fake-api-key')
self.assertEqual(embedding.batch_size, 256)
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = VolcengineEmbedding(model='doubao-embedding-text-240515')
# Check attributes
self.assertEqual(embedding.model, 'doubao-embedding-text-240515')
self.assertEqual(embedding.dimension, 2048)
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = VolcengineEmbedding(model_name='doubao-embedding-text-240715')
# Check attributes
self.assertEqual(embedding.model, 'doubao-embedding-text-240715')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = VolcengineEmbedding(api_key='test-api-key')
# Check that the API key was set correctly
self.assertEqual(embedding.api_key, 'test-api-key')
@patch.dict('os.environ', {}, clear=True)
def test_init_without_api_key(self):
"""Test initialization without API key raises error."""
with self.assertRaises(RuntimeError):
VolcengineEmbedding()
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = VolcengineEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://ark.cn-beijing.volces.com/api/v3/embeddings',
json={
'model': 'doubao-embedding-large-text-240915',
'input': query,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the result
self.assertEqual(result, [0.1] * 4096)
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = VolcengineEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
self.mock_response.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1 * (i + 1)] * 4096}
for i in range(3)
]
}
# Call the method
results = embedding.embed_documents(texts)
# Verify that request was called correctly
self.mock_request.assert_called_once_with(
'POST',
'https://ark.cn-beijing.volces.com/api/v3/embeddings',
json={
'model': 'doubao-embedding-large-text-240915',
'input': texts,
'encoding_format': 'float'
},
headers={
'Authorization': 'Bearer fake-api-key',
'Content-Type': 'application/json'
}
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 4096)
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents_with_batching(self):
"""Test embedding documents with batching."""
# Create the embedder
embedding = VolcengineEmbedding()
# Create test documents
texts = ["text " + str(i) for i in range(300)] # More than batch_size
# Set up mock response for batched documents
def mock_batch_response(*args, **kwargs):
batch_input = kwargs['json']['input']
mock_resp = MagicMock()
mock_resp.json.return_value = {
'data': [
{'index': i, 'embedding': [0.1] * 4096}
for i in range(len(batch_input))
]
}
mock_resp.raise_for_status = MagicMock()
return mock_resp
self.mock_request.side_effect = mock_batch_response
# Call the method
results = embedding.embed_documents(texts)
# Check that request was called multiple times
self.assertTrue(self.mock_request.call_count > 1)
# Check the results
self.assertEqual(len(results), 300)
for result in results:
self.assertEqual(result, [0.1] * 4096)
@patch.dict('os.environ', {'VOLCENGINE_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = VolcengineEmbedding()
# For doubao-embedding-large-text-240915
self.assertEqual(embedding.dimension, 4096)
# For doubao-embedding-text-240715
embedding = VolcengineEmbedding(model='doubao-embedding-text-240715')
self.assertEqual(embedding.dimension, 2560)
# For doubao-embedding-text-240515
embedding = VolcengineEmbedding(model='doubao-embedding-text-240515')
self.assertEqual(embedding.dimension, 2048)
if __name__ == "__main__":
unittest.main()

144
tests/embedding/test_voyage_embedding.py

@ -1,144 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
from deepsearcher.embedding import VoyageEmbedding
class TestVoyageEmbedding(unittest.TestCase):
"""Tests for the VoyageEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create a mock module
self.mock_voyageai = MagicMock()
self.mock_client = MagicMock()
# Set up mock response for embed
mock_response = MagicMock()
mock_response.embeddings = [[0.1] * 1024] # voyage-3 has 1024 dimensions
self.mock_client.embed.return_value = mock_response
# Set up the mock module
self.mock_voyageai.Client.return_value = self.mock_client
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'voyageai': self.mock_voyageai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create the embedder
embedding = VoyageEmbedding()
# Check that voyageai was initialized correctly
self.mock_voyageai.Client.assert_called_once()
# Check attributes
self.assertEqual(embedding.model, 'voyage-3')
self.assertEqual(embedding.voyageai_api_key, 'fake-api-key')
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model(self):
"""Test initialization with specified model."""
# Initialize with a different model
embedding = VoyageEmbedding(model='voyage-3-lite')
# Check attributes
self.assertEqual(embedding.model, 'voyage-3-lite')
self.assertEqual(embedding.dimension, 512) # voyage-3-lite has 512 dimensions
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Initialize with model_name
embedding = VoyageEmbedding(model_name='voyage-3-large')
# Check attributes
self.assertEqual(embedding.model, 'voyage-3-large')
@patch.dict('os.environ', {}, clear=True)
def test_init_with_api_key(self):
"""Test initialization with API key parameter."""
# Initialize with API key
embedding = VoyageEmbedding(api_key='test-api-key')
# Check that the API key was set correctly
self.assertEqual(embedding.voyageai_api_key, 'test-api-key')
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create the embedder
embedding = VoyageEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that embed was called correctly
self.mock_client.embed.assert_called_once_with(
[query],
model='voyage-3',
input_type='query'
)
# Check the result
self.assertEqual(result, [0.1] * 1024)
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_embed_documents(self):
"""Test embedding multiple documents."""
# Create the embedder
embedding = VoyageEmbedding()
# Create test documents
texts = ["text 1", "text 2", "text 3"]
# Set up mock response for multiple documents
mock_response = MagicMock()
mock_response.embeddings = [[0.1 * (i + 1)] * 1024 for i in range(3)]
self.mock_client.embed.return_value = mock_response
# Call the method
results = embedding.embed_documents(texts)
# Verify that embed was called correctly
self.mock_client.embed.assert_called_once_with(
texts,
model='voyage-3',
input_type='document'
)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(result, [0.1 * (i + 1)] * 1024)
@patch.dict('os.environ', {'VOYAGE_API_KEY': 'fake-api-key'}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create the embedder
embedding = VoyageEmbedding()
# For voyage-3
self.assertEqual(embedding.dimension, 1024)
# For voyage-3-lite
embedding = VoyageEmbedding(model='voyage-3-lite')
self.assertEqual(embedding.dimension, 512)
# For voyage-3-large
embedding = VoyageEmbedding(model='voyage-3-large')
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()

284
tests/embedding/test_watsonx_embedding.py

@ -1,284 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, ANY
import os
class TestWatsonXEmbedding(unittest.TestCase):
"""Test cases for WatsonXEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Mock the ibm_watsonx_ai imports
self.mock_credentials = MagicMock()
self.mock_embeddings = MagicMock()
# Create a mock client
self.mock_client = MagicMock()
# Set up mock response for embed_query
self.mock_client.embed_query.return_value = {
'results': [
{'embedding': [0.1] * 768}
]
}
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_env_vars(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with environment variables."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
embedding = WatsonXEmbedding()
# Check that Credentials was called with correct parameters
mock_credentials_class.assert_called_once_with(
url='https://test.watsonx.com',
api_key='test-api-key'
)
# Check that Embeddings was called with correct parameters
mock_embeddings_class.assert_called_once_with(
model_id='ibm/slate-125m-english-rtrvr-v2',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
# Check default model and dimension
self.assertEqual(embedding.model, 'ibm/slate-125m-english-rtrvr-v2')
self.assertEqual(embedding.dimension, 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_with_space_id(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with space_id instead of project_id."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
embedding = WatsonXEmbedding(space_id='test-space-id')
# Check that Embeddings was called with space_id
mock_embeddings_class.assert_called_once_with(
model_id='ibm/slate-125m-english-rtrvr-v2',
credentials=mock_credentials_instance,
space_id='test-space-id'
)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
def test_init_missing_api_key(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing API key."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with patch.dict(os.environ, {}, clear=True):
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_APIKEY", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key'
})
def test_init_missing_url(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing URL."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_URL", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_missing_project_and_space_id(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing both project_id and space_id."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_PROJECT_ID", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_query(self, mock_credentials_class, mock_embeddings_class):
"""Test embedding a single query."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
# WatsonX embed_query returns the embedding vector directly, not wrapped in a dict
mock_embeddings_instance.embed_query.return_value = [0.1] * 768
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that embed_query was called correctly
mock_embeddings_instance.embed_query.assert_called_once_with(text=query)
# Check the result
self.assertEqual(result, [0.1] * 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_documents(self, mock_credentials_class, mock_embeddings_class):
"""Test embedding multiple documents."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
# WatsonX embed_documents returns a list of embedding vectors directly
mock_embeddings_instance.embed_documents.return_value = [
[0.1] * 768,
[0.2] * 768,
[0.3] * 768
]
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Create test documents
documents = ["Document 1", "Document 2", "Document 3"]
# Call the method
results = embedding.embed_documents(documents)
# Verify that embed_documents was called correctly
mock_embeddings_instance.embed_documents.assert_called_once_with(texts=documents)
# Check the results
self.assertEqual(len(results), 3)
self.assertEqual(results[0], [0.1] * 768)
self.assertEqual(results[1], [0.2] * 768)
self.assertEqual(results[2], [0.3] * 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_dimension_property(self, mock_credentials_class, mock_embeddings_class):
"""Test the dimension property for different models."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Test default model
embedding = WatsonXEmbedding()
self.assertEqual(embedding.dimension, 768)
# Test different model
embedding = WatsonXEmbedding(model='ibm/slate-30m-english-rtrvr')
self.assertEqual(embedding.dimension, 384)
# Test unknown model (should default to 768)
embedding = WatsonXEmbedding(model='unknown-model')
self.assertEqual(embedding.dimension, 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_query_error_handling(self, mock_credentials_class, mock_embeddings_class):
"""Test error handling in embed_query."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_embeddings_instance.embed_query.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
embedding.embed_query("test")
self.assertIn("Error embedding query with WatsonX", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_documents_error_handling(self, mock_credentials_class, mock_embeddings_class):
"""Test error handling in embed_documents."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_embeddings_instance.embed_documents.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
embedding.embed_documents(["test"])
self.assertIn("Error embedding documents with WatsonX", str(context.exception))
if __name__ == '__main__':
unittest.main()

1
tests/llm/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher.llm package

164
tests/llm/test_aliyun.py

@ -1,164 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Aliyun
from deepsearcher.llm.base import ChatResponse
class TestAliyun(unittest.TestCase):
"""Tests for the Aliyun LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Aliyun()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
# Check default model
self.assertEqual(llm.model, "deepseek-r1")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": api_key}):
llm = Aliyun()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = Aliyun(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "qwen-max"
llm = Aliyun(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.aliyun.api"
llm = Aliyun(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Aliyun instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Aliyun()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-r1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Aliyun instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Aliyun()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-r1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Aliyun instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Aliyun()
# Mock an error response
self.mock_completions.create.side_effect = Exception("Aliyun API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Aliyun API Error")
if __name__ == "__main__":
unittest.main()

169
tests/llm/test_anthropic.py

@ -1,169 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
from typing import List
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Anthropic
from deepsearcher.llm.base import ChatResponse
class ContentItem:
"""Mock content item for Anthropic response."""
def __init__(self, text: str):
self.text = text
class TestAnthropic(unittest.TestCase):
"""Tests for the Anthropic LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_anthropic = MagicMock()
self.mock_client = MagicMock()
# Set up mock response
self.mock_response = MagicMock()
# Set up response content with proper structure
content_item = ContentItem("Test response")
self.mock_response.content = [content_item]
self.mock_response.usage.input_tokens = 50
self.mock_response.usage.output_tokens = 50
# Set up the mock module structure and response
self.mock_client.messages.create.return_value = self.mock_response
self.mock_anthropic.Anthropic.return_value = self.mock_client
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'anthropic': self.mock_anthropic})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Anthropic()
# Check that Anthropic client was initialized correctly
self.mock_anthropic.Anthropic.assert_called_once_with(
api_key=None,
base_url=None
)
# Check default attributes
self.assertEqual(llm.model, "claude-sonnet-4-0")
self.assertEqual(llm.max_tokens, 8192)
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": api_key}):
llm = Anthropic()
self.mock_anthropic.Anthropic.assert_called_with(
api_key=api_key,
base_url=None
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = Anthropic(api_key=api_key)
self.mock_anthropic.Anthropic.assert_called_with(
api_key=api_key,
base_url=None
)
def test_init_with_custom_model_and_tokens(self):
"""Test initialization with custom model and max tokens."""
with patch.dict('os.environ', {}, clear=True):
model = "claude-3-opus-20240229"
max_tokens = 4096
llm = Anthropic(model=model, max_tokens=max_tokens)
self.assertEqual(llm.model, model)
self.assertEqual(llm.max_tokens, max_tokens)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.anthropic.api"
llm = Anthropic(base_url=base_url)
self.mock_anthropic.Anthropic.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Anthropic instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Anthropic()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that messages.create was called correctly
self.mock_client.messages.create.assert_called_once()
call_args = self.mock_client.messages.create.call_args
self.assertEqual(call_args[1]["model"], "claude-sonnet-4-0")
self.assertEqual(call_args[1]["messages"], messages)
self.assertEqual(call_args[1]["max_tokens"], 8192)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100) # 50 input + 50 output
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Anthropic instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Anthropic()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that messages.create was called correctly
self.mock_client.messages.create.assert_called_once()
call_args = self.mock_client.messages.create.call_args
self.assertEqual(call_args[1]["model"], "claude-sonnet-4-0")
self.assertEqual(call_args[1]["messages"], messages)
self.assertEqual(call_args[1]["max_tokens"], 8192)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100) # 50 input + 50 output
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Anthropic instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Anthropic()
# Mock an error response
self.mock_client.messages.create.side_effect = Exception("Anthropic API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Anthropic API Error")
if __name__ == "__main__":
unittest.main()

170
tests/llm/test_azure_openai.py

@ -1,170 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import AzureOpenAI
from deepsearcher.llm.base import ChatResponse
class TestAzureOpenAI(unittest.TestCase):
"""Tests for the Azure OpenAI LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.AzureOpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
# Test parameters
self.test_model = "gpt-4"
self.test_endpoint = "https://test.openai.azure.com"
self.test_api_key = "test_api_key"
self.test_api_version = "2024-02-15"
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_with_parameters(self):
"""Test initialization with explicit parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Check that Azure OpenAI client was initialized correctly
self.mock_openai.AzureOpenAI.assert_called_once_with(
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Check model attribute
self.assertEqual(llm.model, self.test_model)
def test_init_with_env_variables(self):
"""Test initialization with environment variables."""
env_endpoint = "https://env.openai.azure.com"
env_api_key = "env_api_key"
with patch.dict(os.environ, {
"AZURE_OPENAI_ENDPOINT": env_endpoint,
"AZURE_OPENAI_KEY": env_api_key
}):
llm = AzureOpenAI(model=self.test_model)
self.mock_openai.AzureOpenAI.assert_called_with(
azure_endpoint=env_endpoint,
api_key=env_api_key,
api_version=None
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], self.test_model)
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], self.test_model)
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Mock an error response
self.mock_completions.create.side_effect = Exception("Azure OpenAI API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Azure OpenAI API Error")
if __name__ == "__main__":
unittest.main()

154
tests/llm/test_base.py

@ -1,154 +0,0 @@
import unittest
from deepsearcher.llm.base import BaseLLM, ChatResponse
from unittest.mock import patch
class TestBaseLLM(unittest.TestCase):
"""Tests for the BaseLLM abstract base class."""
def setUp(self):
"""Set up test fixtures."""
# Clear environment variables temporarily
self.env_patcher = patch.dict('os.environ', {}, clear=True)
self.env_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.env_patcher.stop()
def test_chat_response_init(self):
"""Test ChatResponse initialization and representation."""
content = "Test content"
total_tokens = 100
response = ChatResponse(content=content, total_tokens=total_tokens)
self.assertEqual(response.content, content)
self.assertEqual(response.total_tokens, total_tokens)
self.assertEqual(
repr(response),
f"ChatResponse(content={content}, total_tokens={total_tokens})"
)
def test_literal_eval_python_code_block(self):
"""Test literal_eval with Python code block."""
content = '''```python
{"key": "value", "number": 42}
```'''
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value", "number": 42})
def test_literal_eval_json_code_block(self):
"""Test literal_eval with JSON code block."""
content = '''```json
{"key": "value", "number": 42}
```'''
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value", "number": 42})
def test_literal_eval_str_code_block(self):
"""Test literal_eval with str code block."""
content = '''```str
{"key": "value", "number": 42}
```'''
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value", "number": 42})
def test_literal_eval_plain_code_block(self):
"""Test literal_eval with plain code block."""
content = '''```
{"key": "value", "number": 42}
```'''
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value", "number": 42})
def test_literal_eval_raw_dict(self):
"""Test literal_eval with raw dictionary string."""
content = '{"key": "value", "number": 42}'
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value", "number": 42})
def test_literal_eval_raw_list(self):
"""Test literal_eval with raw list string."""
content = '[1, 2, "three", {"four": 4}]'
result = BaseLLM.literal_eval(content)
self.assertEqual(result, [1, 2, "three", {"four": 4}])
def test_literal_eval_with_whitespace(self):
"""Test literal_eval with extra whitespace."""
content = '''
{"key": "value"}
'''
result = BaseLLM.literal_eval(content)
self.assertEqual(result, {"key": "value"})
def test_literal_eval_nested_structures(self):
"""Test literal_eval with nested data structures."""
content = '''
{
"string": "value",
"number": 42,
"list": [1, 2, 3],
"dict": {"nested": "value"},
"mixed": [1, {"key": "value"}, [2, 3]]
}
'''
result = BaseLLM.literal_eval(content)
expected = {
"string": "value",
"number": 42,
"list": [1, 2, 3],
"dict": {"nested": "value"},
"mixed": [1, {"key": "value"}, [2, 3]]
}
self.assertEqual(result, expected)
def test_literal_eval_invalid_format(self):
"""Test literal_eval with invalid format."""
invalid_contents = [
"Not a valid Python literal",
"{invalid: json}",
"[1, 2, 3", # Unclosed bracket
'{"key": undefined}', # undefined is not a valid Python literal
]
for content in invalid_contents:
with self.assertRaises(ValueError):
BaseLLM.literal_eval(content)
def test_remove_think_with_tags(self):
"""Test remove_think with think tags."""
content = '''<think>
This is the reasoning process.
Multiple lines of thought.
</think>
This is the actual response.'''
result = BaseLLM.remove_think(content)
self.assertEqual(result.strip(), "This is the actual response.")
def test_remove_think_without_tags(self):
"""Test remove_think without think tags."""
content = "This is a response without think tags."
result = BaseLLM.remove_think(content)
self.assertEqual(result.strip(), content.strip())
def test_remove_think_multiple_tags(self):
"""Test remove_think with multiple think tags - should only remove first block."""
content = '''<think>First think block</think>
Actual response
<think>Second think block</think>'''
result = BaseLLM.remove_think(content)
self.assertEqual(
result.strip(),
"Actual response\n <think>Second think block</think>"
)
def test_remove_think_empty_tags(self):
"""Test remove_think with empty think tags."""
content = "<think></think>Response"
result = BaseLLM.remove_think(content)
self.assertEqual(result.strip(), "Response")
if __name__ == "__main__":
unittest.main()

196
tests/llm/test_bedrock.py

@ -1,196 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Bedrock
from deepsearcher.llm.base import ChatResponse
class TestBedrock(unittest.TestCase):
"""Tests for the AWS Bedrock LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_boto3 = MagicMock()
self.mock_client = MagicMock()
# Set up the mock module structure
self.mock_boto3.client = MagicMock(return_value=self.mock_client)
# Set up mock response
self.mock_response = {
"output": {
"message": {
"content": [{"text": "Test response\nwith newline"}]
}
},
"usage": {
"totalTokens": 100
}
}
self.mock_client.converse.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock()
# Check that client was initialized correctly
self.mock_boto3.client.assert_called_once_with(
"bedrock-runtime",
region_name="us-west-2",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None
)
# Check default attributes
self.assertEqual(llm.model, "us.deepseek.r1-v1:0")
self.assertEqual(llm.max_tokens, 20000)
def test_init_with_aws_credentials_from_env(self):
"""Test initialization with AWS credentials from environment variables."""
credentials = {
"AWS_ACCESS_KEY_ID": "test_access_key",
"AWS_SECRET_ACCESS_KEY": "test_secret_key",
"AWS_SESSION_TOKEN": "test_session_token"
}
with patch.dict(os.environ, credentials):
llm = Bedrock()
self.mock_boto3.client.assert_called_with(
"bedrock-runtime",
region_name="us-west-2",
aws_access_key_id="test_access_key",
aws_secret_access_key="test_secret_key",
aws_session_token="test_session_token"
)
def test_init_with_aws_credentials_parameters(self):
"""Test initialization with AWS credentials as parameters."""
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock(
aws_access_key_id="param_access_key",
aws_secret_access_key="param_secret_key",
aws_session_token="param_session_token"
)
self.mock_boto3.client.assert_called_with(
"bedrock-runtime",
region_name="us-west-2",
aws_access_key_id="param_access_key",
aws_secret_access_key="param_secret_key",
aws_session_token="param_session_token"
)
def test_init_with_custom_model_and_tokens(self):
"""Test initialization with custom model and max tokens."""
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock(model="custom.model", max_tokens=1000)
self.assertEqual(llm.model, "custom.model")
self.assertEqual(llm.max_tokens, 1000)
def test_init_with_custom_region(self):
"""Test initialization with custom region."""
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock(region_name="us-east-1")
self.mock_boto3.client.assert_called_with(
"bedrock-runtime",
region_name="us-east-1",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Bedrock instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that converse was called correctly
self.mock_client.converse.assert_called_once()
call_args = self.mock_client.converse.call_args
self.assertEqual(call_args[1]["modelId"], "us.deepseek.r1-v1:0")
self.assertEqual(call_args[1]["messages"], [
{"role": "user", "content": [{"text": "Hello"}]}
])
self.assertEqual(call_args[1]["inferenceConfig"], {"maxTokens": 20000})
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test responsewith newline")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Bedrock instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that converse was called correctly
self.mock_client.converse.assert_called_once()
call_args = self.mock_client.converse.call_args
expected_messages = [
{"role": "system", "content": [{"text": "You are a helpful assistant"}]},
{"role": "user", "content": [{"text": "Hello"}]},
{"role": "assistant", "content": [{"text": "Hi there!"}]},
{"role": "user", "content": [{"text": "How are you?"}]}
]
self.assertEqual(call_args[1]["messages"], expected_messages)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Bedrock instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock()
# Mock an error response
self.mock_client.converse.side_effect = Exception("AWS Bedrock Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "AWS Bedrock Error")
def test_chat_with_preformatted_messages(self):
"""Test chat with messages that are already in the correct format."""
# Create Bedrock instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Bedrock()
messages = [
{
"role": "user",
"content": [{"text": "Hello"}]
}
]
response = llm.chat(messages)
# Check that the message format was preserved
call_args = self.mock_client.converse.call_args
self.assertEqual(call_args[1]["messages"], messages)
if __name__ == "__main__":
unittest.main()

169
tests/llm/test_deepseek.py

@ -1,169 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import DeepSeek
from deepsearcher.llm.base import ChatResponse
class TestDeepSeek(unittest.TestCase):
"""Tests for the DeepSeek LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = DeepSeek()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://api.deepseek.com"
)
# Check default model
self.assertEqual(llm.model, "deepseek-reasoner")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
base_url = "https://custom.deepseek.api"
with patch.dict(os.environ, {
"DEEPSEEK_API_KEY": api_key,
"DEEPSEEK_BASE_URL": base_url
}):
llm = DeepSeek()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url=base_url
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
api_key = "test_api_key_param"
with patch.dict('os.environ', {}, clear=True):
llm = DeepSeek(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.deepseek.com"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "deepseek-chat"
llm = DeepSeek(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.deepseek.api"
llm = DeepSeek(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create DeepSeek instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = DeepSeek()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-reasoner")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create DeepSeek instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = DeepSeek()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-reasoner")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create DeepSeek instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = DeepSeek()
# Mock an error response
self.mock_completions.create.side_effect = Exception("DeepSeek API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "DeepSeek API Error")
if __name__ == "__main__":
unittest.main()

136
tests/llm/test_gemini.py

@ -1,136 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Gemini
from deepsearcher.llm.base import ChatResponse
class TestGemini(unittest.TestCase):
"""Tests for the Gemini LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_genai = MagicMock()
self.mock_client = MagicMock()
self.mock_response = MagicMock()
self.mock_metadata = MagicMock()
# Set up the mock module structure
self.mock_genai.Client = MagicMock(return_value=self.mock_client)
# Set up mock response
self.mock_response.text = "Test response"
self.mock_metadata.total_token_count = 100
self.mock_response.usage_metadata = self.mock_metadata
self.mock_client.models.generate_content.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'google.genai': self.mock_genai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Gemini()
# Check that Client was initialized correctly
self.mock_genai.Client.assert_called_once_with(api_key=None)
# Check default model
self.assertEqual(llm.model, "gemini-2.0-flash")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"GEMINI_API_KEY": api_key}):
llm = Gemini()
self.mock_genai.Client.assert_called_with(api_key=api_key)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = Gemini(api_key=api_key)
self.mock_genai.Client.assert_called_with(api_key=api_key)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "gemini-pro"
llm = Gemini(model=model)
self.assertEqual(llm.model, model)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Gemini instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Gemini()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that generate_content was called correctly
self.mock_client.models.generate_content.assert_called_once()
call_args = self.mock_client.models.generate_content.call_args
self.assertEqual(call_args[1]["model"], "gemini-2.0-flash")
self.assertEqual(call_args[1]["contents"], "Hello")
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Gemini instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Gemini()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that generate_content was called correctly
self.mock_client.models.generate_content.assert_called_once()
call_args = self.mock_client.models.generate_content.call_args
self.assertEqual(call_args[1]["model"], "gemini-2.0-flash")
expected_content = "You are a helpful assistant\nHello\nHi there!\nHow are you?"
self.assertEqual(call_args[1]["contents"], expected_content)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Gemini instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Gemini()
# Mock an error response
self.mock_client.models.generate_content.side_effect = Exception("API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "API Error")
if __name__ == "__main__":
unittest.main()

165
tests/llm/test_glm.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import GLM
from deepsearcher.llm.base import ChatResponse
class TestGLM(unittest.TestCase):
"""Tests for the GLM LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_zhipuai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_zhipuai.ZhipuAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'zhipuai': self.mock_zhipuai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = GLM()
# Check that ZhipuAI client was initialized correctly
self.mock_zhipuai.ZhipuAI.assert_called_once_with(
api_key=None,
base_url="https://open.bigmodel.cn/api/paas/v4/"
)
# Check default model
self.assertEqual(llm.model, "glm-4-plus")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"GLM_API_KEY": api_key}):
llm = GLM()
self.mock_zhipuai.ZhipuAI.assert_called_with(
api_key=api_key,
base_url="https://open.bigmodel.cn/api/paas/v4/"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = GLM(api_key=api_key)
self.mock_zhipuai.ZhipuAI.assert_called_with(
api_key=api_key,
base_url="https://open.bigmodel.cn/api/paas/v4/"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "glm-3-turbo"
llm = GLM(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.glm.api"
llm = GLM(base_url=base_url)
self.mock_zhipuai.ZhipuAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create GLM instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = GLM()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "glm-4-plus")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create GLM instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = GLM()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "glm-4-plus")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create GLM instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = GLM()
# Mock an error response
self.mock_completions.create.side_effect = Exception("GLM API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "GLM API Error")
if __name__ == "__main__":
unittest.main()

165
tests/llm/test_novita.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Novita
from deepsearcher.llm.base import ChatResponse
class TestNovita(unittest.TestCase):
"""Tests for the Novita LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Novita()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://api.novita.ai/v3/openai"
)
# Check default model
self.assertEqual(llm.model, "qwen/qwq-32b")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"NOVITA_API_KEY": api_key}):
llm = Novita()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.novita.ai/v3/openai"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = Novita(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.novita.ai/v3/openai"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "qwen/qwq-72b"
llm = Novita(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.novita.api"
llm = Novita(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Novita instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Novita()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "qwen/qwq-32b")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Novita instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Novita()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "qwen/qwq-32b")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Novita instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Novita()
# Mock an error response
self.mock_completions.create.side_effect = Exception("Novita API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Novita API Error")
if __name__ == "__main__":
unittest.main()

136
tests/llm/test_ollama.py

@ -1,136 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import logging
import os
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Ollama
from deepsearcher.llm.base import ChatResponse
class TestOllama(unittest.TestCase):
"""Tests for the Ollama LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_ollama = MagicMock()
self.mock_client = MagicMock()
# Set up the mock module structure
self.mock_ollama.Client = MagicMock(return_value=self.mock_client)
# Set up mock response
self.mock_response = MagicMock()
self.mock_message = MagicMock()
self.mock_message.content = "Test response"
self.mock_response.message = self.mock_message
self.mock_response.prompt_eval_count = 50
self.mock_response.eval_count = 50
self.mock_client.chat.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'ollama': self.mock_ollama})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Ollama()
# Check that Ollama client was initialized correctly
self.mock_ollama.Client.assert_called_once_with(
host="http://localhost:11434"
)
# Check default model
self.assertEqual(llm.model, "qwq")
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "llama2"
llm = Ollama(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "http://custom.ollama:11434"
llm = Ollama(base_url=base_url)
self.mock_ollama.Client.assert_called_with(
host=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Ollama instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Ollama()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that chat was called correctly
self.mock_client.chat.assert_called_once_with(
model="qwq",
messages=messages
)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Ollama instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Ollama()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that chat was called correctly
self.mock_client.chat.assert_called_once_with(
model="qwq",
messages=messages
)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Ollama instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Ollama()
# Mock an error response
self.mock_client.chat.side_effect = Exception("Ollama API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Ollama API Error")
if __name__ == "__main__":
unittest.main()

167
tests/llm/test_openai.py

@ -1,167 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import OpenAI
from deepsearcher.llm.base import ChatResponse
class TestOpenAI(unittest.TestCase):
"""Tests for the OpenAI LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = OpenAI()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url=None
)
# Check default model
self.assertEqual(llm.model, "o1-mini")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
base_url = "https://api.openai.com/v1"
with patch.dict(os.environ, {
"OPENAI_API_KEY": api_key,
"OPENAI_BASE_URL": base_url
}):
llm = OpenAI()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url=base_url
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
api_key = "test_api_key_param"
llm = OpenAI(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url=None
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
model = "gpt-4"
llm = OpenAI(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.openai.api"
llm = OpenAI(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = OpenAI()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "o1-mini")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = OpenAI()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "o1-mini")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = OpenAI()
# Mock an error response
self.mock_completions.create.side_effect = Exception("OpenAI API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "OpenAI API Error")
if __name__ == "__main__":
unittest.main()

165
tests/llm/test_ppio.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import PPIO
from deepsearcher.llm.base import ChatResponse
class TestPPIO(unittest.TestCase):
"""Tests for the PPIO LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = PPIO()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://api.ppinfra.com/v3/openai"
)
# Check default model
self.assertEqual(llm.model, "deepseek/deepseek-r1-turbo")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"PPIO_API_KEY": api_key}):
llm = PPIO()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.ppinfra.com/v3/openai"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = PPIO(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.ppinfra.com/v3/openai"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "deepseek/deepseek-r1-max"
llm = PPIO(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.ppio.api"
llm = PPIO(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create PPIO instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = PPIO()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek/deepseek-r1-turbo")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create PPIO instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = PPIO()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek/deepseek-r1-turbo")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create PPIO instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = PPIO()
# Mock an error response
self.mock_completions.create.side_effect = Exception("PPIO API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "PPIO API Error")
if __name__ == "__main__":
unittest.main()

165
tests/llm/test_siliconflow.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import SiliconFlow
from deepsearcher.llm.base import ChatResponse
class TestSiliconFlow(unittest.TestCase):
"""Tests for the SiliconFlow LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = SiliconFlow()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://api.siliconflow.cn/v1"
)
# Check default model
self.assertEqual(llm.model, "deepseek-ai/DeepSeek-R1")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"SILICONFLOW_API_KEY": api_key}):
llm = SiliconFlow()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.siliconflow.cn/v1"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = SiliconFlow(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.siliconflow.cn/v1"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "deepseek-ai/DeepSeek-R2"
llm = SiliconFlow(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.siliconflow.api"
llm = SiliconFlow(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create SiliconFlow instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = SiliconFlow()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create SiliconFlow instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = SiliconFlow()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create SiliconFlow instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = SiliconFlow()
# Mock an error response
self.mock_completions.create.side_effect = Exception("SiliconFlow API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "SiliconFlow API Error")
if __name__ == "__main__":
unittest.main()

151
tests/llm/test_together_ai.py

@ -1,151 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import TogetherAI
from deepsearcher.llm.base import ChatResponse
class TestTogetherAI(unittest.TestCase):
"""Tests for the TogetherAI LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_together = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_together.Together = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'together': self.mock_together})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = TogetherAI()
# Check that Together client was initialized correctly
self.mock_together.Together.assert_called_once_with(
api_key=None
)
# Check default model
self.assertEqual(llm.model, "deepseek-ai/DeepSeek-R1")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"TOGETHER_API_KEY": api_key}):
llm = TogetherAI()
self.mock_together.Together.assert_called_with(
api_key=api_key
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = TogetherAI(api_key=api_key)
self.mock_together.Together.assert_called_with(
api_key=api_key
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
llm = TogetherAI(model=model)
self.assertEqual(llm.model, model)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create TogetherAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = TogetherAI()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create TogetherAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = TogetherAI()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-ai/DeepSeek-R1")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create TogetherAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = TogetherAI()
# Mock an error response
self.mock_completions.create.side_effect = Exception("Together API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Together API Error")
if __name__ == "__main__":
unittest.main()

165
tests/llm/test_volcengine.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import Volcengine
from deepsearcher.llm.base import ChatResponse
class TestVolcengine(unittest.TestCase):
"""Tests for the Volcengine LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = Volcengine()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://ark.cn-beijing.volces.com/api/v3"
)
# Check default model
self.assertEqual(llm.model, "deepseek-r1-250120")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"VOLCENGINE_API_KEY": api_key}):
llm = Volcengine()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://ark.cn-beijing.volces.com/api/v3"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = Volcengine(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://ark.cn-beijing.volces.com/api/v3"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "deepseek-r2-250120"
llm = Volcengine(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.volcengine.api"
llm = Volcengine(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Volcengine instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Volcengine()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-r1-250120")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Volcengine instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Volcengine()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "deepseek-r1-250120")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Volcengine instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = Volcengine()
# Mock an error response
self.mock_completions.create.side_effect = Exception("Volcengine API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Volcengine API Error")
if __name__ == "__main__":
unittest.main()

421
tests/llm/test_watsonx.py

@ -1,421 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import os
class TestWatsonX(unittest.TestCase):
"""Test cases for WatsonX class."""
def setUp(self):
"""Set up test fixtures."""
pass
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_env_vars(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with environment variables."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
# Check that Credentials was called with correct parameters
mock_credentials_class.assert_called_once_with(
url='https://test.watsonx.com',
api_key='test-api-key'
)
# Check that ModelInference was called with correct parameters
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-3-3-8b-instruct',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
# Check default model
self.assertEqual(llm.model, 'ibm/granite-3-3-8b-instruct')
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_with_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with space_id instead of project_id."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(space_id='test-space-id')
# Check that ModelInference was called with space_id
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-3-3-8b-instruct',
credentials=mock_credentials_instance,
space_id='test-space-id'
)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_custom_model(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with custom model."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(model='ibm/granite-13b-chat-v2')
# Check that ModelInference was called with custom model
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-13b-chat-v2',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
self.assertEqual(llm.model, 'ibm/granite-13b-chat-v2')
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_custom_params(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with custom generation parameters."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(
max_new_tokens=500,
temperature=0.7,
top_p=0.9,
top_k=40
)
# Check that generation parameters were set correctly
expected_params = {
'max_new_tokens': 500,
'temperature': 0.7,
'top_p': 0.9,
'top_k': 40
}
self.assertEqual(llm.generation_params, expected_params)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
def test_init_missing_api_key(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing API key."""
from deepsearcher.llm.watsonx import WatsonX
with patch.dict(os.environ, {}, clear=True):
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_APIKEY", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key'
})
def test_init_missing_url(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing URL."""
from deepsearcher.llm.watsonx import WatsonX
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_URL", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_missing_project_and_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing both project_id and space_id."""
from deepsearcher.llm.watsonx import WatsonX
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_PROJECT_ID", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_simple_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with a simple message."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "This is a test response from WatsonX."
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "user", "content": "Hello, how are you?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format
expected_prompt = "Human: Hello, how are you?\n\nAssistant:"
self.assertEqual(call_args[1]['prompt'], expected_prompt)
# Check response
self.assertEqual(response.content, "This is a test response from WatsonX.")
self.assertIsInstance(response.total_tokens, int)
self.assertGreater(response.total_tokens, 0)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_with_system_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with system and user messages."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "4"
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format
expected_prompt = "System: You are a helpful assistant.\n\nHuman: What is 2+2?\n\nAssistant:"
self.assertEqual(call_args[1]['prompt'], expected_prompt)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_conversation_history(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with conversation history."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "6"
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
{"role": "user", "content": "What about 3+3?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format includes conversation history
expected_prompt = ("System: You are a helpful assistant.\n\n"
"Human: What is 2+2?\n\n"
"Assistant: 2+2 equals 4.\n\n"
"Human: What about 3+3?\n\n"
"Assistant:")
self.assertEqual(call_args[1]['prompt'], expected_prompt)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_error_handling(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test error handling in chat method."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [{"role": "user", "content": "Hello"}]
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
llm.chat(messages)
self.assertIn("Error generating response with WatsonX", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_messages_to_prompt(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test the _messages_to_prompt method."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant message"},
{"role": "user", "content": "Another user message"}
]
prompt = llm._messages_to_prompt(messages)
expected_prompt = ("System: System message\n\n"
"Human: User message\n\n"
"Assistant: Assistant message\n\n"
"Human: Another user message\n\n"
"Assistant:")
self.assertEqual(prompt, expected_prompt)
if __name__ == '__main__':
unittest.main()

165
tests/llm/test_xai.py

@ -1,165 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import XAI
from deepsearcher.llm.base import ChatResponse
class TestXAI(unittest.TestCase):
"""Tests for the X.AI LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.OpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_default(self):
"""Test initialization with default parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = XAI()
# Check that OpenAI client was initialized correctly
self.mock_openai.OpenAI.assert_called_once_with(
api_key=None,
base_url="https://api.x.ai/v1"
)
# Check default model
self.assertEqual(llm.model, "grok-2-latest")
def test_init_with_api_key_from_env(self):
"""Test initialization with API key from environment variable."""
api_key = "test_api_key_from_env"
with patch.dict(os.environ, {"XAI_API_KEY": api_key}):
llm = XAI()
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.x.ai/v1"
)
def test_init_with_api_key_parameter(self):
"""Test initialization with API key as parameter."""
with patch.dict('os.environ', {}, clear=True):
api_key = "test_api_key_param"
llm = XAI(api_key=api_key)
self.mock_openai.OpenAI.assert_called_with(
api_key=api_key,
base_url="https://api.x.ai/v1"
)
def test_init_with_custom_model(self):
"""Test initialization with custom model."""
with patch.dict('os.environ', {}, clear=True):
model = "grok-1"
llm = XAI(model=model)
self.assertEqual(llm.model, model)
def test_init_with_custom_base_url(self):
"""Test initialization with custom base URL."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
base_url = "https://custom.x.ai"
llm = XAI(base_url=base_url)
self.mock_openai.OpenAI.assert_called_with(
api_key=None,
base_url=base_url
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create XAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = XAI()
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "grok-2-latest")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create XAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = XAI()
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], "grok-2-latest")
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create XAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = XAI()
# Mock an error response
self.mock_completions.create.side_effect = Exception("XAI API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "XAI API Error")
if __name__ == "__main__":
unittest.main()

1
tests/loader/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher.loader package

1
tests/loader/file_loader/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher.loader.file_loader package

69
tests/loader/file_loader/test_base.py

@ -1,69 +0,0 @@
import unittest
import os
import tempfile
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.file_loader.base import BaseLoader
class TestBaseLoader(unittest.TestCase):
"""Tests for the BaseLoader class."""
def test_abstract_methods(self):
"""Test that BaseLoader defines abstract methods."""
# For abstract base classes, we can check if methods are defined
# but not implemented in the base class
self.assertTrue(hasattr(BaseLoader, 'load_file'))
self.assertTrue(hasattr(BaseLoader, 'supported_file_types'))
def test_load_directory(self):
"""Test the load_directory method."""
# Create a subclass of BaseLoader for testing
class TestLoader(BaseLoader):
@property
def supported_file_types(self):
return [".txt", ".md"]
def load_file(self, file_path):
# Mock implementation that returns a simple Document
return [Document(page_content=f"Content of {file_path}", metadata={"reference": file_path})]
# Create a temporary directory with test files
with tempfile.TemporaryDirectory() as temp_dir:
# Create test files
file_paths = [
os.path.join(temp_dir, "test1.txt"),
os.path.join(temp_dir, "test2.md"),
os.path.join(temp_dir, "test3.pdf"), # Unsupported format
os.path.join(temp_dir, "subdir", "test4.txt")
]
# Create subdirectory
os.makedirs(os.path.join(temp_dir, "subdir"), exist_ok=True)
# Create files
for path in file_paths:
# Skip the file if it's in a subdirectory that doesn't exist
if not os.path.exists(os.path.dirname(path)):
continue
with open(path, 'w') as f:
f.write(f"Content of {path}")
# Test loading the directory
loader = TestLoader()
documents = loader.load_directory(temp_dir)
# Check the results
self.assertEqual(len(documents), 3) # Should find 3 supported files
# Verify each document
references = [doc.metadata["reference"] for doc in documents]
self.assertIn(file_paths[0], references) # test1.txt
self.assertIn(file_paths[1], references) # test2.md
self.assertNotIn(file_paths[2], references) # test3.pdf (unsupported)
self.assertIn(file_paths[3], references) # subdir/test4.txt
if __name__ == "__main__":
unittest.main()

185
tests/loader/file_loader/test_docling_loader.py

@ -1,185 +0,0 @@
import unittest
import os
import tempfile
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.file_loader import DoclingLoader
class TestDoclingLoader(unittest.TestCase):
"""Tests for the DoclingLoader class."""
def setUp(self):
"""Set up test fixtures."""
# Create patches for the docling modules
self.docling_patcher = patch.dict('sys.modules', {
'docling': MagicMock(),
'docling.document_converter': MagicMock(),
'docling_core': MagicMock(),
'docling_core.transforms': MagicMock(),
'docling_core.transforms.chunker': MagicMock()
})
self.docling_patcher.start()
# Create mocks for the classes
self.mock_document_converter = MagicMock()
self.mock_hierarchical_chunker = MagicMock()
# Add the mocks to the modules
import sys
sys.modules['docling.document_converter'].DocumentConverter = self.mock_document_converter
sys.modules['docling_core.transforms.chunker'].HierarchicalChunker = self.mock_hierarchical_chunker
# Set up mock instances
self.mock_converter_instance = MagicMock()
self.mock_chunker_instance = MagicMock()
self.mock_document_converter.return_value = self.mock_converter_instance
self.mock_hierarchical_chunker.return_value = self.mock_chunker_instance
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
# Create a test markdown file
self.md_file_path = os.path.join(self.temp_dir.name, "test.md")
with open(self.md_file_path, "w", encoding="utf-8") as f:
f.write("# Test Markdown\nThis is a test markdown file.")
# Create a test unsupported file
self.unsupported_file_path = os.path.join(self.temp_dir.name, "test.xyz")
with open(self.unsupported_file_path, "w", encoding="utf-8") as f:
f.write("This is an unsupported file type.")
# Create a subdirectory with a test file
self.sub_dir = os.path.join(self.temp_dir.name, "subdir")
os.makedirs(self.sub_dir, exist_ok=True)
self.sub_file_path = os.path.join(self.sub_dir, "subfile.md")
with open(self.sub_file_path, "w", encoding="utf-8") as f:
f.write("# Subdir Test\nThis is a test markdown file in a subdirectory.")
# Create the loader
self.loader = DoclingLoader()
def tearDown(self):
"""Clean up test fixtures."""
self.docling_patcher.stop()
self.temp_dir.cleanup()
def test_init(self):
"""Test initialization."""
# Verify instances were created
self.mock_document_converter.assert_called_once()
self.mock_hierarchical_chunker.assert_called_once()
# Check that the instances were assigned correctly
self.assertEqual(self.loader.converter, self.mock_converter_instance)
self.assertEqual(self.loader.chunker, self.mock_chunker_instance)
def test_supported_file_types(self):
"""Test the supported_file_types property."""
file_types = self.loader.supported_file_types
# Check that the common file types are included
common_types = ["pdf", "docx", "md", "html", "csv", "jpg"]
for file_type in common_types:
self.assertIn(file_type, file_types)
def test_load_file(self):
"""Test loading a single file."""
# Set up mock document and chunks
mock_document = MagicMock()
mock_conversion_result = MagicMock()
mock_conversion_result.document = mock_document
# Set up three mock chunks
mock_chunks = []
for i in range(3):
chunk = MagicMock()
chunk.text = f"Chunk {i} content"
mock_chunks.append(chunk)
# Configure mock converter and chunker
self.mock_converter_instance.convert.return_value = mock_conversion_result
self.mock_chunker_instance.chunk.return_value = mock_chunks
# Call the method
documents = self.loader.load_file(self.md_file_path)
# Verify converter was called correctly
self.mock_converter_instance.convert.assert_called_once_with(self.md_file_path)
# Verify chunker was called correctly
self.mock_chunker_instance.chunk.assert_called_once_with(mock_document)
# Check results
self.assertEqual(len(documents), 3)
# Check each document
for i, document in enumerate(documents):
self.assertEqual(document.page_content, f"Chunk {i} content")
self.assertEqual(document.metadata["reference"], self.md_file_path)
self.assertEqual(document.metadata["text"], f"Chunk {i} content")
def test_load_file_not_found(self):
"""Test loading a non-existent file."""
non_existent_file = os.path.join(self.temp_dir.name, "non_existent.md")
with self.assertRaises(FileNotFoundError):
self.loader.load_file(non_existent_file)
def test_load_unsupported_file_type(self):
"""Test loading a file with unsupported extension."""
with self.assertRaises(ValueError):
self.loader.load_file(self.unsupported_file_path)
def test_load_file_error(self):
"""Test error handling when loading a file."""
# Configure converter to raise an exception
self.mock_converter_instance.convert.side_effect = Exception("Test error")
# Verify that the error is propagated
with self.assertRaises(IOError):
self.loader.load_file(self.md_file_path)
def test_load_directory(self):
"""Test loading a directory."""
# Set up mock document and chunks
mock_document = MagicMock()
mock_conversion_result = MagicMock()
mock_conversion_result.document = mock_document
# Set up a single mock chunk
mock_chunk = MagicMock()
mock_chunk.text = "Test chunk content"
# Configure mock converter and chunker
self.mock_converter_instance.convert.return_value = mock_conversion_result
self.mock_chunker_instance.chunk.return_value = [mock_chunk]
# Load the directory
documents = self.loader.load_directory(self.temp_dir.name)
# Verify converter was called twice (once for each MD file)
self.assertEqual(self.mock_converter_instance.convert.call_count, 2)
# Verify converter was called with both MD files
self.mock_converter_instance.convert.assert_any_call(self.md_file_path)
self.mock_converter_instance.convert.assert_any_call(self.sub_file_path)
# Check results - should have two documents (one from each MD file)
self.assertEqual(len(documents), 2)
# Check each document
for document in documents:
self.assertEqual(document.page_content, "Test chunk content")
self.assertEqual(document.metadata["text"], "Test chunk content")
self.assertIn(document.metadata["reference"], [self.md_file_path, self.sub_file_path])
def test_load_not_a_directory(self):
"""Test loading a path that is not a directory."""
with self.assertRaises(NotADirectoryError):
self.loader.load_directory(self.md_file_path)
if __name__ == "__main__":
unittest.main()

124
tests/loader/file_loader/test_json_loader.py

@ -1,124 +0,0 @@
import unittest
import os
import json
import tempfile
from langchain_core.documents import Document
from deepsearcher.loader.file_loader import JsonFileLoader
class TestJsonFileLoader(unittest.TestCase):
"""Tests for the JsonFileLoader class."""
def setUp(self):
"""Set up the test environment."""
# Create a temporary directory for test files
self.temp_dir = tempfile.TemporaryDirectory()
# Sample JSON data
self.json_data = [
{"id": 1, "text": "This is the first document.", "author": "John Doe"},
{"id": 2, "text": "This is the second document.", "author": "Jane Smith"}
]
# Create JSON test file
self.json_file_path = os.path.join(self.temp_dir.name, "test.json")
with open(self.json_file_path, "w", encoding="utf-8") as f:
json.dump(self.json_data, f)
# Create JSONL test file
self.jsonl_file_path = os.path.join(self.temp_dir.name, "test.jsonl")
with open(self.jsonl_file_path, "w", encoding="utf-8") as f:
for item in self.json_data:
f.write(json.dumps(item) + "\n")
# Create invalid JSON file (not a list)
self.invalid_json_file_path = os.path.join(self.temp_dir.name, "invalid.json")
with open(self.invalid_json_file_path, "w", encoding="utf-8") as f:
json.dump({"id": 1, "text": "This is not a list.", "author": "John Doe"}, f)
# Create invalid JSONL file
self.invalid_jsonl_file_path = os.path.join(self.temp_dir.name, "invalid.jsonl")
with open(self.invalid_jsonl_file_path, "w", encoding="utf-8") as f:
f.write("This is not valid JSON\n")
f.write(json.dumps({"id": 2, "text": "This is valid JSON", "author": "Jane Smith"}) + "\n")
# Initialize the loader
self.loader = JsonFileLoader(text_key="text")
# Patch the _read_json_file method to fix the file handling
original_read_json_file = self.loader._read_json_file
def patched_read_json_file(file_path):
with open(file_path, 'r') as f:
json_data = json.load(f)
if not isinstance(json_data, list):
raise ValueError("JSON file must contain a list of dictionaries.")
return json_data
self.loader._read_json_file = patched_read_json_file
def tearDown(self):
"""Clean up the test environment."""
self.temp_dir.cleanup()
def test_load_json_file(self):
"""Test loading a JSON file."""
documents = self.loader.load_file(self.json_file_path)
# Check that we got the right number of documents
self.assertEqual(len(documents), 2)
# Check the content and metadata of each document
self.assertEqual(documents[0].page_content, "This is the first document.")
self.assertEqual(documents[0].metadata["id"], 1)
self.assertEqual(documents[0].metadata["author"], "John Doe")
self.assertEqual(documents[0].metadata["reference"], self.json_file_path)
self.assertEqual(documents[1].page_content, "This is the second document.")
self.assertEqual(documents[1].metadata["id"], 2)
self.assertEqual(documents[1].metadata["author"], "Jane Smith")
self.assertEqual(documents[1].metadata["reference"], self.json_file_path)
def test_load_jsonl_file(self):
"""Test loading a JSONL file."""
documents = self.loader.load_file(self.jsonl_file_path)
# Check that we got the right number of documents
self.assertEqual(len(documents), 2)
# Check the content and metadata of each document
self.assertEqual(documents[0].page_content, "This is the first document.")
self.assertEqual(documents[0].metadata["id"], 1)
self.assertEqual(documents[0].metadata["author"], "John Doe")
self.assertEqual(documents[0].metadata["reference"], self.jsonl_file_path)
self.assertEqual(documents[1].page_content, "This is the second document.")
self.assertEqual(documents[1].metadata["id"], 2)
self.assertEqual(documents[1].metadata["author"], "Jane Smith")
self.assertEqual(documents[1].metadata["reference"], self.jsonl_file_path)
def test_invalid_json_file(self):
"""Test loading an invalid JSON file (not a list)."""
with self.assertRaises(ValueError):
self.loader.load_file(self.invalid_json_file_path)
def test_invalid_jsonl_file(self):
"""Test loading a JSONL file with invalid lines."""
documents = self.loader.load_file(self.invalid_jsonl_file_path)
# Only the valid line should be loaded
self.assertEqual(len(documents), 1)
self.assertEqual(documents[0].page_content, "This is valid JSON")
def test_supported_file_types(self):
"""Test the supported_file_types property."""
file_types = self.loader.supported_file_types
self.assertIsInstance(file_types, list)
self.assertIn("txt", file_types)
self.assertIn("md", file_types)
if __name__ == "__main__":
unittest.main()

142
tests/loader/file_loader/test_pdf_loader.py

@ -1,142 +0,0 @@
import unittest
import os
import tempfile
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.file_loader import PDFLoader
class TestPDFLoader(unittest.TestCase):
"""Tests for the PDFLoader class."""
def setUp(self):
"""Set up the test environment."""
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
# Create a text file for testing
self.text_file_path = os.path.join(self.temp_dir.name, "test.txt")
with open(self.text_file_path, "w", encoding="utf-8") as f:
f.write("This is a test text file.")
# Create a markdown file for testing
self.md_file_path = os.path.join(self.temp_dir.name, "test.md")
with open(self.md_file_path, "w", encoding="utf-8") as f:
f.write("# Test Markdown\nThis is a test markdown file.")
# PDF file path (will be mocked)
self.pdf_file_path = os.path.join(self.temp_dir.name, "test.pdf")
# Create the loader
self.loader = PDFLoader()
def tearDown(self):
"""Clean up the test environment."""
self.temp_dir.cleanup()
def test_supported_file_types(self):
"""Test the supported_file_types property."""
file_types = self.loader.supported_file_types
self.assertIsInstance(file_types, list)
self.assertIn("pdf", file_types)
self.assertIn("md", file_types)
self.assertIn("txt", file_types)
def test_load_text_file(self):
"""Test loading a text file."""
documents = self.loader.load_file(self.text_file_path)
# Check that we got one document
self.assertEqual(len(documents), 1)
# Check the document content
document = documents[0]
self.assertEqual(document.page_content, "This is a test text file.")
# Check the metadata
self.assertEqual(document.metadata["reference"], self.text_file_path)
def test_load_markdown_file(self):
"""Test loading a markdown file."""
documents = self.loader.load_file(self.md_file_path)
# Check that we got one document
self.assertEqual(len(documents), 1)
# Check the document content
document = documents[0]
self.assertEqual(document.page_content, "# Test Markdown\nThis is a test markdown file.")
# Check the metadata
self.assertEqual(document.metadata["reference"], self.md_file_path)
@patch("pdfplumber.open")
def test_load_pdf_file(self, mock_pdf_open):
"""Test loading a PDF file."""
# Set up mock PDF pages
mock_page1 = MagicMock()
mock_page1.extract_text.return_value = "Page 1 content"
mock_page2 = MagicMock()
mock_page2.extract_text.return_value = "Page 2 content"
# Set up mock PDF file
mock_pdf = MagicMock()
mock_pdf.pages = [mock_page1, mock_page2]
mock_pdf.__enter__.return_value = mock_pdf
mock_pdf.__exit__.return_value = None
# Configure the mock to return our mock PDF
mock_pdf_open.return_value = mock_pdf
# Create a dummy PDF file
with open(self.pdf_file_path, "w") as f:
f.write("dummy pdf content")
# Load the PDF file
documents = self.loader.load_file(self.pdf_file_path)
# Verify pdfplumber.open was called
mock_pdf_open.assert_called_once_with(self.pdf_file_path)
# Check that we got one document
self.assertEqual(len(documents), 1)
# Check the document content
document = documents[0]
self.assertEqual(document.page_content, "Page 1 content\n\nPage 2 content")
# Check the metadata
self.assertEqual(document.metadata["reference"], self.pdf_file_path)
def test_load_directory(self):
"""Test loading a directory with mixed file types."""
# Create the loader
loader = PDFLoader()
# Mock the load_file method to track calls
original_load_file = loader.load_file
calls = []
def mock_load_file(file_path):
calls.append(file_path)
return original_load_file(file_path)
loader.load_file = mock_load_file
# Load the directory
documents = loader.load_directory(self.temp_dir.name)
# Check that we processed both text and markdown files
self.assertEqual(len(calls), 2) # text and markdown files
self.assertIn(self.text_file_path, calls)
self.assertIn(self.md_file_path, calls)
# Check that we got two documents
self.assertEqual(len(documents), 2)
if __name__ == "__main__":
unittest.main()

82
tests/loader/file_loader/test_text_loader.py

@ -1,82 +0,0 @@
import unittest
import os
import tempfile
from deepsearcher.loader.file_loader import TextLoader
class TestTextLoader(unittest.TestCase):
"""Tests for the TextLoader class."""
def setUp(self):
"""Set up the test environment."""
self.loader = TextLoader()
# Create a temporary directory and file for testing
self.temp_dir = tempfile.TemporaryDirectory()
self.test_file_path = os.path.join(self.temp_dir.name, "test.txt")
self.test_content = "This is a test file content.\nWith multiple lines."
# Write test content to the file
with open(self.test_file_path, "w", encoding="utf-8") as f:
f.write(self.test_content)
def tearDown(self):
"""Clean up the test environment."""
self.temp_dir.cleanup()
def test_supported_file_types(self):
"""Test the supported_file_types property."""
supported_types = self.loader.supported_file_types
self.assertIsInstance(supported_types, list)
self.assertIn("txt", supported_types)
self.assertIn("md", supported_types)
def test_load_file(self):
"""Test loading a text file."""
documents = self.loader.load_file(self.test_file_path)
# Check that we got a list with one document
self.assertIsInstance(documents, list)
self.assertEqual(len(documents), 1)
# Check the document content
document = documents[0]
self.assertEqual(document.page_content, self.test_content)
# Check the metadata
self.assertIn("reference", document.metadata)
self.assertEqual(document.metadata["reference"], self.test_file_path)
def test_load_directory(self):
"""Test loading a directory with text files."""
# Create additional test files
md_file_path = os.path.join(self.temp_dir.name, "test.md")
with open(md_file_path, "w", encoding="utf-8") as f:
f.write("# Markdown Test\nThis is a markdown file.")
# Create a non-supported file
pdf_file_path = os.path.join(self.temp_dir.name, "test.pdf")
with open(pdf_file_path, "w", encoding="utf-8") as f:
f.write("PDF content")
# Load the directory
documents = self.loader.load_directory(self.temp_dir.name)
# Check that we got documents for supported files only
self.assertEqual(len(documents), 2)
# Get references
references = [doc.metadata["reference"] for doc in documents]
# Check that supported files were loaded
self.assertIn(self.test_file_path, references)
self.assertIn(md_file_path, references)
# Check that unsupported file was not loaded
for doc in documents:
self.assertNotEqual(doc.metadata["reference"], pdf_file_path)
if __name__ == "__main__":
unittest.main()

203
tests/loader/file_loader/test_unstructured_loader.py

@ -1,203 +0,0 @@
import unittest
import os
import shutil
import tempfile
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.file_loader import UnstructuredLoader
class TestUnstructuredLoader(unittest.TestCase):
"""Tests for the UnstructuredLoader class."""
def setUp(self):
"""Set up test fixtures."""
# Create a temporary directory for tests
self.temp_dir = tempfile.TemporaryDirectory()
# Create a test file
self.test_file_path = os.path.join(self.temp_dir.name, "test.txt")
with open(self.test_file_path, "w", encoding="utf-8") as f:
f.write("This is a test file.")
# Path for mock processed outputs
self.mock_output_dir = os.path.join(self.temp_dir.name, "mock_outputs")
os.makedirs(self.mock_output_dir, exist_ok=True)
# Create a mock JSON output file
self.mock_json_path = os.path.join(self.mock_output_dir, "test_output.json")
with open(self.mock_json_path, "w", encoding="utf-8") as f:
f.write('{"elements": [{"text": "This is extracted text.", "metadata": {"filename": "test.txt"}}]}')
# Set up patches for unstructured modules
self.unstructured_modules = {
'unstructured_ingest': MagicMock(),
'unstructured_ingest.interfaces': MagicMock(),
'unstructured_ingest.pipeline': MagicMock(),
'unstructured_ingest.pipeline.pipeline': MagicMock(),
'unstructured_ingest.processes': MagicMock(),
'unstructured_ingest.processes.connectors': MagicMock(),
'unstructured_ingest.processes.connectors.local': MagicMock(),
'unstructured_ingest.processes.partitioner': MagicMock(),
'unstructured': MagicMock(),
'unstructured.staging': MagicMock(),
'unstructured.staging.base': MagicMock(),
}
self.patches = []
for module_name, mock_module in self.unstructured_modules.items():
patcher = patch.dict('sys.modules', {module_name: mock_module})
patcher.start()
self.patches.append(patcher)
# Create mock Pipeline class
self.mock_pipeline = MagicMock()
self.unstructured_modules['unstructured_ingest.pipeline.pipeline'].Pipeline = self.mock_pipeline
self.mock_pipeline.from_configs.return_value = self.mock_pipeline
# Create mock Element class
self.mock_element = MagicMock()
self.mock_element.text = "This is extracted text."
self.mock_element.metadata = MagicMock()
self.mock_element.metadata.to_dict.return_value = {"filename": "test.txt"}
# Set up elements_from_json mock
self.unstructured_modules['unstructured.staging.base'].elements_from_json = MagicMock()
self.unstructured_modules['unstructured.staging.base'].elements_from_json.return_value = [self.mock_element]
# Patch makedirs and rmtree but don't assert on them
with patch('os.makedirs'):
with patch('shutil.rmtree'):
self.loader = UnstructuredLoader()
def tearDown(self):
"""Clean up test fixtures."""
# Stop all patches
for patcher in self.patches:
patcher.stop()
# Remove temporary directory
self.temp_dir.cleanup()
def test_init(self):
"""Test initialization."""
self.assertEqual(self.loader.directory_with_results, "./pdf_processed_outputs")
def test_supported_file_types(self):
"""Test the supported_file_types property."""
file_types = self.loader.supported_file_types
# Check that common file types are included
common_types = ["pdf", "docx", "txt", "html", "md", "jpg"]
for file_type in common_types:
self.assertIn(file_type, file_types)
# Check total number of supported types (should be extensive)
self.assertGreater(len(file_types), 20)
@patch('os.listdir')
def test_load_file(self, mock_listdir):
"""Test loading a single file."""
# Configure mocks
mock_listdir.return_value = ["test_output.json"]
# Call the method
documents = self.loader.load_file(self.test_file_path)
# Verify Pipeline.from_configs was called
self.mock_pipeline.from_configs.assert_called_once()
self.mock_pipeline.run.assert_called_once()
# Verify elements_from_json was called
self.unstructured_modules['unstructured.staging.base'].elements_from_json.assert_called_once()
# Check results
self.assertEqual(len(documents), 1)
self.assertEqual(documents[0].page_content, "This is extracted text.")
self.assertEqual(documents[0].metadata["reference"], self.test_file_path)
self.assertEqual(documents[0].metadata["filename"], "test.txt")
@patch('os.listdir')
def test_load_directory(self, mock_listdir):
"""Test loading a directory."""
# Configure mocks
mock_listdir.return_value = ["test_output.json"]
# Call the method
documents = self.loader.load_directory(self.temp_dir.name)
# Verify Pipeline.from_configs was called
self.mock_pipeline.from_configs.assert_called_once()
self.mock_pipeline.run.assert_called_once()
# Check results
self.assertEqual(len(documents), 1)
self.assertEqual(documents[0].page_content, "This is extracted text.")
self.assertEqual(documents[0].metadata["reference"], self.temp_dir.name)
@patch('os.listdir')
def test_load_with_api(self, mock_listdir):
"""Test loading with API environment variables."""
# Create a mock for os.environ.get
with patch('os.environ.get') as mock_env_get:
# Configure environment variables
mock_env_get.side_effect = lambda key, default=None: {
"UNSTRUCTURED_API_KEY": "test-key",
"UNSTRUCTURED_API_URL": "https://api.example.com"
}.get(key, default)
# Configure listdir mock
mock_listdir.return_value = ["test_output.json"]
# Create a mock for PartitionerConfig
mock_partitioner_config = MagicMock()
self.unstructured_modules['unstructured_ingest.processes.partitioner'].PartitionerConfig = mock_partitioner_config
# Call the method
documents = self.loader.load_file(self.test_file_path)
# Verify Pipeline.from_configs was called
self.mock_pipeline.from_configs.assert_called_once()
# Check that PartitionerConfig was called with correct parameters
mock_partitioner_config.assert_called_once()
args, kwargs = mock_partitioner_config.call_args
self.assertTrue(kwargs.get('partition_by_api'))
self.assertEqual(kwargs.get('api_key'), "test-key")
self.assertEqual(kwargs.get('partition_endpoint'), "https://api.example.com")
# Check results
self.assertEqual(len(documents), 1)
@patch('os.listdir')
def test_empty_output(self, mock_listdir):
"""Test handling of empty output directory."""
# Configure listdir to return no JSON files
mock_listdir.return_value = []
# Call the method
documents = self.loader.load_file(self.test_file_path)
# Check results
self.assertEqual(len(documents), 0)
@patch('os.listdir')
def test_error_reading_json(self, mock_listdir):
"""Test handling of errors when reading JSON files."""
# Configure listdir mock
mock_listdir.return_value = ["test_output.json"]
# Configure elements_from_json to raise an IOError
self.unstructured_modules['unstructured.staging.base'].elements_from_json.side_effect = IOError("Test error")
# Call the method
documents = self.loader.load_file(self.test_file_path)
# Check results (should be empty)
self.assertEqual(len(documents), 0)
if __name__ == "__main__":
unittest.main()

98
tests/loader/test_splitter.py

@ -1,98 +0,0 @@
import unittest
from langchain_core.documents import Document
from deepsearcher.loader.splitter import Chunk, split_docs_to_chunks, _sentence_window_split
class TestSplitter(unittest.TestCase):
"""Tests for the splitter module."""
def test_chunk_init(self):
"""Test initialization of Chunk class."""
# Test with minimal parameters
chunk = Chunk(text="Test text", reference="test_ref")
self.assertEqual(chunk.text, "Test text")
self.assertEqual(chunk.reference, "test_ref")
self.assertEqual(chunk.metadata, {})
self.assertIsNone(chunk.embedding)
# Test with all parameters
metadata = {"key": "value"}
embedding = [0.1, 0.2, 0.3]
chunk = Chunk(text="Test text", reference="test_ref", metadata=metadata, embedding=embedding)
self.assertEqual(chunk.text, "Test text")
self.assertEqual(chunk.reference, "test_ref")
self.assertEqual(chunk.metadata, metadata)
self.assertEqual(chunk.embedding, embedding)
def test_sentence_window_split(self):
"""Test _sentence_window_split function."""
# Create a test document
original_text = "This is a test document. It has multiple sentences. This is for testing the splitter."
original_doc = Document(page_content=original_text, metadata={"reference": "test_doc"})
# Create split documents
split_docs = [
Document(page_content="This is a test document.", metadata={"reference": "test_doc"}),
Document(page_content="It has multiple sentences.", metadata={"reference": "test_doc"}),
Document(page_content="This is for testing the splitter.", metadata={"reference": "test_doc"})
]
# Test with default offset
chunks = _sentence_window_split(split_docs, original_doc)
# Verify the results
self.assertEqual(len(chunks), 3)
for i, chunk in enumerate(chunks):
self.assertEqual(chunk.text, split_docs[i].page_content)
self.assertEqual(chunk.reference, "test_doc")
self.assertIn("wider_text", chunk.metadata)
# The wider text should contain the original text since our test document is short
self.assertEqual(chunk.metadata["wider_text"], original_text)
# Test with smaller offset
chunks = _sentence_window_split(split_docs, original_doc, offset=10)
# Verify the results with smaller context windows
self.assertEqual(len(chunks), 3)
for chunk in chunks:
# With smaller offset, wider_text should be shorter than the full original text
self.assertLessEqual(len(chunk.metadata["wider_text"]), len(original_text))
def test_split_docs_to_chunks(self):
"""Test split_docs_to_chunks function."""
# Create test documents
docs = [
Document(
page_content="This is document one. It has some content for testing.",
metadata={"reference": "doc1"}
),
Document(
page_content="This is document two. It also has content for testing purposes.",
metadata={"reference": "doc2"}
)
]
# Test with default parameters
chunks = split_docs_to_chunks(docs)
# Verify the results
self.assertGreater(len(chunks), 0)
for chunk in chunks:
self.assertIsInstance(chunk, Chunk)
self.assertIn(chunk.reference, ["doc1", "doc2"])
self.assertIn("wider_text", chunk.metadata)
# Test with custom chunk size and overlap
chunks = split_docs_to_chunks(docs, chunk_size=10, chunk_overlap=2)
# With small chunk size, we should get more chunks
self.assertGreater(len(chunks), 2)
for chunk in chunks:
self.assertIsInstance(chunk, Chunk)
self.assertIn(chunk.reference, ["doc1", "doc2"])
self.assertIn("wider_text", chunk.metadata)
if __name__ == "__main__":
unittest.main()

1
tests/loader/web_crawler/__init__.py

@ -1 +0,0 @@
# Tests for the deepsearcher.loader.web_crawler package

53
tests/loader/web_crawler/test_base.py

@ -1,53 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from deepsearcher.loader.web_crawler.base import BaseCrawler
class TestBaseCrawler(unittest.TestCase):
"""Tests for the BaseCrawler class."""
def test_abstract_methods(self):
"""Test that BaseCrawler defines abstract methods."""
# For abstract base classes, we can check if methods are defined
# but not implemented in the base class
self.assertTrue(hasattr(BaseCrawler, 'crawl_url'))
def test_crawl_urls(self):
"""Test the crawl_urls method."""
# Create a subclass of BaseCrawler for testing
class TestCrawler(BaseCrawler):
def crawl_url(self, url, **kwargs):
# Mock implementation that returns a list of documents
from langchain_core.documents import Document
return [Document(
page_content=f"Content from {url}",
metadata={"reference": url, "kwargs": kwargs}
)]
# Create test URLs
urls = [
"https://example.com",
"https://example.org",
"https://example.net"
]
# Test crawling multiple URLs
crawler = TestCrawler()
documents = crawler.crawl_urls(urls, param1="value1")
# Check the results
self.assertEqual(len(documents), 3) # One document per URL
# Verify each document
references = [doc.metadata["reference"] for doc in documents]
for url in urls:
self.assertIn(url, references)
# Check that kwargs were passed correctly
for doc in documents:
self.assertEqual(doc.metadata["kwargs"]["param1"], "value1")
if __name__ == "__main__":
unittest.main()

157
tests/loader/web_crawler/test_crawl4ai_crawler.py

@ -1,157 +0,0 @@
import unittest
import asyncio
from unittest.mock import patch, MagicMock
import warnings
from langchain_core.documents import Document
from deepsearcher.loader.web_crawler import Crawl4AICrawler
class TestCrawl4AICrawler(unittest.TestCase):
"""Tests for the Crawl4AICrawler class."""
def setUp(self):
"""Set up test fixtures."""
# Create a mock for the crawl4ai module
warnings.filterwarnings('ignore', message='coroutine.*never awaited')
self.crawl4ai_patcher = patch.dict('sys.modules', {'crawl4ai': MagicMock()})
self.crawl4ai_patcher.start()
# Create mocks for the classes
self.mock_async_web_crawler = MagicMock()
self.mock_browser_config = MagicMock()
# Set up the from_kwargs method
self.mock_config_instance = MagicMock()
self.mock_browser_config.from_kwargs.return_value = self.mock_config_instance
# Add the mocks to the crawl4ai module
import sys
sys.modules['crawl4ai'].AsyncWebCrawler = self.mock_async_web_crawler
sys.modules['crawl4ai'].BrowserConfig = self.mock_browser_config
# Set up mock instances
self.mock_crawler_instance = MagicMock()
self.mock_async_web_crawler.return_value = self.mock_crawler_instance
# For context manager behavior
self.mock_crawler_instance.__aenter__.return_value = self.mock_crawler_instance
self.mock_crawler_instance.__aexit__.return_value = None
# Create test browser_config
self.test_browser_config = {"headless": True}
# Create the crawler
self.crawler = Crawl4AICrawler(browser_config=self.test_browser_config)
def tearDown(self):
"""Clean up test fixtures."""
self.crawl4ai_patcher.stop()
def test_init(self):
"""Test initialization."""
# Verify that the browser_config was stored
self.assertEqual(self.crawler.browser_config, self.test_browser_config)
# Verify that the crawler is not initialized
self.assertIsNone(self.crawler.crawler)
def test_lazy_init(self):
"""Test the lazy initialization of the crawler."""
# Call _lazy_init method
self.crawler._lazy_init()
# Verify BrowserConfig.from_kwargs was called
self.mock_browser_config.from_kwargs.assert_called_once_with(self.test_browser_config)
# Verify AsyncWebCrawler was initialized
self.mock_async_web_crawler.assert_called_once_with(config=self.mock_config_instance)
# Verify that the crawler is now set
self.assertEqual(self.crawler.crawler, self.mock_crawler_instance)
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run')
def test_crawl_url(self, mock_asyncio_run):
"""Test crawling a single URL."""
url = "https://example.com"
# Set up mock document
mock_document = Document(
page_content="# Example Page\nThis is a test page.",
metadata={"reference": url, "title": "Example Page"}
)
# Configure asyncio.run to return a document
mock_asyncio_run.return_value = mock_document
# Call the method
documents = self.crawler.crawl_url(url)
# Verify asyncio.run was called with _async_crawl
mock_asyncio_run.assert_called_once()
# Check results
self.assertEqual(len(documents), 1)
self.assertEqual(documents[0], mock_document)
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run')
def test_crawl_url_error(self, mock_asyncio_run):
"""Test error handling when crawling a URL."""
url = "https://example.com"
# Configure asyncio.run to raise an exception
mock_asyncio_run.side_effect = Exception("Test error")
# Call the method
documents = self.crawler.crawl_url(url)
# Should return empty list on error
self.assertEqual(documents, [])
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run')
def test_crawl_urls(self, mock_asyncio_run):
"""Test crawling multiple URLs."""
urls = ["https://example.com", "https://example.org"]
# Set up mock documents
mock_documents = [
Document(
page_content="# Example Page 1\nThis is test page 1.",
metadata={"reference": urls[0], "title": "Example Page 1"}
),
Document(
page_content="# Example Page 2\nThis is test page 2.",
metadata={"reference": urls[1], "title": "Example Page 2"}
)
]
# Configure asyncio.run to return documents
mock_asyncio_run.return_value = mock_documents
# Call the method
documents = self.crawler.crawl_urls(urls)
# Verify asyncio.run was called with _async_crawl_many
mock_asyncio_run.assert_called_once()
# Check results
self.assertEqual(documents, mock_documents)
@patch('deepsearcher.loader.web_crawler.crawl4ai_crawler.asyncio.run')
def test_crawl_urls_error(self, mock_asyncio_run):
"""Test error handling when crawling multiple URLs."""
urls = ["https://example.com", "https://example.org"]
# Configure asyncio.run to raise an exception
mock_asyncio_run.side_effect = Exception("Test error")
# Call the method
documents = self.crawler.crawl_urls(urls)
# Should return empty list on error
self.assertEqual(documents, [])
if __name__ == "__main__":
unittest.main()

157
tests/loader/web_crawler/test_docling_crawler.py

@ -1,157 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.web_crawler import DoclingCrawler
class TestDoclingCrawler(unittest.TestCase):
"""Tests for the DoclingCrawler class."""
def setUp(self):
"""Set up test fixtures."""
# Create mocks for the docling modules
self.docling_patcher = patch.dict('sys.modules', {
'docling': MagicMock(),
'docling.document_converter': MagicMock(),
'docling_core': MagicMock(),
'docling_core.transforms': MagicMock(),
'docling_core.transforms.chunker': MagicMock()
})
self.docling_patcher.start()
# Create mocks for the classes
self.mock_document_converter = MagicMock()
self.mock_hierarchical_chunker = MagicMock()
# Add the mocks to the modules
import sys
sys.modules['docling.document_converter'].DocumentConverter = self.mock_document_converter
sys.modules['docling_core.transforms.chunker'].HierarchicalChunker = self.mock_hierarchical_chunker
# Set up mock instances
self.mock_converter_instance = MagicMock()
self.mock_chunker_instance = MagicMock()
self.mock_document_converter.return_value = self.mock_converter_instance
self.mock_hierarchical_chunker.return_value = self.mock_chunker_instance
# Create the crawler
self.crawler = DoclingCrawler()
def tearDown(self):
"""Clean up test fixtures."""
self.docling_patcher.stop()
def test_init(self):
"""Test initialization."""
# Verify instances were created
self.mock_document_converter.assert_called_once()
self.mock_hierarchical_chunker.assert_called_once()
# Check that the instances were assigned correctly
self.assertEqual(self.crawler.converter, self.mock_converter_instance)
self.assertEqual(self.crawler.chunker, self.mock_chunker_instance)
def test_crawl_url(self):
"""Test crawling a URL."""
url = "https://example.com"
# Set up mock document and chunks
mock_document = MagicMock()
mock_conversion_result = MagicMock()
mock_conversion_result.document = mock_document
# Set up three mock chunks
mock_chunks = []
for i in range(3):
chunk = MagicMock()
chunk.text = f"Chunk {i} content"
mock_chunks.append(chunk)
# Configure mock converter and chunker
self.mock_converter_instance.convert.return_value = mock_conversion_result
self.mock_chunker_instance.chunk.return_value = mock_chunks
# Call the method
documents = self.crawler.crawl_url(url)
# Verify converter was called correctly
self.mock_converter_instance.convert.assert_called_once_with(url)
# Verify chunker was called correctly
self.mock_chunker_instance.chunk.assert_called_once_with(mock_document)
# Check results
self.assertEqual(len(documents), 3)
# Check each document
for i, document in enumerate(documents):
self.assertEqual(document.page_content, f"Chunk {i} content")
self.assertEqual(document.metadata["reference"], url)
self.assertEqual(document.metadata["text"], f"Chunk {i} content")
def test_crawl_url_error(self):
"""Test error handling when crawling a URL."""
url = "https://example.com"
# Configure converter to raise an exception
self.mock_converter_instance.convert.side_effect = Exception("Test error")
# Verify that the error is propagated
with self.assertRaises(IOError):
self.crawler.crawl_url(url)
def test_supported_file_types(self):
"""Test the supported_file_types property."""
file_types = self.crawler.supported_file_types
# Check that all expected file types are included
expected_types = [
"pdf", "docx", "xlsx", "pptx", "md", "adoc", "asciidoc",
"html", "xhtml", "csv", "png", "jpg", "jpeg", "tif", "tiff", "bmp"
]
for file_type in expected_types:
self.assertIn(file_type, file_types)
# Check that the count matches
self.assertEqual(len(file_types), len(expected_types))
def test_crawl_urls(self):
"""Test crawling multiple URLs."""
urls = ["https://example.com", "https://example.org"]
# Set up mock document and chunks for each URL
mock_document = MagicMock()
mock_conversion_result = MagicMock()
mock_conversion_result.document = mock_document
# Set up one mock chunk per URL
mock_chunk = MagicMock()
mock_chunk.text = "Test chunk content"
# Configure mock converter and chunker
self.mock_converter_instance.convert.return_value = mock_conversion_result
self.mock_chunker_instance.chunk.return_value = [mock_chunk]
# Call the method
documents = self.crawler.crawl_urls(urls)
# Verify converter was called for each URL
self.assertEqual(self.mock_converter_instance.convert.call_count, 2)
# Verify chunker was called for each document
self.assertEqual(self.mock_chunker_instance.chunk.call_count, 2)
# Check results
self.assertEqual(len(documents), 2)
# Each URL should have generated one document (with one chunk)
for document in documents:
self.assertEqual(document.page_content, "Test chunk content")
self.assertIn(document.metadata["reference"], urls)
if __name__ == "__main__":
unittest.main()

135
tests/loader/web_crawler/test_firecrawl_crawler.py

@ -1,135 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
from langchain_core.documents import Document
from deepsearcher.loader.web_crawler import FireCrawlCrawler
class TestFireCrawlCrawler(unittest.TestCase):
"""Tests for the FireCrawlCrawler class."""
def setUp(self):
"""Set up test fixtures."""
# Patch the environment variable
self.env_patcher = patch.dict('os.environ', {'FIRECRAWL_API_KEY': 'fake-api-key'})
self.env_patcher.start()
# Create a mock for the FirecrawlApp
self.firecrawl_app_patcher = patch('deepsearcher.loader.web_crawler.firecrawl_crawler.FirecrawlApp')
self.mock_firecrawl_app = self.firecrawl_app_patcher.start()
# Set up mock instances
self.mock_app_instance = MagicMock()
self.mock_firecrawl_app.return_value = self.mock_app_instance
# Create the crawler
self.crawler = FireCrawlCrawler()
def tearDown(self):
"""Clean up test fixtures."""
self.env_patcher.stop()
self.firecrawl_app_patcher.stop()
def test_init(self):
"""Test initialization."""
self.assertIsNone(self.crawler.app)
def test_crawl_url_single_page(self):
"""Test crawling a single URL."""
url = "https://example.com"
# Set up mock response for scrape_url
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"markdown": "# Example Page\nThis is a test page.",
"metadata": {"title": "Example Page", "url": url}
}
self.mock_app_instance.scrape_url.return_value = mock_response
# Call the method
documents = self.crawler.crawl_url(url)
# Verify FirecrawlApp was initialized
self.mock_firecrawl_app.assert_called_once_with(api_key='fake-api-key')
# Verify scrape_url was called correctly
self.mock_app_instance.scrape_url.assert_called_once_with(url=url, formats=["markdown"])
# Check results
self.assertEqual(len(documents), 1)
document = documents[0]
self.assertEqual(document.page_content, "# Example Page\nThis is a test page.")
self.assertEqual(document.metadata["reference"], url)
self.assertEqual(document.metadata["title"], "Example Page")
def test_crawl_url_multiple_pages(self):
"""Test crawling multiple pages recursively."""
url = "https://example.com"
max_depth = 3
limit = 10
# Set up mock response for crawl_url
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"data": [
{
"markdown": "# Page 1\nContent 1",
"metadata": {"title": "Page 1", "url": "https://example.com/page1"}
},
{
"markdown": "# Page 2\nContent 2",
"metadata": {"title": "Page 2", "url": "https://example.com/page2"}
}
]
}
self.mock_app_instance.crawl_url.return_value = mock_response
# Call the method
documents = self.crawler.crawl_url(url, max_depth=max_depth, limit=limit)
# Verify FirecrawlApp was initialized
self.mock_firecrawl_app.assert_called_once_with(api_key='fake-api-key')
# Verify crawl_url was called correctly
self.mock_app_instance.crawl_url.assert_called_once()
call_kwargs = self.mock_app_instance.crawl_url.call_args[1]
self.assertEqual(call_kwargs['url'], url)
self.assertEqual(call_kwargs['max_depth'], max_depth)
self.assertEqual(call_kwargs['limit'], limit)
# Check results
self.assertEqual(len(documents), 2)
# Check first document
self.assertEqual(documents[0].page_content, "# Page 1\nContent 1")
self.assertEqual(documents[0].metadata["reference"], "https://example.com/page1")
self.assertEqual(documents[0].metadata["title"], "Page 1")
# Check second document
self.assertEqual(documents[1].page_content, "# Page 2\nContent 2")
self.assertEqual(documents[1].metadata["reference"], "https://example.com/page2")
self.assertEqual(documents[1].metadata["title"], "Page 2")
def test_crawl_url_with_default_params(self):
"""Test crawling with default parameters."""
url = "https://example.com"
# Set up mock response for crawl_url
mock_response = MagicMock()
mock_response.model_dump.return_value = {"data": []}
self.mock_app_instance.crawl_url.return_value = mock_response
# Call the method with only max_depth
self.crawler.crawl_url(url, max_depth=2)
# Verify default values were used
call_kwargs = self.mock_app_instance.crawl_url.call_args[1]
self.assertEqual(call_kwargs['limit'], 20) # Default limit
self.assertEqual(call_kwargs['max_depth'], 2) # Provided max_depth
self.assertEqual(call_kwargs['allow_backward_links'], False) # Default allow_backward_links
if __name__ == "__main__":
unittest.main()

112
tests/loader/web_crawler/test_jina_crawler.py

@ -1,112 +0,0 @@
import unittest
import os
from unittest.mock import patch, MagicMock
import requests
from langchain_core.documents import Document
from deepsearcher.loader.web_crawler import JinaCrawler
class TestJinaCrawler(unittest.TestCase):
"""Tests for the JinaCrawler class."""
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"})
def test_init_with_token(self):
"""Test initialization with API token in environment."""
crawler = JinaCrawler()
self.assertEqual(crawler.jina_api_token, "fake-token")
@patch.dict(os.environ, {"JINAAI_API_KEY": "fake-key"})
def test_init_with_alternative_key(self):
"""Test initialization with alternative API key in environment."""
crawler = JinaCrawler()
self.assertEqual(crawler.jina_api_token, "fake-key")
@patch.dict(os.environ, {}, clear=True)
def test_init_without_token(self):
"""Test initialization without API token raises ValueError."""
with self.assertRaises(ValueError):
JinaCrawler()
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"})
@patch("requests.get")
def test_crawl_url(self, mock_get):
"""Test crawling a URL."""
# Set up the mock response
mock_response = MagicMock()
mock_response.text = "# Markdown Content\nThis is a test."
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "text/markdown"}
mock_get.return_value = mock_response
# Create the crawler and crawl a test URL
crawler = JinaCrawler()
url = "https://example.com"
documents = crawler.crawl_url(url)
# Check that requests.get was called correctly
mock_get.assert_called_once_with(
f"https://r.jina.ai/{url}",
headers={
"Authorization": "Bearer fake-token",
"X-Return-Format": "markdown",
}
)
# Check the results
self.assertEqual(len(documents), 1)
document = documents[0]
# Check the content
self.assertEqual(document.page_content, mock_response.text)
# Check the metadata
self.assertEqual(document.metadata["reference"], url)
self.assertEqual(document.metadata["status_code"], 200)
self.assertEqual(document.metadata["headers"], {"Content-Type": "text/markdown"})
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"})
@patch("requests.get")
def test_crawl_url_http_error(self, mock_get):
"""Test handling of HTTP errors."""
# Set up the mock response to raise an HTTPError
mock_get.side_effect = requests.exceptions.HTTPError("404 Client Error")
# Create the crawler
crawler = JinaCrawler()
# Crawl a URL and check that the error is propagated
with self.assertRaises(requests.exceptions.HTTPError):
crawler.crawl_url("https://example.com")
@patch.dict(os.environ, {"JINA_API_TOKEN": "fake-token"})
@patch("requests.get")
def test_crawl_urls(self, mock_get):
"""Test crawling multiple URLs."""
# Set up the mock response
mock_response = MagicMock()
mock_response.text = "# Markdown Content\nThis is a test."
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "text/markdown"}
mock_get.return_value = mock_response
# Create the crawler and crawl multiple URLs
crawler = JinaCrawler()
urls = ["https://example.com", "https://example.org"]
documents = crawler.crawl_urls(urls)
# Check that requests.get was called twice
self.assertEqual(mock_get.call_count, 2)
# Check the results
self.assertEqual(len(documents), 2)
# Check that each document has the correct reference
references = [doc.metadata["reference"] for doc in documents]
for url in urls:
self.assertIn(url, references)
if __name__ == "__main__":
unittest.main()

172
tests/utils/test_log.py

@ -1,172 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock, call
import logging
from termcolor import colored
from deepsearcher.utils import log
from deepsearcher.utils.log import ColoredFormatter
class TestColoredFormatter(unittest.TestCase):
"""Tests for the ColoredFormatter class."""
def setUp(self):
"""Set up test fixtures."""
self.formatter = ColoredFormatter("%(levelname)s - %(message)s")
def test_format_debug(self):
"""Test formatting debug level messages."""
record = logging.LogRecord(
"test", logging.DEBUG, "test.py", 1, "Debug message", (), None
)
formatted = self.formatter.format(record)
expected = colored("DEBUG - Debug message", "cyan")
self.assertEqual(formatted, expected)
def test_format_info(self):
"""Test formatting info level messages."""
record = logging.LogRecord(
"test", logging.INFO, "test.py", 1, "Info message", (), None
)
formatted = self.formatter.format(record)
expected = colored("INFO - Info message", "green")
self.assertEqual(formatted, expected)
def test_format_warning(self):
"""Test formatting warning level messages."""
record = logging.LogRecord(
"test", logging.WARNING, "test.py", 1, "Warning message", (), None
)
formatted = self.formatter.format(record)
expected = colored("WARNING - Warning message", "yellow")
self.assertEqual(formatted, expected)
def test_format_error(self):
"""Test formatting error level messages."""
record = logging.LogRecord(
"test", logging.ERROR, "test.py", 1, "Error message", (), None
)
formatted = self.formatter.format(record)
expected = colored("ERROR - Error message", "red")
self.assertEqual(formatted, expected)
def test_format_critical(self):
"""Test formatting critical level messages."""
record = logging.LogRecord(
"test", logging.CRITICAL, "test.py", 1, "Critical message", (), None
)
formatted = self.formatter.format(record)
expected = colored("CRITICAL - Critical message", "magenta")
self.assertEqual(formatted, expected)
def test_format_unknown_level(self):
"""Test formatting messages with unknown log level."""
record = logging.LogRecord(
"test", 60, "test.py", 1, "Custom level message", (), None
)
record.levelname = "CUSTOM"
formatted = self.formatter.format(record)
expected = colored("CUSTOM - Custom level message", "white")
self.assertEqual(formatted, expected)
class TestLogFunctions(unittest.TestCase):
"""Tests for the logging functions."""
def setUp(self):
"""Set up test fixtures."""
# Reset dev mode before each test
log.set_dev_mode(False)
# Create mock for dev_logger
self.mock_dev_logger = MagicMock()
self.dev_logger_patcher = patch("deepsearcher.utils.log.dev_logger", self.mock_dev_logger)
self.dev_logger_patcher.start()
# Create mock for progress_logger
self.mock_progress_logger = MagicMock()
self.progress_logger_patcher = patch("deepsearcher.utils.log.progress_logger", self.mock_progress_logger)
self.progress_logger_patcher.start()
def tearDown(self):
"""Clean up test fixtures."""
self.dev_logger_patcher.stop()
self.progress_logger_patcher.stop()
def test_set_dev_mode(self):
"""Test setting development mode."""
self.assertFalse(log.dev_mode)
log.set_dev_mode(True)
self.assertTrue(log.dev_mode)
log.set_dev_mode(False)
self.assertFalse(log.dev_mode)
def test_set_level(self):
"""Test setting log level."""
log.set_level(logging.DEBUG)
self.mock_dev_logger.setLevel.assert_called_once_with(logging.DEBUG)
def test_debug_in_dev_mode(self):
"""Test debug logging in dev mode."""
log.set_dev_mode(True)
log.debug("Test debug")
self.mock_dev_logger.debug.assert_called_once_with("Test debug")
def test_debug_not_in_dev_mode(self):
"""Test debug logging not in dev mode."""
log.set_dev_mode(False)
log.debug("Test debug")
self.mock_dev_logger.debug.assert_not_called()
def test_info_in_dev_mode(self):
"""Test info logging in dev mode."""
log.set_dev_mode(True)
log.info("Test info")
self.mock_dev_logger.info.assert_called_once_with("Test info")
def test_info_not_in_dev_mode(self):
"""Test info logging not in dev mode."""
log.set_dev_mode(False)
log.info("Test info")
self.mock_dev_logger.info.assert_not_called()
def test_warning_in_dev_mode(self):
"""Test warning logging in dev mode."""
log.set_dev_mode(True)
log.warning("Test warning")
self.mock_dev_logger.warning.assert_called_once_with("Test warning")
def test_warning_not_in_dev_mode(self):
"""Test warning logging not in dev mode."""
log.set_dev_mode(False)
log.warning("Test warning")
self.mock_dev_logger.warning.assert_not_called()
def test_error_in_dev_mode(self):
"""Test error logging in dev mode."""
log.set_dev_mode(True)
log.error("Test error")
self.mock_dev_logger.error.assert_called_once_with("Test error")
def test_error_not_in_dev_mode(self):
"""Test error logging not in dev mode."""
log.set_dev_mode(False)
log.error("Test error")
self.mock_dev_logger.error.assert_not_called()
def test_critical(self):
"""Test critical logging and exception raising."""
with self.assertRaises(RuntimeError) as context:
log.critical("Test critical")
self.mock_dev_logger.critical.assert_called_once_with("Test critical")
self.assertEqual(str(context.exception), "Test critical")
def test_color_print(self):
"""Test color print function."""
log.color_print("Test message")
self.mock_progress_logger.info.assert_called_once_with("Test message")
if __name__ == "__main__":
unittest.main()

237
tests/vector_db/test_azure_search.py

@ -1,237 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import sys
from deepsearcher.vector_db import AzureSearch
from deepsearcher.vector_db.base import RetrievalResult
class TestAzureSearch(unittest.TestCase):
"""Tests for the Azure Search vector database implementation."""
def setUp(self):
"""Set up test fixtures."""
# Create mock modules
self.mock_azure = MagicMock()
self.mock_search = MagicMock()
self.mock_indexes = MagicMock()
self.mock_models = MagicMock()
self.mock_credentials = MagicMock()
self.mock_exceptions = MagicMock()
# Setup nested structure
self.mock_azure.search = self.mock_search
self.mock_search.documents = self.mock_search
self.mock_search.documents.indexes = self.mock_indexes
self.mock_indexes.models = self.mock_models
self.mock_azure.core = self.mock_credentials
self.mock_azure.core.credentials = self.mock_credentials
self.mock_azure.core.exceptions = self.mock_exceptions
# Mock specific models needed for init_collection
self.mock_models.SearchableField = MagicMock()
self.mock_models.SimpleField = MagicMock()
self.mock_models.SearchField = MagicMock()
self.mock_models.SearchIndex = MagicMock()
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {
'azure': self.mock_azure,
'azure.core': self.mock_credentials,
'azure.core.credentials': self.mock_credentials,
'azure.core.exceptions': self.mock_exceptions,
'azure.search': self.mock_search,
'azure.search.documents': self.mock_search,
'azure.search.documents.indexes': self.mock_indexes,
'azure.search.documents.indexes.models': self.mock_models
})
# Start the patcher
self.module_patcher.start()
# Import after mocking
from deepsearcher.vector_db import AzureSearch
from deepsearcher.vector_db.base import RetrievalResult
self.AzureSearch = AzureSearch
self.RetrievalResult = RetrievalResult
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init(self):
"""Test basic initialization."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Verify initialization
self.assertEqual(azure_search.index_name, "test-index")
self.assertEqual(azure_search.endpoint, "https://test-search.search.windows.net")
self.assertEqual(azure_search.api_key, "test-key")
self.assertEqual(azure_search.vector_field, "content_vector")
self.assertIsNotNone(azure_search.client)
def test_init_collection(self):
"""Test collection initialization."""
# Setup mock
mock_index_client = MagicMock()
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
mock_index_client.create_index.return_value = None
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
azure_search.init_collection()
self.assertTrue(mock_index_client.create_index.called)
def test_insert_data(self):
"""Test inserting data."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock successful upload result
mock_result = [MagicMock(succeeded=True) for _ in range(2)]
mock_client.upload_documents.return_value = mock_result
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Create test data
d = 1536 # Azure Search expects 1536 dimensions
rng = np.random.default_rng(seed=42)
test_docs = [
{
"text": "hello world",
"vector": rng.random(d).tolist(),
"id": "doc1"
},
{
"text": "hello azure search",
"vector": rng.random(d).tolist(),
"id": "doc2"
}
]
results = azure_search.insert_data(documents=test_docs)
self.assertEqual(len(results), 2)
self.assertTrue(all(results))
def test_search_data(self):
"""Test search functionality."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock search results
d = 1536
rng = np.random.default_rng(seed=42)
mock_results = MagicMock()
mock_results.results = [
{
"content": "hello world",
"id": "doc1",
"@search.score": 0.95
},
{
"content": "hello azure search",
"id": "doc2",
"@search.score": 0.85
}
]
mock_client._client.documents.search_post.return_value = mock_results
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Test search
query_vector = rng.random(d).tolist()
results = azure_search.search_data(
collection="test-index",
vector=query_vector,
top_k=2
)
self.assertIsInstance(results, list)
self.assertEqual(len(results), 2)
# Verify results are RetrievalResult objects
for result in results:
self.assertIsInstance(result, self.RetrievalResult)
def test_clear_db(self):
"""Test clearing database."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock search results for documents to delete
mock_client.search.return_value = [
{"id": "doc1"},
{"id": "doc2"}
]
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
deleted_count = azure_search.clear_db()
self.assertEqual(deleted_count, 2)
def test_list_collections(self):
"""Test listing collections."""
# Setup mock
mock_index_client = MagicMock()
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
# Mock list_indexes response
mock_index1 = MagicMock()
mock_index1.name = "test-index-1"
mock_index1.fields = ["field1", "field2"]
mock_index2 = MagicMock()
mock_index2.name = "test-index-2"
mock_index2.fields = ["field1", "field2", "field3"]
mock_index_client.list_indexes.return_value = [mock_index1, mock_index2]
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
collections = azure_search.list_collections()
self.assertIsInstance(collections, list)
self.assertEqual(len(collections), 2)
if __name__ == "__main__":
unittest.main()

157
tests/vector_db/test_base.py

@ -1,157 +0,0 @@
import unittest
import numpy as np
from typing import List
from deepsearcher.vector_db.base import (
RetrievalResult,
deduplicate_results,
CollectionInfo,
BaseVectorDB,
)
from deepsearcher.loader.splitter import Chunk
class TestRetrievalResult(unittest.TestCase):
"""Tests for the RetrievalResult class."""
def setUp(self):
"""Set up test fixtures."""
self.embedding = np.array([0.1, 0.2, 0.3])
self.text = "Test text"
self.reference = "test.txt"
self.metadata = {"key": "value"}
self.score = 0.95
def test_init(self):
"""Test initialization of RetrievalResult."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
score=self.score,
)
self.assertTrue(np.array_equal(result.embedding, self.embedding))
self.assertEqual(result.text, self.text)
self.assertEqual(result.reference, self.reference)
self.assertEqual(result.metadata, self.metadata)
self.assertEqual(result.score, self.score)
def test_init_default_score(self):
"""Test initialization of RetrievalResult with default score."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
)
self.assertEqual(result.score, 0.0)
def test_repr(self):
"""Test string representation of RetrievalResult."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
score=self.score,
)
expected = f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
self.assertEqual(repr(result), expected)
class TestDeduplicateResults(unittest.TestCase):
"""Tests for the deduplicate_results function."""
def setUp(self):
"""Set up test fixtures."""
self.embedding1 = np.array([0.1, 0.2, 0.3])
self.embedding2 = np.array([0.4, 0.5, 0.6])
self.text1 = "Text 1"
self.text2 = "Text 2"
self.reference = "test.txt"
self.metadata = {"key": "value"}
def test_no_duplicates(self):
"""Test deduplication with no duplicate results."""
results = [
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
]
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 2)
self.assertEqual(deduplicated, results)
def test_with_duplicates(self):
"""Test deduplication with duplicate results."""
results = [
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
]
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 2)
self.assertEqual(deduplicated[0].text, self.text1)
self.assertEqual(deduplicated[1].text, self.text2)
def test_empty_list(self):
"""Test deduplication with empty list."""
results = []
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 0)
class TestCollectionInfo(unittest.TestCase):
"""Tests for the CollectionInfo class."""
def test_init(self):
"""Test initialization of CollectionInfo."""
name = "test_collection"
description = "Test collection description"
collection_info = CollectionInfo(name, description)
self.assertEqual(collection_info.collection_name, name)
self.assertEqual(collection_info.description, description)
class MockVectorDB(BaseVectorDB):
"""Mock implementation of BaseVectorDB for testing."""
def init_collection(self, dim, collection, description, force_new_collection=False, *args, **kwargs):
pass
def insert_data(self, collection, chunks, *args, **kwargs):
pass
def search_data(self, collection, vector, *args, **kwargs) -> List[RetrievalResult]:
return []
def clear_db(self, *args, **kwargs):
pass
class TestBaseVectorDB(unittest.TestCase):
"""Tests for the BaseVectorDB class."""
def setUp(self):
"""Set up test fixtures."""
self.db = MockVectorDB()
def test_init_default(self):
"""Test initialization with default collection name."""
self.assertEqual(self.db.default_collection, "deepsearcher")
def test_init_custom_collection(self):
"""Test initialization with custom collection name."""
custom_collection = "custom_collection"
db = MockVectorDB(default_collection=custom_collection)
self.assertEqual(db.default_collection, custom_collection)
def test_list_collections_default(self):
"""Test default list_collections implementation."""
self.assertIsNone(self.db.list_collections())
if __name__ == "__main__":
unittest.main()

141
tests/vector_db/test_milvus.py

@ -1,141 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import warnings
# Filter out the pkg_resources deprecation warning from milvus_lite
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
from deepsearcher.vector_db import Milvus
from deepsearcher.loader.splitter import Chunk
from deepsearcher.vector_db.base import RetrievalResult
class TestMilvus(unittest.TestCase):
"""Simple tests for the Milvus vector database implementation."""
def test_init(self):
"""Test basic initialization."""
milvus = Milvus(
default_collection="test_collection",
uri="./milvus.db",
hybrid=False
)
# Verify initialization - just check basic properties
self.assertEqual(milvus.default_collection, "test_collection")
self.assertFalse(milvus.hybrid)
self.assertIsNotNone(milvus.client)
def test_init_collection(self):
"""Test collection initialization."""
milvus = Milvus(uri="./milvus.db")
# Test collection initialization
d = 8
collection = "hello_deepsearcher"
try:
milvus.init_collection(dim=d, collection=collection)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "init_collection should work")
def test_insert_data_with_retrieval_results(self):
"""Test inserting data using RetrievalResult objects."""
milvus = Milvus(uri="./milvus.db")
# Create test data
d = 8
collection = "hello_deepsearcher"
rng = np.random.default_rng(seed=19530)
# Create RetrievalResult objects
test_data = [
RetrievalResult(
embedding=rng.random((1, d))[0],
text="hello world",
reference="local file: hi.txt",
metadata={"a": 1},
),
RetrievalResult(
embedding=rng.random((1, d))[0],
text="hello milvus",
reference="local file: hi.txt",
metadata={"a": 1},
),
]
try:
milvus.insert_data(collection=collection, chunks=test_data)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "insert_data should work with RetrievalResult objects")
def test_search_data(self):
"""Test search functionality."""
milvus = Milvus(uri="./milvus.db")
# Test search
d = 8
collection = "hello_deepsearcher"
rng = np.random.default_rng(seed=19530)
query_vector = rng.random((1, d))[0]
try:
top_2 = milvus.search_data(
collection=collection,
vector=query_vector,
top_k=2
)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "search_data should work")
if test_passed:
self.assertIsInstance(top_2, list)
# Note: In an empty collection, we might not get 2 results
self.assertIsInstance(top_2[0], RetrievalResult) if top_2 else None
def test_clear_collection(self):
"""Test clearing collection."""
milvus = Milvus(uri="./milvus.db")
collection = "hello_deepsearcher"
try:
milvus.clear_db(collection=collection)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "clear_db should work")
def test_list_collections(self):
"""Test listing collections."""
milvus = Milvus(uri="./milvus.db")
try:
collections = milvus.list_collections()
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "list_collections should work")
if test_passed:
self.assertIsInstance(collections, list)
self.assertGreaterEqual(len(collections), 0)
if __name__ == "__main__":
unittest.main()

255
tests/vector_db/test_oracle.py

@ -1,255 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import sys
import json
from deepsearcher.vector_db.base import RetrievalResult
from deepsearcher.loader.splitter import Chunk
import logging
logging.disable(logging.CRITICAL)
class TestOracleDB(unittest.TestCase):
"""Tests for the Oracle vector database implementation."""
def setUp(self):
"""Set up test fixtures."""
# Create mock modules
self.mock_oracledb = MagicMock()
# Setup mock DB_TYPE_VECTOR
self.mock_oracledb.DB_TYPE_VECTOR = "VECTOR"
self.mock_oracledb.defaults = MagicMock()
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {
'oracledb': self.mock_oracledb
})
# Start the patcher
self.module_patcher.start()
# Import after mocking
from deepsearcher.vector_db import OracleDB
self.OracleDB = OracleDB
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init(self):
"""Test basic initialization."""
# Setup mock
mock_pool = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd",
default_collection="test_collection"
)
# Verify initialization
self.assertEqual(oracle_db.default_collection, "test_collection")
self.assertIsNotNone(oracle_db.client)
self.mock_oracledb.create_pool.assert_called_once()
self.assertTrue(self.mock_oracledb.defaults.fetch_lobs is False)
def test_insert_data(self):
"""Test inserting data."""
# Setup mock
mock_pool = MagicMock()
mock_connection = MagicMock()
mock_cursor = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd"
)
# Create test data
d = 8
rng = np.random.default_rng(seed=42)
test_chunks = [
Chunk(
embedding=rng.random(d).tolist(),
text="hello world",
reference="test.txt",
metadata={"key": "value1"}
),
Chunk(
embedding=rng.random(d).tolist(),
text="hello oracle",
reference="test.txt",
metadata={"key": "value2"}
)
]
oracle_db.insert_data(collection="test_collection", chunks=test_chunks)
self.assertTrue(mock_cursor.execute.called)
self.assertTrue(mock_connection.commit.called)
def test_search_data(self):
"""Test search functionality."""
# Setup mock
mock_pool = MagicMock()
mock_connection = MagicMock()
mock_cursor = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
# Mock search results
mock_cursor.description = [("embedding",), ("text",), ("reference",), ("distance",), ("metadata",)]
mock_cursor.fetchall.return_value = [
(
np.array([0.1, 0.2, 0.3]),
"hello world",
"test.txt",
0.95,
json.dumps({"key": "value1"})
),
(
np.array([0.4, 0.5, 0.6]),
"hello oracle",
"test.txt",
0.85,
json.dumps({"key": "value2"})
)
]
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd"
)
# Test search
d = 8
rng = np.random.default_rng(seed=42)
query_vector = rng.random(d)
results = oracle_db.search_data(
collection="test_collection",
vector=query_vector,
top_k=2
)
self.assertIsInstance(results, list)
self.assertEqual(len(results), 2)
for result in results:
self.assertIsInstance(result, RetrievalResult)
def test_list_collections(self):
"""Test listing collections."""
# Setup mock
mock_pool = MagicMock()
mock_connection = MagicMock()
mock_cursor = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
# Mock list_collections response
mock_cursor.description = [("collection",), ("description",)]
mock_cursor.fetchall.return_value = [
("test_collection_1", "Test collection 1"),
("test_collection_2", "Test collection 2")
]
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd"
)
collections = oracle_db.list_collections()
self.assertIsInstance(collections, list)
self.assertEqual(len(collections), 2)
def test_clear_db(self):
"""Test clearing database."""
# Setup mock
mock_pool = MagicMock()
mock_connection = MagicMock()
mock_cursor = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd"
)
oracle_db.clear_db("test_collection")
self.assertTrue(mock_cursor.execute.called)
self.assertTrue(mock_connection.commit.called)
def test_has_collection(self):
"""Test checking if collection exists."""
# Setup mock
mock_pool = MagicMock()
mock_connection = MagicMock()
mock_cursor = MagicMock()
self.mock_oracledb.create_pool.return_value = mock_pool
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
# Mock check_table response first (called during init)
mock_cursor.description = [("table_name",)]
mock_cursor.fetchall.return_value = [
("DEEPSEARCHER_COLLECTION_INFO",),
("DEEPSEARCHER_COLLECTION_ITEM",)
]
oracle_db = self.OracleDB(
user="test_user",
password="test_password",
dsn="test_dsn",
config_dir="/test/config",
wallet_location="/test/wallet",
wallet_password="test_wallet_pwd"
)
# Now mock has_collection response - collection exists
mock_cursor.description = [("rowcnt",)]
mock_cursor.fetchall.return_value = [(1,)] # Return tuple, not dict
result = oracle_db.has_collection("test_collection")
self.assertTrue(result)
# Test collection doesn't exist
mock_cursor.fetchall.return_value = [(0,)] # Return tuple, not dict
result = oracle_db.has_collection("nonexistent_collection")
self.assertFalse(result)
if __name__ == "__main__":
unittest.main()

192
tests/vector_db/test_qdrant.py

@ -1,192 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import sys
class TestQdrant(unittest.TestCase):
"""Tests for the Qdrant vector database implementation."""
def setUp(self):
"""Set up test fixtures."""
# Create mock modules
self.mock_qdrant = MagicMock()
self.mock_models = MagicMock()
self.mock_qdrant.models = self.mock_models
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {
'qdrant_client': self.mock_qdrant,
'qdrant_client.models': self.mock_models
})
self.module_patcher.start()
# Import after mocking
from deepsearcher.vector_db import Qdrant
from deepsearcher.loader.splitter import Chunk
from deepsearcher.vector_db.base import RetrievalResult
self.Qdrant = Qdrant
self.Chunk = Chunk
self.RetrievalResult = RetrievalResult
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch('qdrant_client.QdrantClient')
def test_init(self, mock_client_class):
"""Test basic initialization."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
qdrant = self.Qdrant(
location="memory",
url="http://custom:6333",
port=6333,
api_key="test_key",
default_collection="custom"
)
# Verify initialization - just check basic properties
self.assertEqual(qdrant.default_collection, "custom")
self.assertIsNotNone(qdrant.client)
@patch('qdrant_client.QdrantClient')
def test_init_collection(self, mock_client_class):
"""Test collection initialization."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.collection_exists.return_value = False
qdrant = self.Qdrant()
# Test collection initialization
d = 8
collection = "test_collection"
try:
qdrant.init_collection(dim=d, collection=collection)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "init_collection should work")
@patch('qdrant_client.QdrantClient')
def test_insert_data(self, mock_client_class):
"""Test inserting data."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.upsert.return_value = None
qdrant = self.Qdrant()
# Create test data
d = 8
collection = "test_collection"
rng = np.random.default_rng(seed=42)
# Create test chunks with numpy arrays converted to lists
chunks = [
self.Chunk(
embedding=rng.random(d).tolist(), # Convert to list
text="hello world",
reference="test.txt",
metadata={"key": "value1"}
),
self.Chunk(
embedding=rng.random(d).tolist(), # Convert to list
text="hello qdrant",
reference="test.txt",
metadata={"key": "value2"}
)
]
try:
qdrant.insert_data(collection=collection, chunks=chunks)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "insert_data should work")
@patch('qdrant_client.QdrantClient')
def test_search_data(self, mock_client_class):
"""Test search functionality."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock search results
d = 8
rng = np.random.default_rng(seed=42)
mock_point1 = MagicMock()
mock_point1.vector = rng.random(d)
mock_point1.payload = {
"text": "hello world",
"reference": "test.txt",
"metadata": {"key": "value1"}
}
mock_point1.score = 0.95
mock_point2 = MagicMock()
mock_point2.vector = rng.random(d)
mock_point2.payload = {
"text": "hello qdrant",
"reference": "test.txt",
"metadata": {"key": "value2"}
}
mock_point2.score = 0.85
mock_response = MagicMock()
mock_response.points = [mock_point1, mock_point2]
mock_client.query_points.return_value = mock_response
qdrant = self.Qdrant()
# Test search
collection = "test_collection"
query_vector = rng.random(d)
try:
results = qdrant.search_data(
collection=collection,
vector=query_vector,
top_k=2
)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "search_data should work")
if test_passed:
self.assertIsInstance(results, list)
self.assertEqual(len(results), 2)
# Verify results are RetrievalResult objects
for result in results:
self.assertIsInstance(result, self.RetrievalResult)
@patch('qdrant_client.QdrantClient')
def test_clear_collection(self, mock_client_class):
"""Test clearing collection."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.delete_collection.return_value = None
qdrant = self.Qdrant()
collection = "test_collection"
try:
qdrant.clear_db(collection=collection)
test_passed = True
except Exception as e:
test_passed = False
print(f"Error: {e}")
self.assertTrue(test_passed, "clear_db should work")
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save