You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

247 lines
8.0 KiB

import hashlib
from typing import List
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult
def calculate_text_hash(text: str) -> str:
"""
计算文本的哈希值用于作为主键
Args:
text (str): 输入文本
Returns:
str: 文本的MD5哈希值
"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
class DeduplicatedVectorDB:
"""
支持去重的向量数据库包装器
该类在现有向量数据库基础上增加以下功能
1. 使用文本哈希值作为主键
2. 避免插入重复数据
3. 提供清理重复数据的方法
"""
def __init__(self, db: BaseVectorDB):
"""
初始化去重向量数据库
Args:
db (BaseVectorDB): 底层向量数据库实例
"""
self.db = db
def init_collection(self, dim: int, collection: str, description: str,
force_new_collection=False, *args, **kwargs):
"""
初始化集合
Args:
dim (int): 向量维度
collection (str): 集合名称
description (str): 集合描述
force_new_collection (bool): 是否强制创建新集合
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.init_collection(dim, collection, description,
force_new_collection, *args, **kwargs)
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据避免重复
该方法会先检查数据库中是否已存在相同文本的哈希值
如果不存在才进行插入操作
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk计算哈希值并添加到metadata中
for chunk in chunks:
if 'hash' not in chunk.metadata:
chunk.metadata['hash'] = calculate_text_hash(chunk.text)
# 过滤掉已存在的数据
filtered_chunks = self._filter_duplicate_chunks(collection, chunks)
# 插入过滤后的数据
if filtered_chunks:
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs)
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
# 注意:这个实现依赖于具体的数据库实现
# 对于生产环境,应该使用数据库特定的查询方法来检查重复
# 这里提供一个通用但效率较低的实现
return chunks
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]:
"""
搜索数据
Args:
collection (str): 集合名称
vector (np.array): 查询向量
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
List[RetrievalResult]: 搜索结果
"""
return self.db.search_data(collection, vector, *args, **kwargs)
def list_collections(self, *args, **kwargs):
"""
列出所有集合
Args:
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
集合信息列表
"""
return self.db.list_collections(*args, **kwargs)
def clear_db(self, collection: str = None, *args, **kwargs):
"""
清空数据库集合
Args:
collection (str): 集合名称
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.clear_db(collection, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除数据库中的重复数据
注意这是一个通用实现具体数据库可能有更高效的实现方式
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# 这个方法的具体实现需要根据数据库类型来定制
# 这里仅提供一个概念性的实现
raise NotImplementedError("需要根据具体的数据库实现去重逻辑")
# 为Qdrant实现专门的去重工具
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Qdrant专用的去重向量数据库实现
"""
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于Qdrant数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
try:
# 获取Qdrant客户端
qdrant_client = self.db.client
# 收集所有要检查的哈希值
hashes = [calculate_text_hash(chunk.text) for chunk in chunks]
# 查询已存在的记录
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整
filtered_chunks = []
for chunk, hash_value in zip(chunks, hashes):
# 这里应该使用Qdrant的查询API检查是否已存在该hash
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现
filtered_chunks.append(chunk)
return filtered_chunks
except Exception:
# 出错时返回所有chunks(不进行过滤)
return chunks
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Qdrant数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Qdrant去重实现
# 这需要先检索所有数据,识别重复项,然后删除重复项
# 具体实现略
return 0
# 为Milvus实现专门的去重工具
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Milvus专用的去重向量数据库实现
"""
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据到Milvus使用文本哈希作为主键
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk设置ID为文本哈希值
for chunk in chunks:
chunk.metadata['id'] = calculate_text_hash(chunk.text)
# 调用父类方法进行插入
return super().insert_data(collection, chunks, batch_size, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Milvus数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Milvus去重实现
# 可以通过查询所有数据,找出ID重复的记录并删除
return 0