You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

215 lines
7.0 KiB

import os
from typing import Literal
import yaml
from deepsearcher.agent import BaseAgent, DeepSearch
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.loader.file_loader.base import BaseLoader
from deepsearcher.loader.web_crawler.base import BaseCrawler
from deepsearcher.vector_db.base import BaseVectorDB
current_dir = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CONFIG_YAML_PATH = os.path.join(current_dir, "config.yaml")
FeatureType = Literal["llm", "embedding", "file_loader", "web_crawler", "vector_db"]
class Configuration:
"""
Configuration class for DeepSearcher.
This class manages the configuration settings for various components of the DeepSearcher system,
including LLM providers, embedding models, file loaders, web crawlers, and vector databases.
It loads configurations from a YAML file and provides methods to get and set provider configurations.
"""
def __init__(self, config_path: str = DEFAULT_CONFIG_YAML_PATH):
"""
Initialize the Configuration object.
Args:
config_path: Path to the configuration YAML file. Defaults to the config.yaml in the project root.
"""
# Initialize default configurations
config_data = self.load_config_from_yaml(config_path)
self.provide_settings = config_data["provide_settings"]
self.query_settings = config_data["query_settings"]
self.load_settings = config_data["load_settings"]
def load_config_from_yaml(self, config_path: str):
"""
Load configuration from a YAML file.
Args:
config_path: Path to the configuration YAML file.
Returns:
The loaded configuration data as a dictionary.
"""
with open(config_path) as file:
return yaml.safe_load(file)
def set_provider_config(self, feature: FeatureType, provider: str, provider_configs: dict):
"""
Set the provider and its configurations for a given feature.
Args:
feature: The feature to configure (e.g., 'llm', 'file_loader', 'web_crawler').
provider: The provider name (e.g., 'openai', 'deepseek').
provider_configs: A dictionary with configurations specific to the provider.
Raises:
ValueError: If the feature is not supported.
"""
if feature not in self.provide_settings:
raise ValueError(f"Unsupported feature: {feature}")
self.provide_settings[feature]["provider"] = provider
self.provide_settings[feature]["config"] = provider_configs
def get_provider_config(self, feature: FeatureType):
"""
Get the current provider and configuration for a given feature.
Args:
feature: The feature to retrieve (e.g., 'llm', 'file_loader', 'web_crawler').
Returns:
A dictionary with provider and its configurations.
Raises:
ValueError: If the feature is not supported.
"""
if feature not in self.provide_settings:
raise ValueError(f"Unsupported feature: {feature}")
return self.provide_settings[feature]
class ModuleFactory:
"""
Factory class for creating instances of various modules in the DeepSearcher system.
This class creates instances of LLMs, embedding models, file loaders, web crawlers,
and vector databases based on the configuration settings.
"""
def __init__(self, config: Configuration):
"""
Initialize the ModuleFactory.
Args:
config: The Configuration object containing provider settings.
"""
self.config = config
def _create_module_instance(self, feature: FeatureType, module_name: str):
"""
Create an instance of a module based on the feature and module name.
Args:
feature: The feature type (e.g., 'llm', 'embedding').
module_name: The module name to import from.
Returns:
An instance of the specified module.
"""
# e.g.
# feature = "file_loader"
# module_name = "deepsearcher.loader.file_loader"
class_name = self.config.provide_settings[feature]["provider"]
module = __import__(module_name, fromlist=[class_name])
class_ = getattr(module, class_name)
return class_(**self.config.provide_settings[feature]["config"])
def create_llm(self) -> BaseLLM:
"""
Create an instance of a language model.
Returns:
An instance of a BaseLLM implementation.
"""
return self._create_module_instance("llm", "deepsearcher.llm")
def create_embedding(self) -> BaseEmbedding:
"""
Create an instance of an embedding model.
Returns:
An instance of a BaseEmbedding implementation.
"""
return self._create_module_instance("embedding", "deepsearcher.embedding")
def create_file_loader(self) -> BaseLoader:
"""
Create an instance of a file loader.
Returns:
An instance of a BaseLoader implementation.
"""
return self._create_module_instance("file_loader", "deepsearcher.loader.file_loader")
def create_web_crawler(self) -> BaseCrawler:
"""
Create an instance of a web crawler.
Returns:
An instance of a BaseCrawler implementation.
"""
return self._create_module_instance("web_crawler", "deepsearcher.loader.web_crawler")
def create_vector_db(self) -> BaseVectorDB:
"""
Create an instance of a vector database.
Returns:
An instance of a BaseVectorDB implementation.
"""
return self._create_module_instance("vector_db", "deepsearcher.vector_db")
config = Configuration()
module_factory: ModuleFactory = None
llm: BaseLLM = None
embedding_model: BaseEmbedding = None
file_loader: BaseLoader = None
vector_db: BaseVectorDB = None
web_crawler: BaseCrawler = None
default_searcher: BaseAgent = None
def init_config(config: Configuration):
"""
Initialize the global configuration and create instances of all required modules.
This function initializes the global variables for the LLM, embedding model,
file loader, web crawler, vector database, and RAG agents.
Args:
config: The Configuration object to use for initialization.
"""
global \
module_factory, \
llm, \
embedding_model, \
file_loader, \
vector_db, \
web_crawler, \
default_searcher
module_factory = ModuleFactory(config)
llm = module_factory.create_llm()
embedding_model = module_factory.create_embedding()
file_loader = module_factory.create_file_loader()
web_crawler = module_factory.create_web_crawler()
vector_db = module_factory.create_vector_db()
default_searcher = DeepSearch(
llm=llm,
embedding_model=embedding_model,
vector_db=vector_db,
max_iter=config.query_settings["max_iter"],
route_collection=False,
text_window_splitter=True,
)