You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

536 lines
19 KiB

import array
import json
from typing import List, Optional, Union
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class OracleDB(BaseVectorDB):
"""OracleDB class is a subclass of DB class."""
client = None
def __init__(
self,
user: str,
password: str,
dsn: str,
config_dir: str,
wallet_location: str,
wallet_password: str,
min: int = 1,
max: int = 10,
increment: int = 1,
default_collection: str = "deepsearcher",
):
"""
Initialize the Oracle database connection.
Args:
user (str): Oracle database username.
password (str): Oracle database password.
dsn (str): Oracle database connection string.
config_dir (str): Directory containing Oracle configuration files.
wallet_location (str): Location of the Oracle wallet.
wallet_password (str): Password for the Oracle wallet.
min (int, optional): Minimum number of connections in the pool. Defaults to 1.
max (int, optional): Maximum number of connections in the pool. Defaults to 10.
increment (int, optional): Increment for adding new connections. Defaults to 1.
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
"""
super().__init__(default_collection)
self.default_collection = default_collection
import oracledb
oracledb.defaults.fetch_lobs = False
self.DB_TYPE_VECTOR = oracledb.DB_TYPE_VECTOR
try:
self.client = oracledb.create_pool(
user=user,
password=password,
dsn=dsn,
config_dir=config_dir,
wallet_location=wallet_location,
wallet_password=wallet_password,
min=min,
max=max,
increment=increment,
)
log.color_print(f"Connected to Oracle database at {dsn}")
self.check_table()
except Exception as e:
log.critical(f"Failed to connect to Oracle database at {dsn}")
log.critical(f"Oracle database error in init: {e}")
raise
def numpy_converter_in(self, value):
"""Convert numpy array to array.array"""
if value.dtype == np.float64:
dtype = "d"
elif value.dtype == np.float32:
dtype = "f"
else:
dtype = "b"
return array.array(dtype, value)
def input_type_handler(self, cursor, value, arraysize):
"""Set the type handler for the input data"""
if isinstance(value, np.ndarray):
return cursor.var(
self.DB_TYPE_VECTOR,
arraysize=arraysize,
inconverter=self.numpy_converter_in,
)
def numpy_converter_out(self, value):
"""Convert array.array to numpy array"""
if value.typecode == "b":
dtype = np.int8
elif value.typecode == "f":
dtype = np.float32
else:
dtype = np.float64
return np.array(value, copy=False, dtype=dtype)
def output_type_handler(self, cursor, metadata):
"""Set the type handler for the output data"""
if metadata.type_code is self.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code,
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def query(self, sql: str, params: dict = None) -> Union[dict, None]:
"""
Execute a SQL query and return the results.
Args:
sql (str): SQL query to execute.
params (dict, optional): Parameters for the SQL query. Defaults to None.
Returns:
Union[dict, None]: Query results as a dictionary or None if no results.
Raises:
Exception: If there's an error executing the query.
"""
with self.client.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
try:
if log.dev_mode:
print("sql:\n", sql)
# log.debug("def query:"+params)
# print("sql:\n",sql)
# print("params:\n",params)
cursor.execute(sql, params)
except Exception as e:
log.critical(f"Oracle database error in query: {e}")
raise
columns = [column[0].lower() for column in cursor.description]
rows = cursor.fetchall()
if rows:
data = [dict(zip(columns, row)) for row in rows]
else:
data = []
if log.dev_mode:
print("data:\n", data)
return data
# self.client.drop(connection)
def execute(self, sql: str, data: Union[list, dict] = None):
"""
Execute a SQL statement without returning results.
Args:
sql (str): SQL statement to execute.
data (Union[list, dict], optional): Data for the SQL statement. Defaults to None.
Raises:
Exception: If there's an error executing the statement.
"""
try:
with self.client.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
# print("sql:\n",sql)
# print("data:\n",data)
if data is None:
cursor.execute(sql)
else:
cursor.execute(sql, data)
connection.commit()
except Exception as e:
log.critical(f"Oracle database error in execute: {e}")
log.error("ERROR sql:\n" + sql)
log.error("ERROR data:\n" + data)
raise
def has_collection(self, collection: str = "deepsearcher"):
"""
Check if a collection exists in the database.
Args:
collection (str, optional): Collection name to check. Defaults to "deepsearcher".
Returns:
bool: True if the collection exists, False otherwise.
"""
SQL = SQL_TEMPLATES["has_collection"]
params = {"collection": collection}
res = self.query(SQL, params)
if res:
if res[0]["rowcnt"] > 0:
return True
else:
return False
else:
return False
def check_table(self):
"""
Check if required tables exist and create them if they don't.
Raises:
Exception: If there's an error checking or creating tables.
"""
SQL = SQL_TEMPLATES["has_table"]
try:
res = self.query(SQL)
if len(res) < 2:
missing_table = TABLES.keys() - set([i["table_name"] for i in res])
for table in missing_table:
self.create_tables(table)
except Exception as e:
log.critical(f"Failed to check table in Oracle database, error info: {e}")
raise
def create_tables(self, table_name):
"""
Create a table in the database.
Args:
table_name: Name of the table to create.
Raises:
Exception: If there's an error creating the table.
"""
SQL = TABLES[table_name]
try:
self.execute(SQL)
log.color_print(f"Created table {table_name} in Oracle database")
except Exception as e:
log.critical(f"Failed to create table {table_name} in Oracle database, error info: {e}")
raise
def drop_collection(self, collection: str = "deepsearcher"):
"""
Drop a collection from the database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
Raises:
Exception: If there's an error dropping the collection.
"""
try:
params = {"collection": collection}
SQL = SQL_TEMPLATES["drop_collection"]
self.execute(SQL, params)
SQL = SQL_TEMPLATES["drop_collection_item"]
self.execute(SQL, params)
log.color_print(f"Collection {collection} dropped")
except Exception as e:
log.critical(f"fail to drop collection, error info: {e}")
raise
def insertone(self, data):
"""
Insert a single record into the database.
Args:
data: Data to insert.
"""
SQL = SQL_TEMPLATES["insert"]
self.execute(SQL, data)
log.debug("insert done!")
def searchone(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
):
"""
Search for similar vectors in a collection.
Args:
collection (Optional[str]): Collection name to search in.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
Returns:
list: List of search results.
Raises:
Exception: If there's an error during search.
"""
log.debug("def searchone:" + collection)
try:
if isinstance(vector, List):
vector = np.array(vector)
embedding_string = "[" + ", ".join(map(str, vector.tolist())) + "]"
dimension = vector.shape[0]
dtype = str(vector.dtype).upper()
SQL = SQL_TEMPLATES["search"].format(dimension=dimension, dtype=dtype)
max_distance = 0.8
params = {
"collection": collection,
"embedding_string": embedding_string,
"top_k": top_k,
"max_distance": max_distance,
}
res = self.query(SQL, params)
if res:
return res
else:
return []
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
raise
def init_collection(
self,
dim: int,
collection: Optional[str] = "deepsearcher",
description: Optional[str] = "",
force_new_collection: bool = False,
text_max_length: int = 65_535,
reference_max_length: int = 2048,
metric_type: str = "L2",
*args,
**kwargs,
):
"""
Initialize a collection in the database.
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher".
description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
Exception: If there's an error initializing the collection.
"""
if not collection:
collection = self.default_collection
if description is None:
description = ""
try:
has_collection = self.has_collection(collection)
if force_new_collection and has_collection:
self.drop_collection(collection)
elif has_collection:
return
# insert collection info
SQL = SQL_TEMPLATES["insert_collection"]
params = {"collection": collection, "description": description}
self.execute(SQL, params)
except Exception as e:
log.critical(f"fail to init_collection for oracle, error info: {e}")
def insert_data(
self,
collection: Optional[str],
chunks: List[Chunk],
batch_size: int = 256,
*args,
**kwargs,
):
"""
Insert data into a collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
chunks (List[Chunk]): List of Chunk objects to insert.
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
Exception: If there's an error inserting data.
"""
if not collection:
collection = self.default_collection
datas = []
for chunk in chunks:
_data = {
"embedding": self.numpy_converter_in(np.array(chunk.embedding)),
"text": chunk.text,
"reference": chunk.reference,
"metadata": json.dumps(chunk.metadata),
"collection": collection,
}
datas.append(_data)
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]
try:
for batch_data in batch_datas:
for _data in batch_data:
self.insertone(data=_data)
log.color_print(f"Successfully insert {len(datas)} data")
except Exception as e:
log.critical(f"fail to insert data, error info: {e}")
raise
def search_data(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
*args,
**kwargs,
) -> List[RetrievalResult]:
"""
Search for similar vectors in a collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[RetrievalResult]: List of retrieval results containing similar vectors.
Raises:
Exception: If there's an error during search.
"""
if not collection:
collection = self.default_collection
try:
# print("def search_data:",collection)
# print("def search_data:",type(vector))
search_results = self.searchone(collection=collection, vector=vector, top_k=top_k)
# print("def search_data: search_results",search_results)
return [
RetrievalResult(
embedding=b["embedding"],
text=b["text"],
reference=b["reference"],
score=b["distance"],
metadata=json.loads(b["metadata"]),
)
for b in search_results
]
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
raise
# return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""
List all collections in the database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[CollectionInfo]: List of collection information objects.
"""
collection_infos = []
try:
SQL = SQL_TEMPLATES["list_collections"]
log.debug("def list_collections:" + SQL)
collections = self.query(SQL)
if collections:
for collection in collections:
collection_infos.append(
CollectionInfo(
collection_name=collection["collection"],
description=collection["description"],
)
)
return collection_infos
except Exception as e:
log.critical(f"fail to list collections, error info: {e}")
raise
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs):
"""
Clear (drop) a collection from the database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
try:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear db, error info: {e}")
raise
TABLES = {
"DEEPSEARCHER_COLLECTION_INFO": """CREATE TABLE DEEPSEARCHER_COLLECTION_INFO (
id INT generated by default as identity primary key,
collection varchar(256),
description CLOB,
status NUMBER DEFAULT 1,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL)""",
"DEEPSEARCHER_COLLECTION_ITEM": """CREATE TABLE DEEPSEARCHER_COLLECTION_ITEM (
id INT generated by default as identity primary key,
collection varchar(256),
embedding VECTOR,
text CLOB,
reference varchar(4000),
metadata CLOB,
status NUMBER DEFAULT 1,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL)""",
}
SQL_TEMPLATES = {
"has_table": f"""SELECT table_name FROM all_tables
WHERE table_name in ({",".join([f"'{k}'" for k in TABLES.keys()])})""",
"has_collection": "select count(*) as rowcnt from DEEPSEARCHER_COLLECTION_INFO where collection=:collection and status=1",
"list_collections": "select collection,description from DEEPSEARCHER_COLLECTION_INFO where status=1",
"drop_collection": "update DEEPSEARCHER_COLLECTION_INFO set status=0 where collection=:collection and status=1",
"drop_collection_item": "update DEEPSEARCHER_COLLECTION_ITEM set status=0 where collection=:collection and status=1",
"insert_collection": """INSERT INTO DEEPSEARCHER_COLLECTION_INFO (collection,description)
values (:collection,:description)""",
"insert": """INSERT INTO DEEPSEARCHER_COLLECTION_ITEM (collection,embedding,text,reference,metadata)
values (:collection,:embedding,:text,:reference,:metadata)""",
"search": """SELECT * FROM
(SELECT t.*,
VECTOR_DISTANCE(t.embedding,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM DEEPSEARCHER_COLLECTION_ITEM t
JOIN DEEPSEARCHER_COLLECTION_INFO c ON t.collection=c.collection
WHERE t.collection=:collection AND t.status=1 AND c.status=1)
WHERE distance<:max_distance ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
}