You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

537 lines
19 KiB

2 weeks ago
import array
import json
from typing import List, Optional, Union
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class OracleDB(BaseVectorDB):
"""OracleDB class is a subclass of DB class."""
client = None
def __init__(
self,
user: str,
password: str,
dsn: str,
config_dir: str,
wallet_location: str,
wallet_password: str,
min: int = 1,
max: int = 10,
increment: int = 1,
default_collection: str = "deepsearcher",
):
"""
Initialize the Oracle database connection.
Args:
user (str): Oracle database username.
password (str): Oracle database password.
dsn (str): Oracle database connection string.
config_dir (str): Directory containing Oracle configuration files.
wallet_location (str): Location of the Oracle wallet.
wallet_password (str): Password for the Oracle wallet.
min (int, optional): Minimum number of connections in the pool. Defaults to 1.
max (int, optional): Maximum number of connections in the pool. Defaults to 10.
increment (int, optional): Increment for adding new connections. Defaults to 1.
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
"""
super().__init__(default_collection)
self.default_collection = default_collection
import oracledb
oracledb.defaults.fetch_lobs = False
self.DB_TYPE_VECTOR = oracledb.DB_TYPE_VECTOR
try:
self.client = oracledb.create_pool(
user=user,
password=password,
dsn=dsn,
config_dir=config_dir,
wallet_location=wallet_location,
wallet_password=wallet_password,
min=min,
max=max,
increment=increment,
)
log.color_print(f"Connected to Oracle database at {dsn}")
self.check_table()
except Exception as e:
log.critical(f"Failed to connect to Oracle database at {dsn}")
log.critical(f"Oracle database error in init: {e}")
raise
def numpy_converter_in(self, value):
"""Convert numpy array to array.array"""
if value.dtype == np.float64:
dtype = "d"
elif value.dtype == np.float32:
dtype = "f"
else:
dtype = "b"
return array.array(dtype, value)
def input_type_handler(self, cursor, value, arraysize):
"""Set the type handler for the input data"""
if isinstance(value, np.ndarray):
return cursor.var(
self.DB_TYPE_VECTOR,
arraysize=arraysize,
inconverter=self.numpy_converter_in,
)
def numpy_converter_out(self, value):
"""Convert array.array to numpy array"""
if value.typecode == "b":
dtype = np.int8
elif value.typecode == "f":
dtype = np.float32
else:
dtype = np.float64
return np.array(value, copy=False, dtype=dtype)
def output_type_handler(self, cursor, metadata):
"""Set the type handler for the output data"""
if metadata.type_code is self.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code,
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def query(self, sql: str, params: dict = None) -> Union[dict, None]:
"""
Execute a SQL query and return the results.
Args:
sql (str): SQL query to execute.
params (dict, optional): Parameters for the SQL query. Defaults to None.
Returns:
Union[dict, None]: Query results as a dictionary or None if no results.
Raises:
Exception: If there's an error executing the query.
"""
with self.client.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
try:
if log.dev_mode:
print("sql:\n", sql)
# log.debug("def query:"+params)
# print("sql:\n",sql)
# print("params:\n",params)
cursor.execute(sql, params)
except Exception as e:
log.critical(f"Oracle database error in query: {e}")
raise
columns = [column[0].lower() for column in cursor.description]
rows = cursor.fetchall()
if rows:
data = [dict(zip(columns, row)) for row in rows]
else:
data = []
if log.dev_mode:
print("data:\n", data)
return data
# self.client.drop(connection)
def execute(self, sql: str, data: Union[list, dict] = None):
"""
Execute a SQL statement without returning results.
Args:
sql (str): SQL statement to execute.
data (Union[list, dict], optional): Data for the SQL statement. Defaults to None.
Raises:
Exception: If there's an error executing the statement.
"""
try:
with self.client.acquire() as connection:
connection.inputtypehandler = self.input_type_handler
connection.outputtypehandler = self.output_type_handler
with connection.cursor() as cursor:
# print("sql:\n",sql)
# print("data:\n",data)
if data is None:
cursor.execute(sql)
else:
cursor.execute(sql, data)
connection.commit()
except Exception as e:
log.critical(f"Oracle database error in execute: {e}")
log.error("ERROR sql:\n" + sql)
log.error("ERROR data:\n" + data)
raise
def has_collection(self, collection: str = "deepsearcher"):
"""
Check if a collection exists in the database.
Args:
collection (str, optional): Collection name to check. Defaults to "deepsearcher".
Returns:
bool: True if the collection exists, False otherwise.
"""
SQL = SQL_TEMPLATES["has_collection"]
params = {"collection": collection}
res = self.query(SQL, params)
if res:
if res[0]["rowcnt"] > 0:
return True
else:
return False
else:
return False
def check_table(self):
"""
Check if required tables exist and create them if they don't.
Raises:
Exception: If there's an error checking or creating tables.
"""
SQL = SQL_TEMPLATES["has_table"]
try:
res = self.query(SQL)
if len(res) < 2:
missing_table = TABLES.keys() - set([i["table_name"] for i in res])
for table in missing_table:
self.create_tables(table)
except Exception as e:
log.critical(f"Failed to check table in Oracle database, error info: {e}")
raise
def create_tables(self, table_name):
"""
Create a table in the database.
Args:
table_name: Name of the table to create.
Raises:
Exception: If there's an error creating the table.
"""
SQL = TABLES[table_name]
try:
self.execute(SQL)
log.color_print(f"Created table {table_name} in Oracle database")
except Exception as e:
log.critical(f"Failed to create table {table_name} in Oracle database, error info: {e}")
raise
def drop_collection(self, collection: str = "deepsearcher"):
"""
Drop a collection from the database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
Raises:
Exception: If there's an error dropping the collection.
"""
try:
params = {"collection": collection}
SQL = SQL_TEMPLATES["drop_collection"]
self.execute(SQL, params)
SQL = SQL_TEMPLATES["drop_collection_item"]
self.execute(SQL, params)
log.color_print(f"Collection {collection} dropped")
except Exception as e:
log.critical(f"fail to drop collection, error info: {e}")
raise
def insertone(self, data):
"""
Insert a single record into the database.
Args:
data: Data to insert.
"""
SQL = SQL_TEMPLATES["insert"]
self.execute(SQL, data)
log.debug("insert done!")
def searchone(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
):
"""
Search for similar vectors in a collection.
Args:
collection (Optional[str]): Collection name to search in.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
Returns:
list: List of search results.
Raises:
Exception: If there's an error during search.
"""
log.debug("def searchone:" + collection)
try:
if isinstance(vector, List):
vector = np.array(vector)
embedding_string = "[" + ", ".join(map(str, vector.tolist())) + "]"
dimension = vector.shape[0]
dtype = str(vector.dtype).upper()
SQL = SQL_TEMPLATES["search"].format(dimension=dimension, dtype=dtype)
max_distance = 0.8
params = {
"collection": collection,
"embedding_string": embedding_string,
"top_k": top_k,
"max_distance": max_distance,
}
res = self.query(SQL, params)
if res:
return res
else:
return []
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
raise
def init_collection(
self,
dim: int,
collection: Optional[str] = "deepsearcher",
description: Optional[str] = "",
force_new_collection: bool = False,
text_max_length: int = 65_535,
reference_max_length: int = 2048,
metric_type: str = "L2",
*args,
**kwargs,
):
"""
Initialize a collection in the database.
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher".
description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
Exception: If there's an error initializing the collection.
"""
if not collection:
collection = self.default_collection
if description is None:
description = ""
try:
has_collection = self.has_collection(collection)
if force_new_collection and has_collection:
self.drop_collection(collection)
elif has_collection:
return
# insert collection info
SQL = SQL_TEMPLATES["insert_collection"]
params = {"collection": collection, "description": description}
self.execute(SQL, params)
except Exception as e:
log.critical(f"fail to init_collection for oracle, error info: {e}")
def insert_data(
self,
collection: Optional[str],
chunks: List[Chunk],
batch_size: int = 256,
*args,
**kwargs,
):
"""
Insert data into a collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
chunks (List[Chunk]): List of Chunk objects to insert.
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Raises:
Exception: If there's an error inserting data.
"""
if not collection:
collection = self.default_collection
datas = []
for chunk in chunks:
_data = {
"embedding": self.numpy_converter_in(np.array(chunk.embedding)),
"text": chunk.text,
"reference": chunk.reference,
"metadata": json.dumps(chunk.metadata),
"collection": collection,
}
datas.append(_data)
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]
try:
for batch_data in batch_datas:
for _data in batch_data:
self.insertone(data=_data)
log.color_print(f"Successfully insert {len(datas)} data")
except Exception as e:
log.critical(f"fail to insert data, error info: {e}")
raise
def search_data(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
*args,
**kwargs,
) -> List[RetrievalResult]:
"""
Search for similar vectors in a collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[RetrievalResult]: List of retrieval results containing similar vectors.
Raises:
Exception: If there's an error during search.
"""
if not collection:
collection = self.default_collection
try:
# print("def search_data:",collection)
# print("def search_data:",type(vector))
search_results = self.searchone(collection=collection, vector=vector, top_k=top_k)
# print("def search_data: search_results",search_results)
return [
RetrievalResult(
embedding=b["embedding"],
text=b["text"],
reference=b["reference"],
score=b["distance"],
metadata=json.loads(b["metadata"]),
)
for b in search_results
]
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
raise
# return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""
List all collections in the database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[CollectionInfo]: List of collection information objects.
"""
collection_infos = []
try:
SQL = SQL_TEMPLATES["list_collections"]
log.debug("def list_collections:" + SQL)
collections = self.query(SQL)
if collections:
for collection in collections:
collection_infos.append(
CollectionInfo(
collection_name=collection["collection"],
description=collection["description"],
)
)
return collection_infos
except Exception as e:
log.critical(f"fail to list collections, error info: {e}")
raise
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs):
"""
Clear (drop) a collection from the database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
try:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear db, error info: {e}")
raise
TABLES = {
"DEEPSEARCHER_COLLECTION_INFO": """CREATE TABLE DEEPSEARCHER_COLLECTION_INFO (
id INT generated by default as identity primary key,
collection varchar(256),
description CLOB,
status NUMBER DEFAULT 1,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL)""",
"DEEPSEARCHER_COLLECTION_ITEM": """CREATE TABLE DEEPSEARCHER_COLLECTION_ITEM (
id INT generated by default as identity primary key,
collection varchar(256),
embedding VECTOR,
text CLOB,
reference varchar(4000),
metadata CLOB,
status NUMBER DEFAULT 1,
createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updatetime TIMESTAMP DEFAULT NULL)""",
}
SQL_TEMPLATES = {
"has_table": f"""SELECT table_name FROM all_tables
WHERE table_name in ({",".join([f"'{k}'" for k in TABLES.keys()])})""",
"has_collection": "select count(*) as rowcnt from DEEPSEARCHER_COLLECTION_INFO where collection=:collection and status=1",
"list_collections": "select collection,description from DEEPSEARCHER_COLLECTION_INFO where status=1",
"drop_collection": "update DEEPSEARCHER_COLLECTION_INFO set status=0 where collection=:collection and status=1",
"drop_collection_item": "update DEEPSEARCHER_COLLECTION_ITEM set status=0 where collection=:collection and status=1",
"insert_collection": """INSERT INTO DEEPSEARCHER_COLLECTION_INFO (collection,description)
values (:collection,:description)""",
"insert": """INSERT INTO DEEPSEARCHER_COLLECTION_ITEM (collection,embedding,text,reference,metadata)
values (:collection,:embedding,:text,:reference,:metadata)""",
"search": """SELECT * FROM
(SELECT t.*,
VECTOR_DISTANCE(t.embedding,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
FROM DEEPSEARCHER_COLLECTION_ITEM t
JOIN DEEPSEARCHER_COLLECTION_INFO c ON t.collection=c.collection
WHERE t.collection=:collection AND t.status=1 AND c.status=1)
WHERE distance<:max_distance ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
}