import hashlib from typing import List import numpy as np from deepsearcher.loader.splitter import Chunk from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult def calculate_text_hash(text: str) -> str: """ 计算文本的哈希值,用于作为主键 Args: text (str): 输入文本 Returns: str: 文本的MD5哈希值 """ return hashlib.md5(text.encode('utf-8')).hexdigest() class DeduplicatedVectorDB: """ 支持去重的向量数据库包装器 该类在现有向量数据库基础上增加以下功能: 1. 使用文本哈希值作为主键 2. 避免插入重复数据 3. 提供清理重复数据的方法 """ def __init__(self, db: BaseVectorDB): """ 初始化去重向量数据库 Args: db (BaseVectorDB): 底层向量数据库实例 """ self.db = db def init_collection(self, dim: int, collection: str, description: str, force_new_collection=False, *args, **kwargs): """ 初始化集合 Args: dim (int): 向量维度 collection (str): 集合名称 description (str): 集合描述 force_new_collection (bool): 是否强制创建新集合 *args: 其他参数 **kwargs: 其他关键字参数 """ return self.db.init_collection(dim, collection, description, force_new_collection, *args, **kwargs) def insert_data(self, collection: str, chunks: List[Chunk], batch_size: int = 256, *args, **kwargs): """ 插入数据,避免重复 该方法会先检查数据库中是否已存在相同文本的哈希值, 如果不存在才进行插入操作 Args: collection (str): 集合名称 chunks (List[Chunk]): 要插入的数据块列表 batch_size (int): 批处理大小 *args: 其他参数 **kwargs: 其他关键字参数 """ # 为每个chunk计算哈希值并添加到metadata中 for chunk in chunks: if 'hash' not in chunk.metadata: chunk.metadata['hash'] = calculate_text_hash(chunk.text) # 过滤掉已存在的数据 filtered_chunks = self._filter_duplicate_chunks(collection, chunks) # 插入过滤后的数据 if filtered_chunks: return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs) def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: """ 过滤掉已存在于数据库中的chunks Args: collection (str): 集合名称 chunks (List[Chunk]): 待过滤的数据块列表 Returns: List[Chunk]: 过滤后的数据块列表 """ # 注意:这个实现依赖于具体的数据库实现 # 对于生产环境,应该使用数据库特定的查询方法来检查重复 # 这里提供一个通用但效率较低的实现 return chunks def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]: """ 搜索数据 Args: collection (str): 集合名称 vector (np.array): 查询向量 *args: 其他参数 **kwargs: 其他关键字参数 Returns: List[RetrievalResult]: 搜索结果 """ return self.db.search_data(collection, vector, *args, **kwargs) def list_collections(self, *args, **kwargs): """ 列出所有集合 Args: *args: 其他参数 **kwargs: 其他关键字参数 Returns: 集合信息列表 """ return self.db.list_collections(*args, **kwargs) def clear_db(self, collection: str = None, *args, **kwargs): """ 清空数据库集合 Args: collection (str): 集合名称 *args: 其他参数 **kwargs: 其他关键字参数 """ return self.db.clear_db(collection, *args, **kwargs) def remove_duplicate_data(self, collection: str = None) -> int: """ 移除数据库中的重复数据 注意:这是一个通用实现,具体数据库可能有更高效的实现方式 Args: collection (str): 集合名称 Returns: int: 移除的重复记录数 """ # 这个方法的具体实现需要根据数据库类型来定制 # 这里仅提供一个概念性的实现 raise NotImplementedError("需要根据具体的数据库实现去重逻辑") # 为Qdrant实现专门的去重工具 class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB): """ Qdrant专用的去重向量数据库实现 """ def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: """ 过滤掉已存在于Qdrant数据库中的chunks Args: collection (str): 集合名称 chunks (List[Chunk]): 待过滤的数据块列表 Returns: List[Chunk]: 过滤后的数据块列表 """ try: # 获取Qdrant客户端 qdrant_client = self.db.client # 收集所有要检查的哈希值 hashes = [calculate_text_hash(chunk.text) for chunk in chunks] # 查询已存在的记录 # 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整 filtered_chunks = [] for chunk, hash_value in zip(chunks, hashes): # 这里应该使用Qdrant的查询API检查是否已存在该hash # 由于Qdrant的API限制,可能需要使用scroll或search功能实现 filtered_chunks.append(chunk) return filtered_chunks except Exception: # 出错时返回所有chunks(不进行过滤) return chunks def remove_duplicate_data(self, collection: str = None) -> int: """ 移除Qdrant数据库中的重复数据 Args: collection (str): 集合名称 Returns: int: 移除的重复记录数 """ # Qdrant去重实现 # 这需要先检索所有数据,识别重复项,然后删除重复项 # 具体实现略 return 0 # 为Milvus实现专门的去重工具 class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB): """ Milvus专用的去重向量数据库实现 """ def insert_data(self, collection: str, chunks: List[Chunk], batch_size: int = 256, *args, **kwargs): """ 插入数据到Milvus,使用文本哈希作为主键 Args: collection (str): 集合名称 chunks (List[Chunk]): 要插入的数据块列表 batch_size (int): 批处理大小 *args: 其他参数 **kwargs: 其他关键字参数 """ # 为每个chunk设置ID为文本哈希值 for chunk in chunks: chunk.metadata['id'] = calculate_text_hash(chunk.text) # 调用父类方法进行插入 return super().insert_data(collection, chunks, batch_size, *args, **kwargs) def remove_duplicate_data(self, collection: str = None) -> int: """ 移除Milvus数据库中的重复数据 Args: collection (str): 集合名称 Returns: int: 移除的重复记录数 """ # Milvus去重实现 # 可以通过查询所有数据,找出ID重复的记录并删除 return 0