You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
6.5 KiB

2 weeks ago
import os
from typing import Literal
import yaml
from deepsearcher.agent import BaseAgent, DeepSearch
2 weeks ago
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.loader.file_loader.base import BaseLoader
from deepsearcher.vector_db.base import BaseVectorDB
current_dir = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CONFIG_YAML_PATH = os.path.join(current_dir, "config.yaml")
FeatureType = Literal["llm", "embedding", "file_loader", "vector_db"]
2 weeks ago
class Configuration:
"""
Configuration class for DeepSearcher.
This class manages the configuration settings for various components of the DeepSearcher system,
including LLM providers, embedding models, file loaders and vector databases.
2 weeks ago
It loads configurations from a YAML file and provides methods to get and set provider configurations.
"""
def __init__(self, config_path: str = DEFAULT_CONFIG_YAML_PATH):
"""
Initialize the Configuration object.
Args:
config_path: Path to the configuration YAML file. Defaults to the config.yaml in the project root.
"""
# Initialize default configurations
config_data = self.load_config_from_yaml(config_path)
self.provide_settings = config_data["provide_settings"]
self.query_settings = config_data["query_settings"]
self.load_settings = config_data["load_settings"]
def load_config_from_yaml(self, config_path: str):
"""
Load configuration from a YAML file.
Args:
config_path: Path to the configuration YAML file.
Returns:
The loaded configuration data as a dictionary.
"""
with open(config_path) as file:
2 weeks ago
return yaml.safe_load(file)
def set_provider_config(self, feature: FeatureType, provider: str, provider_configs: dict):
"""
Set the provider and its configurations for a given feature.
Args:
feature: The feature to configure (e.g., 'llm', 'file_loader').
2 weeks ago
provider: The provider name (e.g., 'openai', 'deepseek').
provider_configs: A dictionary with configurations specific to the provider.
Raises:
ValueError: If the feature is not supported.
"""
if feature not in self.provide_settings:
raise ValueError(f"Unsupported feature: {feature}")
self.provide_settings[feature]["provider"] = provider
self.provide_settings[feature]["config"] = provider_configs
def get_provider_config(self, feature: FeatureType):
"""
Get the current provider and configuration for a given feature.
Args:
feature: The feature to retrieve (e.g., 'llm', 'file_loader').
2 weeks ago
Returns:
A dictionary with provider and its configurations.
Raises:
ValueError: If the feature is not supported.
"""
if feature not in self.provide_settings:
raise ValueError(f"Unsupported feature: {feature}")
return self.provide_settings[feature]
class ModuleFactory:
"""
Factory class for creating instances of various modules in the DeepSearcher system.
This class creates instances of LLMs, embedding models, file loaders
2 weeks ago
and vector databases based on the configuration settings.
"""
def __init__(self, config: Configuration):
"""
Initialize the ModuleFactory.
Args:
config: The Configuration object containing provider settings.
"""
self.config = config
def _create_module_instance(self, feature: FeatureType, module_name: str):
"""
Create an instance of a module based on the feature and module name.
Args:
feature: The feature type (e.g., 'llm', 'embedding').
module_name: The module name to import from.
Returns:
An instance of the specified module.
"""
# e.g.
# feature = "file_loader"
# module_name = "deepsearcher.loader.file_loader"
class_name = self.config.provide_settings[feature]["provider"]
module = __import__(module_name, fromlist=[class_name])
class_ = getattr(module, class_name)
return class_(**self.config.provide_settings[feature]["config"])
def create_llm(self) -> BaseLLM:
"""
Create an instance of a language model.
Returns:
An instance of a BaseLLM implementation.
"""
return self._create_module_instance("llm", "deepsearcher.llm")
def create_embedding(self) -> BaseEmbedding:
"""
Create an instance of an embedding model.
Returns:
An instance of a BaseEmbedding implementation.
"""
return self._create_module_instance("embedding", "deepsearcher.embedding")
def create_file_loader(self) -> BaseLoader:
"""
Create an instance of a file loader.
Returns:
An instance of a BaseLoader implementation.
"""
return self._create_module_instance("file_loader", "deepsearcher.loader.file_loader")
def create_vector_db(self) -> BaseVectorDB:
"""
Create an instance of a vector database.
Returns:
An instance of a BaseVectorDB implementation.
"""
return self._create_module_instance("vector_db", "deepsearcher.vector_db")
config = Configuration()
module_factory: ModuleFactory = None
llm: BaseLLM = None
embedding_model: BaseEmbedding = None
file_loader: BaseLoader = None
vector_db: BaseVectorDB = None
default_searcher: BaseAgent = None
2 weeks ago
def init_config(config: Configuration):
"""
Initialize the global configuration and create instances of all required modules.
This function initializes the global variables for the LLM, embedding model,
file loader, vector database, and RAG agents.
2 weeks ago
Args:
config: The Configuration object to use for initialization.
"""
global \
module_factory, \
llm, \
embedding_model, \
file_loader, \
vector_db, \
default_searcher
2 weeks ago
module_factory = ModuleFactory(config)
llm = module_factory.create_llm()
embedding_model = module_factory.create_embedding()
file_loader = module_factory.create_file_loader()
vector_db = module_factory.create_vector_db()
default_searcher = DeepSearch(
2 weeks ago
llm=llm,
embedding_model=embedding_model,
vector_db=vector_db,
max_iter=config.query_settings["max_iter"],
1 week ago
route_collection=False,
text_window_splitter=True
2 weeks ago
)