You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
247 lines
8.0 KiB
247 lines
8.0 KiB
import hashlib
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from deepsearcher.loader.splitter import Chunk
|
|
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult
|
|
|
|
|
|
def calculate_text_hash(text: str) -> str:
|
|
"""
|
|
计算文本的哈希值,用于作为主键
|
|
|
|
Args:
|
|
text (str): 输入文本
|
|
|
|
Returns:
|
|
str: 文本的MD5哈希值
|
|
"""
|
|
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
|
|
|
|
|
class DeduplicatedVectorDB:
|
|
"""
|
|
支持去重的向量数据库包装器
|
|
|
|
该类在现有向量数据库基础上增加以下功能:
|
|
1. 使用文本哈希值作为主键
|
|
2. 避免插入重复数据
|
|
3. 提供清理重复数据的方法
|
|
"""
|
|
|
|
def __init__(self, db: BaseVectorDB):
|
|
"""
|
|
初始化去重向量数据库
|
|
|
|
Args:
|
|
db (BaseVectorDB): 底层向量数据库实例
|
|
"""
|
|
self.db = db
|
|
|
|
def init_collection(self, dim: int, collection: str, description: str,
|
|
force_new_collection=False, *args, **kwargs):
|
|
"""
|
|
初始化集合
|
|
|
|
Args:
|
|
dim (int): 向量维度
|
|
collection (str): 集合名称
|
|
description (str): 集合描述
|
|
force_new_collection (bool): 是否强制创建新集合
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
"""
|
|
return self.db.init_collection(dim, collection, description,
|
|
force_new_collection, *args, **kwargs)
|
|
|
|
def insert_data(self, collection: str, chunks: List[Chunk],
|
|
batch_size: int = 256, *args, **kwargs):
|
|
"""
|
|
插入数据,避免重复
|
|
|
|
该方法会先检查数据库中是否已存在相同文本的哈希值,
|
|
如果不存在才进行插入操作
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
chunks (List[Chunk]): 要插入的数据块列表
|
|
batch_size (int): 批处理大小
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
"""
|
|
# 为每个chunk计算哈希值并添加到metadata中
|
|
for chunk in chunks:
|
|
if 'hash' not in chunk.metadata:
|
|
chunk.metadata['hash'] = calculate_text_hash(chunk.text)
|
|
|
|
# 过滤掉已存在的数据
|
|
filtered_chunks = self._filter_duplicate_chunks(collection, chunks)
|
|
|
|
# 插入过滤后的数据
|
|
if filtered_chunks:
|
|
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs)
|
|
|
|
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
|
|
"""
|
|
过滤掉已存在于数据库中的chunks
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
chunks (List[Chunk]): 待过滤的数据块列表
|
|
|
|
Returns:
|
|
List[Chunk]: 过滤后的数据块列表
|
|
"""
|
|
# 注意:这个实现依赖于具体的数据库实现
|
|
# 对于生产环境,应该使用数据库特定的查询方法来检查重复
|
|
# 这里提供一个通用但效率较低的实现
|
|
return chunks
|
|
|
|
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]:
|
|
"""
|
|
搜索数据
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
vector (np.array): 查询向量
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
|
|
Returns:
|
|
List[RetrievalResult]: 搜索结果
|
|
"""
|
|
return self.db.search_data(collection, vector, *args, **kwargs)
|
|
|
|
def list_collections(self, *args, **kwargs):
|
|
"""
|
|
列出所有集合
|
|
|
|
Args:
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
|
|
Returns:
|
|
集合信息列表
|
|
"""
|
|
return self.db.list_collections(*args, **kwargs)
|
|
|
|
def clear_db(self, collection: str = None, *args, **kwargs):
|
|
"""
|
|
清空数据库集合
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
"""
|
|
return self.db.clear_db(collection, *args, **kwargs)
|
|
|
|
def remove_duplicate_data(self, collection: str = None) -> int:
|
|
"""
|
|
移除数据库中的重复数据
|
|
|
|
注意:这是一个通用实现,具体数据库可能有更高效的实现方式
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
|
|
Returns:
|
|
int: 移除的重复记录数
|
|
"""
|
|
# 这个方法的具体实现需要根据数据库类型来定制
|
|
# 这里仅提供一个概念性的实现
|
|
raise NotImplementedError("需要根据具体的数据库实现去重逻辑")
|
|
|
|
|
|
# 为Qdrant实现专门的去重工具
|
|
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB):
|
|
"""
|
|
Qdrant专用的去重向量数据库实现
|
|
"""
|
|
|
|
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
|
|
"""
|
|
过滤掉已存在于Qdrant数据库中的chunks
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
chunks (List[Chunk]): 待过滤的数据块列表
|
|
|
|
Returns:
|
|
List[Chunk]: 过滤后的数据块列表
|
|
"""
|
|
try:
|
|
# 获取Qdrant客户端
|
|
qdrant_client = self.db.client
|
|
|
|
# 收集所有要检查的哈希值
|
|
hashes = [calculate_text_hash(chunk.text) for chunk in chunks]
|
|
|
|
# 查询已存在的记录
|
|
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整
|
|
filtered_chunks = []
|
|
for chunk, hash_value in zip(chunks, hashes):
|
|
# 这里应该使用Qdrant的查询API检查是否已存在该hash
|
|
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现
|
|
filtered_chunks.append(chunk)
|
|
|
|
return filtered_chunks
|
|
except Exception:
|
|
# 出错时返回所有chunks(不进行过滤)
|
|
return chunks
|
|
|
|
def remove_duplicate_data(self, collection: str = None) -> int:
|
|
"""
|
|
移除Qdrant数据库中的重复数据
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
|
|
Returns:
|
|
int: 移除的重复记录数
|
|
"""
|
|
# Qdrant去重实现
|
|
# 这需要先检索所有数据,识别重复项,然后删除重复项
|
|
# 具体实现略
|
|
return 0
|
|
|
|
|
|
# 为Milvus实现专门的去重工具
|
|
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB):
|
|
"""
|
|
Milvus专用的去重向量数据库实现
|
|
"""
|
|
|
|
def insert_data(self, collection: str, chunks: List[Chunk],
|
|
batch_size: int = 256, *args, **kwargs):
|
|
"""
|
|
插入数据到Milvus,使用文本哈希作为主键
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
chunks (List[Chunk]): 要插入的数据块列表
|
|
batch_size (int): 批处理大小
|
|
*args: 其他参数
|
|
**kwargs: 其他关键字参数
|
|
"""
|
|
# 为每个chunk设置ID为文本哈希值
|
|
for chunk in chunks:
|
|
chunk.metadata['id'] = calculate_text_hash(chunk.text)
|
|
|
|
# 调用父类方法进行插入
|
|
return super().insert_data(collection, chunks, batch_size, *args, **kwargs)
|
|
|
|
def remove_duplicate_data(self, collection: str = None) -> int:
|
|
"""
|
|
移除Milvus数据库中的重复数据
|
|
|
|
Args:
|
|
collection (str): 集合名称
|
|
|
|
Returns:
|
|
int: 移除的重复记录数
|
|
"""
|
|
# Milvus去重实现
|
|
# 可以通过查询所有数据,找出ID重复的记录并删除
|
|
return 0
|