You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
240 lines
7.8 KiB
240 lines
7.8 KiB
import os
|
|
from typing import Literal
|
|
|
|
import yaml
|
|
|
|
from deepsearcher.agent import ChainOfRAG, DeepSearch, NaiveRAG
|
|
from deepsearcher.agent.rag_router import RAGRouter
|
|
from deepsearcher.embedding.base import BaseEmbedding
|
|
from deepsearcher.llm.base import BaseLLM
|
|
from deepsearcher.loader.file_loader.base import BaseLoader
|
|
from deepsearcher.loader.web_crawler.base import BaseCrawler
|
|
from deepsearcher.vector_db.base import BaseVectorDB
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
DEFAULT_CONFIG_YAML_PATH = os.path.join(current_dir, "config.yaml")
|
|
|
|
FeatureType = Literal["llm", "embedding", "file_loader", "web_crawler", "vector_db"]
|
|
|
|
|
|
class Configuration:
|
|
"""
|
|
Configuration class for DeepSearcher.
|
|
|
|
This class manages the configuration settings for various components of the DeepSearcher system,
|
|
including LLM providers, embedding models, file loaders, web crawlers, and vector databases.
|
|
It loads configurations from a YAML file and provides methods to get and set provider configurations.
|
|
"""
|
|
|
|
def __init__(self, config_path: str = DEFAULT_CONFIG_YAML_PATH):
|
|
"""
|
|
Initialize the Configuration object.
|
|
|
|
Args:
|
|
config_path: Path to the configuration YAML file. Defaults to the config.yaml in the project root.
|
|
"""
|
|
# Initialize default configurations
|
|
config_data = self.load_config_from_yaml(config_path)
|
|
self.provide_settings = config_data["provide_settings"]
|
|
self.query_settings = config_data["query_settings"]
|
|
self.load_settings = config_data["load_settings"]
|
|
|
|
def load_config_from_yaml(self, config_path: str):
|
|
"""
|
|
Load configuration from a YAML file.
|
|
|
|
Args:
|
|
config_path: Path to the configuration YAML file.
|
|
|
|
Returns:
|
|
The loaded configuration data as a dictionary.
|
|
"""
|
|
with open(config_path, "r") as file:
|
|
return yaml.safe_load(file)
|
|
|
|
def set_provider_config(self, feature: FeatureType, provider: str, provider_configs: dict):
|
|
"""
|
|
Set the provider and its configurations for a given feature.
|
|
|
|
Args:
|
|
feature: The feature to configure (e.g., 'llm', 'file_loader', 'web_crawler').
|
|
provider: The provider name (e.g., 'openai', 'deepseek').
|
|
provider_configs: A dictionary with configurations specific to the provider.
|
|
|
|
Raises:
|
|
ValueError: If the feature is not supported.
|
|
"""
|
|
if feature not in self.provide_settings:
|
|
raise ValueError(f"Unsupported feature: {feature}")
|
|
|
|
self.provide_settings[feature]["provider"] = provider
|
|
self.provide_settings[feature]["config"] = provider_configs
|
|
|
|
def get_provider_config(self, feature: FeatureType):
|
|
"""
|
|
Get the current provider and configuration for a given feature.
|
|
|
|
Args:
|
|
feature: The feature to retrieve (e.g., 'llm', 'file_loader', 'web_crawler').
|
|
|
|
Returns:
|
|
A dictionary with provider and its configurations.
|
|
|
|
Raises:
|
|
ValueError: If the feature is not supported.
|
|
"""
|
|
if feature not in self.provide_settings:
|
|
raise ValueError(f"Unsupported feature: {feature}")
|
|
|
|
return self.provide_settings[feature]
|
|
|
|
|
|
class ModuleFactory:
|
|
"""
|
|
Factory class for creating instances of various modules in the DeepSearcher system.
|
|
|
|
This class creates instances of LLMs, embedding models, file loaders, web crawlers,
|
|
and vector databases based on the configuration settings.
|
|
"""
|
|
|
|
def __init__(self, config: Configuration):
|
|
"""
|
|
Initialize the ModuleFactory.
|
|
|
|
Args:
|
|
config: The Configuration object containing provider settings.
|
|
"""
|
|
self.config = config
|
|
|
|
def _create_module_instance(self, feature: FeatureType, module_name: str):
|
|
"""
|
|
Create an instance of a module based on the feature and module name.
|
|
|
|
Args:
|
|
feature: The feature type (e.g., 'llm', 'embedding').
|
|
module_name: The module name to import from.
|
|
|
|
Returns:
|
|
An instance of the specified module.
|
|
"""
|
|
# e.g.
|
|
# feature = "file_loader"
|
|
# module_name = "deepsearcher.loader.file_loader"
|
|
class_name = self.config.provide_settings[feature]["provider"]
|
|
module = __import__(module_name, fromlist=[class_name])
|
|
class_ = getattr(module, class_name)
|
|
return class_(**self.config.provide_settings[feature]["config"])
|
|
|
|
def create_llm(self) -> BaseLLM:
|
|
"""
|
|
Create an instance of a language model.
|
|
|
|
Returns:
|
|
An instance of a BaseLLM implementation.
|
|
"""
|
|
return self._create_module_instance("llm", "deepsearcher.llm")
|
|
|
|
def create_embedding(self) -> BaseEmbedding:
|
|
"""
|
|
Create an instance of an embedding model.
|
|
|
|
Returns:
|
|
An instance of a BaseEmbedding implementation.
|
|
"""
|
|
return self._create_module_instance("embedding", "deepsearcher.embedding")
|
|
|
|
def create_file_loader(self) -> BaseLoader:
|
|
"""
|
|
Create an instance of a file loader.
|
|
|
|
Returns:
|
|
An instance of a BaseLoader implementation.
|
|
"""
|
|
return self._create_module_instance("file_loader", "deepsearcher.loader.file_loader")
|
|
|
|
def create_web_crawler(self) -> BaseCrawler:
|
|
"""
|
|
Create an instance of a web crawler.
|
|
|
|
Returns:
|
|
An instance of a BaseCrawler implementation.
|
|
"""
|
|
return self._create_module_instance("web_crawler", "deepsearcher.loader.web_crawler")
|
|
|
|
def create_vector_db(self) -> BaseVectorDB:
|
|
"""
|
|
Create an instance of a vector database.
|
|
|
|
Returns:
|
|
An instance of a BaseVectorDB implementation.
|
|
"""
|
|
return self._create_module_instance("vector_db", "deepsearcher.vector_db")
|
|
|
|
|
|
config = Configuration()
|
|
|
|
module_factory: ModuleFactory = None
|
|
llm: BaseLLM = None
|
|
embedding_model: BaseEmbedding = None
|
|
file_loader: BaseLoader = None
|
|
vector_db: BaseVectorDB = None
|
|
web_crawler: BaseCrawler = None
|
|
default_searcher: RAGRouter = None
|
|
naive_rag: NaiveRAG = None
|
|
|
|
|
|
def init_config(config: Configuration):
|
|
"""
|
|
Initialize the global configuration and create instances of all required modules.
|
|
|
|
This function initializes the global variables for the LLM, embedding model,
|
|
file loader, web crawler, vector database, and RAG agents.
|
|
|
|
Args:
|
|
config: The Configuration object to use for initialization.
|
|
"""
|
|
global \
|
|
module_factory, \
|
|
llm, \
|
|
embedding_model, \
|
|
file_loader, \
|
|
vector_db, \
|
|
web_crawler, \
|
|
default_searcher, \
|
|
naive_rag
|
|
module_factory = ModuleFactory(config)
|
|
llm = module_factory.create_llm()
|
|
embedding_model = module_factory.create_embedding()
|
|
file_loader = module_factory.create_file_loader()
|
|
web_crawler = module_factory.create_web_crawler()
|
|
vector_db = module_factory.create_vector_db()
|
|
|
|
default_searcher = RAGRouter(
|
|
llm=llm,
|
|
rag_agents=[
|
|
DeepSearch(
|
|
llm=llm,
|
|
embedding_model=embedding_model,
|
|
vector_db=vector_db,
|
|
max_iter=config.query_settings["max_iter"],
|
|
route_collection=True,
|
|
text_window_splitter=True,
|
|
),
|
|
ChainOfRAG(
|
|
llm=llm,
|
|
embedding_model=embedding_model,
|
|
vector_db=vector_db,
|
|
max_iter=config.query_settings["max_iter"],
|
|
route_collection=True,
|
|
text_window_splitter=True,
|
|
),
|
|
],
|
|
)
|
|
naive_rag = NaiveRAG(
|
|
llm=llm,
|
|
embedding_model=embedding_model,
|
|
vector_db=vector_db,
|
|
top_k=10,
|
|
route_collection=True,
|
|
text_window_splitter=True,
|
|
)
|
|
|