File size: 4,418 Bytes
a8b3f00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
from typing import Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import VectorDBClient
from tcvectordb.model.database import Collection, Database
from tcvectordb.model.document import Document, Filter
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import Index
from xinference_client.types import Embedding
class MockTcvectordbClass:
def mock_vector_db_client(
self,
url=None,
username="",
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None,
):
self._conn = None
self._read_consistency = read_consistency
def list_databases(self) -> list[Database]:
return [
Database(
conn=self._conn,
read_consistency=self._read_consistency,
name="dify",
)
]
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
return []
def drop_collection(self, name: str, timeout: Optional[float] = None):
return {"code": 0, "msg": "operation success"}
def create_collection(
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: Optional[float] = None,
) -> Collection:
return Collection(
self,
name,
shard,
replicas,
description,
index,
embedding=embedding,
read_consistency=self._read_consistency,
timeout=timeout,
)
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
return collection
def collection_upsert(
self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs
):
return {"code": 0, "msg": "operation success"}
def collection_search(
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[list[dict]]:
return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_query(
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
def collection_delete(
self,
document_ids: Optional[list[str]] = None,
filter: Filter = None,
timeout: Optional[float] = None,
):
return {"code": 0, "msg": "operation success"}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete)
yield
if MOCK:
monkeypatch.undo()
|