CatPtain's picture
Upload 697 files
20f348c verified
import os
from collections import UserDict
from unittest.mock import MagicMock
import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient # type: ignore
from pymochow.model.database import Database # type: ignore
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
from pymochow.model.table import Table # type: ignore
from requests.adapters import HTTPAdapter
class AttrDict(UserDict):
def __getattr__(self, item):
return self.get(item)
class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: HTTPAdapter = None,
):
self.conn = MagicMock()
self._config = MagicMock()
def list_databases(self, config=None) -> list[Database]:
return [
Database(
conn=self.conn,
database_name="dify",
config=self._config,
)
]
def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self.conn, database_name=database_name, config=config)
def list_table(self, config=None) -> list[Table]:
return []
def drop_table(self, table_name: str, config=None):
return {"code": 0, "msg": "Success"}
def create_table(
self,
table_name: str,
replication: int,
partition: int,
schema,
enable_dynamic_field=False,
description: str = "",
config=None,
) -> Table:
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)
def describe_table(self, table_name: str, config=None) -> Table:
return Table(
self,
table_name,
3,
1,
None,
enable_dynamic_field=False,
description="table for dify",
config=config,
state=TableState.NORMAL,
)
def upsert(self, rows, config=None):
return {"code": 0, "msg": "operation success", "affectedCount": 1}
def rebuild_index(self, index_name: str, config=None):
return {"code": 0, "msg": "Success"}
def describe_index(self, index_name: str, config=None):
return VectorIndex(
index_name=index_name,
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.L2,
params=HNSWParams(m=16, efconstruction=200),
auto_build=False,
state=IndexState.NORMAL,
)
def query(
self,
primary_key,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return AttrDict(
{
"row": {
"id": primary_key.get("id"),
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
},
"code": 0,
"msg": "Success",
}
)
def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"}
def search(
self,
anns,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return AttrDict(
{
"rows": [
{
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
},
"distance": 0.1,
"score": 0.5,
}
],
"code": 0,
"msg": "Success",
}
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)
yield
if MOCK:
monkeypatch.undo()