You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
8.5 KiB
255 lines
8.5 KiB
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import numpy as np
|
|
import sys
|
|
import json
|
|
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
from deepsearcher.loader.splitter import Chunk
|
|
import logging
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
class TestOracleDB(unittest.TestCase):
|
|
"""Tests for the Oracle vector database implementation."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Create mock modules
|
|
self.mock_oracledb = MagicMock()
|
|
|
|
# Setup mock DB_TYPE_VECTOR
|
|
self.mock_oracledb.DB_TYPE_VECTOR = "VECTOR"
|
|
self.mock_oracledb.defaults = MagicMock()
|
|
|
|
# Create the module patcher
|
|
self.module_patcher = patch.dict('sys.modules', {
|
|
'oracledb': self.mock_oracledb
|
|
})
|
|
|
|
# Start the patcher
|
|
self.module_patcher.start()
|
|
|
|
# Import after mocking
|
|
from deepsearcher.vector_db import OracleDB
|
|
self.OracleDB = OracleDB
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
self.module_patcher.stop()
|
|
|
|
def test_init(self):
|
|
"""Test basic initialization."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd",
|
|
default_collection="test_collection"
|
|
)
|
|
|
|
# Verify initialization
|
|
self.assertEqual(oracle_db.default_collection, "test_collection")
|
|
self.assertIsNotNone(oracle_db.client)
|
|
self.mock_oracledb.create_pool.assert_called_once()
|
|
self.assertTrue(self.mock_oracledb.defaults.fetch_lobs is False)
|
|
|
|
def test_insert_data(self):
|
|
"""Test inserting data."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
|
|
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd"
|
|
)
|
|
|
|
# Create test data
|
|
d = 8
|
|
rng = np.random.default_rng(seed=42)
|
|
test_chunks = [
|
|
Chunk(
|
|
embedding=rng.random(d).tolist(),
|
|
text="hello world",
|
|
reference="test.txt",
|
|
metadata={"key": "value1"}
|
|
),
|
|
Chunk(
|
|
embedding=rng.random(d).tolist(),
|
|
text="hello oracle",
|
|
reference="test.txt",
|
|
metadata={"key": "value2"}
|
|
)
|
|
]
|
|
|
|
oracle_db.insert_data(collection="test_collection", chunks=test_chunks)
|
|
self.assertTrue(mock_cursor.execute.called)
|
|
self.assertTrue(mock_connection.commit.called)
|
|
|
|
def test_search_data(self):
|
|
"""Test search functionality."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
|
|
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
|
|
# Mock search results
|
|
mock_cursor.description = [("embedding",), ("text",), ("reference",), ("distance",), ("metadata",)]
|
|
mock_cursor.fetchall.return_value = [
|
|
(
|
|
np.array([0.1, 0.2, 0.3]),
|
|
"hello world",
|
|
"test.txt",
|
|
0.95,
|
|
json.dumps({"key": "value1"})
|
|
),
|
|
(
|
|
np.array([0.4, 0.5, 0.6]),
|
|
"hello oracle",
|
|
"test.txt",
|
|
0.85,
|
|
json.dumps({"key": "value2"})
|
|
)
|
|
]
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd"
|
|
)
|
|
|
|
# Test search
|
|
d = 8
|
|
rng = np.random.default_rng(seed=42)
|
|
query_vector = rng.random(d)
|
|
|
|
results = oracle_db.search_data(
|
|
collection="test_collection",
|
|
vector=query_vector,
|
|
top_k=2
|
|
)
|
|
|
|
self.assertIsInstance(results, list)
|
|
self.assertEqual(len(results), 2)
|
|
for result in results:
|
|
self.assertIsInstance(result, RetrievalResult)
|
|
|
|
def test_list_collections(self):
|
|
"""Test listing collections."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
|
|
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
|
|
# Mock list_collections response
|
|
mock_cursor.description = [("collection",), ("description",)]
|
|
mock_cursor.fetchall.return_value = [
|
|
("test_collection_1", "Test collection 1"),
|
|
("test_collection_2", "Test collection 2")
|
|
]
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd"
|
|
)
|
|
|
|
collections = oracle_db.list_collections()
|
|
self.assertIsInstance(collections, list)
|
|
self.assertEqual(len(collections), 2)
|
|
|
|
def test_clear_db(self):
|
|
"""Test clearing database."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
|
|
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd"
|
|
)
|
|
|
|
oracle_db.clear_db("test_collection")
|
|
self.assertTrue(mock_cursor.execute.called)
|
|
self.assertTrue(mock_connection.commit.called)
|
|
|
|
def test_has_collection(self):
|
|
"""Test checking if collection exists."""
|
|
# Setup mock
|
|
mock_pool = MagicMock()
|
|
mock_connection = MagicMock()
|
|
mock_cursor = MagicMock()
|
|
|
|
self.mock_oracledb.create_pool.return_value = mock_pool
|
|
mock_pool.acquire.return_value.__enter__.return_value = mock_connection
|
|
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
|
|
# Mock check_table response first (called during init)
|
|
mock_cursor.description = [("table_name",)]
|
|
mock_cursor.fetchall.return_value = [
|
|
("DEEPSEARCHER_COLLECTION_INFO",),
|
|
("DEEPSEARCHER_COLLECTION_ITEM",)
|
|
]
|
|
|
|
oracle_db = self.OracleDB(
|
|
user="test_user",
|
|
password="test_password",
|
|
dsn="test_dsn",
|
|
config_dir="/test/config",
|
|
wallet_location="/test/wallet",
|
|
wallet_password="test_wallet_pwd"
|
|
)
|
|
|
|
# Now mock has_collection response - collection exists
|
|
mock_cursor.description = [("rowcnt",)]
|
|
mock_cursor.fetchall.return_value = [(1,)] # Return tuple, not dict
|
|
|
|
result = oracle_db.has_collection("test_collection")
|
|
self.assertTrue(result)
|
|
|
|
# Test collection doesn't exist
|
|
mock_cursor.fetchall.return_value = [(0,)] # Return tuple, not dict
|
|
result = oracle_db.has_collection("nonexistent_collection")
|
|
self.assertFalse(result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|