You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
536 lines
19 KiB
536 lines
19 KiB
import array
|
|
import json
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
from deepsearcher.loader.splitter import Chunk
|
|
from deepsearcher.utils import log
|
|
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
|
|
|
|
|
|
class OracleDB(BaseVectorDB):
|
|
"""OracleDB class is a subclass of DB class."""
|
|
|
|
client = None
|
|
|
|
def __init__(
|
|
self,
|
|
user: str,
|
|
password: str,
|
|
dsn: str,
|
|
config_dir: str,
|
|
wallet_location: str,
|
|
wallet_password: str,
|
|
min: int = 1,
|
|
max: int = 10,
|
|
increment: int = 1,
|
|
default_collection: str = "deepsearcher",
|
|
):
|
|
"""
|
|
Initialize the Oracle database connection.
|
|
|
|
Args:
|
|
user (str): Oracle database username.
|
|
password (str): Oracle database password.
|
|
dsn (str): Oracle database connection string.
|
|
config_dir (str): Directory containing Oracle configuration files.
|
|
wallet_location (str): Location of the Oracle wallet.
|
|
wallet_password (str): Password for the Oracle wallet.
|
|
min (int, optional): Minimum number of connections in the pool. Defaults to 1.
|
|
max (int, optional): Maximum number of connections in the pool. Defaults to 10.
|
|
increment (int, optional): Increment for adding new connections. Defaults to 1.
|
|
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
|
|
"""
|
|
super().__init__(default_collection)
|
|
self.default_collection = default_collection
|
|
|
|
import oracledb
|
|
|
|
oracledb.defaults.fetch_lobs = False
|
|
self.DB_TYPE_VECTOR = oracledb.DB_TYPE_VECTOR
|
|
|
|
try:
|
|
self.client = oracledb.create_pool(
|
|
user=user,
|
|
password=password,
|
|
dsn=dsn,
|
|
config_dir=config_dir,
|
|
wallet_location=wallet_location,
|
|
wallet_password=wallet_password,
|
|
min=min,
|
|
max=max,
|
|
increment=increment,
|
|
)
|
|
log.color_print(f"Connected to Oracle database at {dsn}")
|
|
self.check_table()
|
|
except Exception as e:
|
|
log.critical(f"Failed to connect to Oracle database at {dsn}")
|
|
log.critical(f"Oracle database error in init: {e}")
|
|
raise
|
|
|
|
def numpy_converter_in(self, value):
|
|
"""Convert numpy array to array.array"""
|
|
if value.dtype == np.float64:
|
|
dtype = "d"
|
|
elif value.dtype == np.float32:
|
|
dtype = "f"
|
|
else:
|
|
dtype = "b"
|
|
return array.array(dtype, value)
|
|
|
|
def input_type_handler(self, cursor, value, arraysize):
|
|
"""Set the type handler for the input data"""
|
|
if isinstance(value, np.ndarray):
|
|
return cursor.var(
|
|
self.DB_TYPE_VECTOR,
|
|
arraysize=arraysize,
|
|
inconverter=self.numpy_converter_in,
|
|
)
|
|
|
|
def numpy_converter_out(self, value):
|
|
"""Convert array.array to numpy array"""
|
|
if value.typecode == "b":
|
|
dtype = np.int8
|
|
elif value.typecode == "f":
|
|
dtype = np.float32
|
|
else:
|
|
dtype = np.float64
|
|
return np.array(value, copy=False, dtype=dtype)
|
|
|
|
def output_type_handler(self, cursor, metadata):
|
|
"""Set the type handler for the output data"""
|
|
if metadata.type_code is self.DB_TYPE_VECTOR:
|
|
return cursor.var(
|
|
metadata.type_code,
|
|
arraysize=cursor.arraysize,
|
|
outconverter=self.numpy_converter_out,
|
|
)
|
|
|
|
def query(self, sql: str, params: dict = None) -> Union[dict, None]:
|
|
"""
|
|
Execute a SQL query and return the results.
|
|
|
|
Args:
|
|
sql (str): SQL query to execute.
|
|
params (dict, optional): Parameters for the SQL query. Defaults to None.
|
|
|
|
Returns:
|
|
Union[dict, None]: Query results as a dictionary or None if no results.
|
|
|
|
Raises:
|
|
Exception: If there's an error executing the query.
|
|
"""
|
|
with self.client.acquire() as connection:
|
|
connection.inputtypehandler = self.input_type_handler
|
|
connection.outputtypehandler = self.output_type_handler
|
|
with connection.cursor() as cursor:
|
|
try:
|
|
if log.dev_mode:
|
|
print("sql:\n", sql)
|
|
# log.debug("def query:"+params)
|
|
# print("sql:\n",sql)
|
|
# print("params:\n",params)
|
|
cursor.execute(sql, params)
|
|
except Exception as e:
|
|
log.critical(f"Oracle database error in query: {e}")
|
|
raise
|
|
columns = [column[0].lower() for column in cursor.description]
|
|
rows = cursor.fetchall()
|
|
if rows:
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
else:
|
|
data = []
|
|
if log.dev_mode:
|
|
print("data:\n", data)
|
|
return data
|
|
# self.client.drop(connection)
|
|
|
|
def execute(self, sql: str, data: Union[list, dict] = None):
|
|
"""
|
|
Execute a SQL statement without returning results.
|
|
|
|
Args:
|
|
sql (str): SQL statement to execute.
|
|
data (Union[list, dict], optional): Data for the SQL statement. Defaults to None.
|
|
|
|
Raises:
|
|
Exception: If there's an error executing the statement.
|
|
"""
|
|
try:
|
|
with self.client.acquire() as connection:
|
|
connection.inputtypehandler = self.input_type_handler
|
|
connection.outputtypehandler = self.output_type_handler
|
|
with connection.cursor() as cursor:
|
|
# print("sql:\n",sql)
|
|
# print("data:\n",data)
|
|
if data is None:
|
|
cursor.execute(sql)
|
|
else:
|
|
cursor.execute(sql, data)
|
|
connection.commit()
|
|
except Exception as e:
|
|
log.critical(f"Oracle database error in execute: {e}")
|
|
log.error("ERROR sql:\n" + sql)
|
|
log.error("ERROR data:\n" + data)
|
|
raise
|
|
|
|
def has_collection(self, collection: str = "deepsearcher"):
|
|
"""
|
|
Check if a collection exists in the database.
|
|
|
|
Args:
|
|
collection (str, optional): Collection name to check. Defaults to "deepsearcher".
|
|
|
|
Returns:
|
|
bool: True if the collection exists, False otherwise.
|
|
"""
|
|
SQL = SQL_TEMPLATES["has_collection"]
|
|
params = {"collection": collection}
|
|
res = self.query(SQL, params)
|
|
if res:
|
|
if res[0]["rowcnt"] > 0:
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
def check_table(self):
|
|
"""
|
|
Check if required tables exist and create them if they don't.
|
|
|
|
Raises:
|
|
Exception: If there's an error checking or creating tables.
|
|
"""
|
|
SQL = SQL_TEMPLATES["has_table"]
|
|
try:
|
|
res = self.query(SQL)
|
|
if len(res) < 2:
|
|
missing_table = TABLES.keys() - set([i["table_name"] for i in res])
|
|
for table in missing_table:
|
|
self.create_tables(table)
|
|
except Exception as e:
|
|
log.critical(f"Failed to check table in Oracle database, error info: {e}")
|
|
raise
|
|
|
|
def create_tables(self, table_name):
|
|
"""
|
|
Create a table in the database.
|
|
|
|
Args:
|
|
table_name: Name of the table to create.
|
|
|
|
Raises:
|
|
Exception: If there's an error creating the table.
|
|
"""
|
|
SQL = TABLES[table_name]
|
|
try:
|
|
self.execute(SQL)
|
|
log.color_print(f"Created table {table_name} in Oracle database")
|
|
except Exception as e:
|
|
log.critical(f"Failed to create table {table_name} in Oracle database, error info: {e}")
|
|
raise
|
|
|
|
def drop_collection(self, collection: str = "deepsearcher"):
|
|
"""
|
|
Drop a collection from the database.
|
|
|
|
Args:
|
|
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
|
|
|
|
Raises:
|
|
Exception: If there's an error dropping the collection.
|
|
"""
|
|
try:
|
|
params = {"collection": collection}
|
|
SQL = SQL_TEMPLATES["drop_collection"]
|
|
self.execute(SQL, params)
|
|
|
|
SQL = SQL_TEMPLATES["drop_collection_item"]
|
|
self.execute(SQL, params)
|
|
log.color_print(f"Collection {collection} dropped")
|
|
except Exception as e:
|
|
log.critical(f"fail to drop collection, error info: {e}")
|
|
raise
|
|
|
|
def insertone(self, data):
|
|
"""
|
|
Insert a single record into the database.
|
|
|
|
Args:
|
|
data: Data to insert.
|
|
"""
|
|
SQL = SQL_TEMPLATES["insert"]
|
|
self.execute(SQL, data)
|
|
log.debug("insert done!")
|
|
|
|
def searchone(
|
|
self,
|
|
collection: Optional[str],
|
|
vector: Union[np.array, List[float]],
|
|
top_k: int = 5,
|
|
):
|
|
"""
|
|
Search for similar vectors in a collection.
|
|
|
|
Args:
|
|
collection (Optional[str]): Collection name to search in.
|
|
vector (Union[np.array, List[float]]): Query vector for similarity search.
|
|
top_k (int, optional): Number of results to return. Defaults to 5.
|
|
|
|
Returns:
|
|
list: List of search results.
|
|
|
|
Raises:
|
|
Exception: If there's an error during search.
|
|
"""
|
|
log.debug("def searchone:" + collection)
|
|
try:
|
|
if isinstance(vector, List):
|
|
vector = np.array(vector)
|
|
embedding_string = "[" + ", ".join(map(str, vector.tolist())) + "]"
|
|
dimension = vector.shape[0]
|
|
dtype = str(vector.dtype).upper()
|
|
|
|
SQL = SQL_TEMPLATES["search"].format(dimension=dimension, dtype=dtype)
|
|
max_distance = 0.8
|
|
params = {
|
|
"collection": collection,
|
|
"embedding_string": embedding_string,
|
|
"top_k": top_k,
|
|
"max_distance": max_distance,
|
|
}
|
|
res = self.query(SQL, params)
|
|
if res:
|
|
return res
|
|
else:
|
|
return []
|
|
except Exception as e:
|
|
log.critical(f"fail to search data, error info: {e}")
|
|
raise
|
|
|
|
def init_collection(
|
|
self,
|
|
dim: int,
|
|
collection: Optional[str] = "deepsearcher",
|
|
description: Optional[str] = "",
|
|
force_new_collection: bool = False,
|
|
text_max_length: int = 65_535,
|
|
reference_max_length: int = 2048,
|
|
metric_type: str = "L2",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize a collection in the database.
|
|
|
|
Args:
|
|
dim (int): Dimension of the vector embeddings.
|
|
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher".
|
|
description (Optional[str], optional): Collection description. Defaults to "".
|
|
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
|
|
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
|
|
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
|
|
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Raises:
|
|
Exception: If there's an error initializing the collection.
|
|
"""
|
|
if not collection:
|
|
collection = self.default_collection
|
|
if description is None:
|
|
description = ""
|
|
try:
|
|
has_collection = self.has_collection(collection)
|
|
if force_new_collection and has_collection:
|
|
self.drop_collection(collection)
|
|
elif has_collection:
|
|
return
|
|
# insert collection info
|
|
SQL = SQL_TEMPLATES["insert_collection"]
|
|
params = {"collection": collection, "description": description}
|
|
self.execute(SQL, params)
|
|
except Exception as e:
|
|
log.critical(f"fail to init_collection for oracle, error info: {e}")
|
|
|
|
def insert_data(
|
|
self,
|
|
collection: Optional[str],
|
|
chunks: List[Chunk],
|
|
batch_size: int = 256,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Insert data into a collection.
|
|
|
|
Args:
|
|
collection (Optional[str]): Collection name. If None, uses default_collection.
|
|
chunks (List[Chunk]): List of Chunk objects to insert.
|
|
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Raises:
|
|
Exception: If there's an error inserting data.
|
|
"""
|
|
if not collection:
|
|
collection = self.default_collection
|
|
|
|
datas = []
|
|
for chunk in chunks:
|
|
_data = {
|
|
"embedding": self.numpy_converter_in(np.array(chunk.embedding)),
|
|
"text": chunk.text,
|
|
"reference": chunk.reference,
|
|
"metadata": json.dumps(chunk.metadata),
|
|
"collection": collection,
|
|
}
|
|
datas.append(_data)
|
|
|
|
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]
|
|
try:
|
|
for batch_data in batch_datas:
|
|
for _data in batch_data:
|
|
self.insertone(data=_data)
|
|
log.color_print(f"Successfully insert {len(datas)} data")
|
|
except Exception as e:
|
|
log.critical(f"fail to insert data, error info: {e}")
|
|
raise
|
|
|
|
def search_data(
|
|
self,
|
|
collection: Optional[str],
|
|
vector: Union[np.array, List[float]],
|
|
top_k: int = 5,
|
|
*args,
|
|
**kwargs,
|
|
) -> List[RetrievalResult]:
|
|
"""
|
|
Search for similar vectors in a collection.
|
|
|
|
Args:
|
|
collection (Optional[str]): Collection name. If None, uses default_collection.
|
|
vector (Union[np.array, List[float]]): Query vector for similarity search.
|
|
top_k (int, optional): Number of results to return. Defaults to 5.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
List[RetrievalResult]: List of retrieval results containing similar vectors.
|
|
|
|
Raises:
|
|
Exception: If there's an error during search.
|
|
"""
|
|
if not collection:
|
|
collection = self.default_collection
|
|
try:
|
|
# print("def search_data:",collection)
|
|
# print("def search_data:",type(vector))
|
|
search_results = self.searchone(collection=collection, vector=vector, top_k=top_k)
|
|
# print("def search_data: search_results",search_results)
|
|
|
|
return [
|
|
RetrievalResult(
|
|
embedding=b["embedding"],
|
|
text=b["text"],
|
|
reference=b["reference"],
|
|
score=b["distance"],
|
|
metadata=json.loads(b["metadata"]),
|
|
)
|
|
for b in search_results
|
|
]
|
|
except Exception as e:
|
|
log.critical(f"fail to search data, error info: {e}")
|
|
raise
|
|
# return []
|
|
|
|
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
|
|
"""
|
|
List all collections in the database.
|
|
|
|
Args:
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
List[CollectionInfo]: List of collection information objects.
|
|
"""
|
|
collection_infos = []
|
|
try:
|
|
SQL = SQL_TEMPLATES["list_collections"]
|
|
log.debug("def list_collections:" + SQL)
|
|
collections = self.query(SQL)
|
|
if collections:
|
|
for collection in collections:
|
|
collection_infos.append(
|
|
CollectionInfo(
|
|
collection_name=collection["collection"],
|
|
description=collection["description"],
|
|
)
|
|
)
|
|
return collection_infos
|
|
except Exception as e:
|
|
log.critical(f"fail to list collections, error info: {e}")
|
|
raise
|
|
|
|
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs):
|
|
"""
|
|
Clear (drop) a collection from the database.
|
|
|
|
Args:
|
|
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
if not collection:
|
|
collection = self.default_collection
|
|
try:
|
|
self.client.drop_collection(collection)
|
|
except Exception as e:
|
|
log.warning(f"fail to clear db, error info: {e}")
|
|
raise
|
|
|
|
|
|
TABLES = {
|
|
"DEEPSEARCHER_COLLECTION_INFO": """CREATE TABLE DEEPSEARCHER_COLLECTION_INFO (
|
|
id INT generated by default as identity primary key,
|
|
collection varchar(256),
|
|
description CLOB,
|
|
status NUMBER DEFAULT 1,
|
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updatetime TIMESTAMP DEFAULT NULL)""",
|
|
"DEEPSEARCHER_COLLECTION_ITEM": """CREATE TABLE DEEPSEARCHER_COLLECTION_ITEM (
|
|
id INT generated by default as identity primary key,
|
|
collection varchar(256),
|
|
embedding VECTOR,
|
|
text CLOB,
|
|
reference varchar(4000),
|
|
metadata CLOB,
|
|
status NUMBER DEFAULT 1,
|
|
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updatetime TIMESTAMP DEFAULT NULL)""",
|
|
}
|
|
|
|
SQL_TEMPLATES = {
|
|
"has_table": f"""SELECT table_name FROM all_tables
|
|
WHERE table_name in ({",".join([f"'{k}'" for k in TABLES.keys()])})""",
|
|
"has_collection": "select count(*) as rowcnt from DEEPSEARCHER_COLLECTION_INFO where collection=:collection and status=1",
|
|
"list_collections": "select collection,description from DEEPSEARCHER_COLLECTION_INFO where status=1",
|
|
"drop_collection": "update DEEPSEARCHER_COLLECTION_INFO set status=0 where collection=:collection and status=1",
|
|
"drop_collection_item": "update DEEPSEARCHER_COLLECTION_ITEM set status=0 where collection=:collection and status=1",
|
|
"insert_collection": """INSERT INTO DEEPSEARCHER_COLLECTION_INFO (collection,description)
|
|
values (:collection,:description)""",
|
|
"insert": """INSERT INTO DEEPSEARCHER_COLLECTION_ITEM (collection,embedding,text,reference,metadata)
|
|
values (:collection,:embedding,:text,:reference,:metadata)""",
|
|
"search": """SELECT * FROM
|
|
(SELECT t.*,
|
|
VECTOR_DISTANCE(t.embedding,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
|
FROM DEEPSEARCHER_COLLECTION_ITEM t
|
|
JOIN DEEPSEARCHER_COLLECTION_INFO c ON t.collection=c.collection
|
|
WHERE t.collection=:collection AND t.status=1 AND c.status=1)
|
|
WHERE distance<:max_distance ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
|
}
|
|
|