File size: 5,365 Bytes
5306da4
 
 
 
ecaf3da
5306da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecaf3da
 
5306da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import numpy as np
import uuid
import asyncio

# Initialize Qdrant client and collection
client = QdrantClient(host='localhost', port=6333)
collection_name = 'vocRT_collection'


async def initialize_collection():
    """
    Initialize the Qdrant collection if it doesn't exist.
    """
    collections = client.get_collections().collections
    collection_names = [col.name for col in collections]

    if collection_name not in collection_names:
        # Create the collection if it doesn't exist
        client.create_collection(
            collection_name=collection_name,
            vectors_config=rest.VectorParams(
                size=768,  # Adjust the size to match your embedding dimension
                distance=rest.Distance.COSINE
            ),
        )
        print(f"Collection '{collection_name}' created.")
    else:
        print(f"Collection '{collection_name}' already exists.")


asyncio.run(initialize_collection())

async def store_embeddings(session_id, embeddings, texts=None, name="", title="", summary="", categories=""):
    """
    Store embeddings for a specific session_id.

    Parameters:
    - session_id (str): Unique identifier for the session/user.
    - embeddings (list of numpy arrays or lists): The embeddings to store.
    - texts (list of str, optional): Corresponding text passages for the embeddings.
    """

    await initialize_collection()

    if texts is not None and len(embeddings) != len(texts):
        raise ValueError(
            "The number of embeddings and texts must be the same.")

    # Upsert embeddings with metadata
    points = []
    for idx, embedding in enumerate(embeddings):
        payload = {'session_id': session_id}
        if texts is not None:
            payload['text'] = texts[idx]
        if name is not None:
            payload['filename'] = name
        if title is not None:
            payload['title'] = title
        if summary is not None:
            payload['summary'] = summary
        if categories is not None:
            if isinstance(categories, str):
                categories_list = [cat.strip()
                                   for cat in categories.split(',') if cat.strip()]
            else:
                categories_list = list(categories)

            payload['categories'] = categories_list

        point_id = str(uuid.uuid4())
        point = rest.PointStruct(
            id=point_id,
            vector=embedding.tolist() if isinstance(embedding, np.ndarray) else embedding,
            payload=payload
        )
        points.append(point)

    client.upsert(
        collection_name=collection_name,
        wait=True,
        points=points
    )
    print(f"Embeddings stored for session_id: {session_id}")


def search_embeddings(session_id, query_embedding, limit=10):
    """
    Search embeddings for a specific session_id using a query embedding.

    Parameters:
    - session_id (str): Unique identifier for the session/user.
    - query_embedding (numpy array or list): The query embedding vector.
    - limit (int): The number of top results to return.

    Returns:
    - List of search results, each containing the ID, distance, and payload.
    """
    # Ensure query_embedding is a list
    if isinstance(query_embedding, np.ndarray):
        query_embedding = query_embedding.tolist()

    # Perform search with session_id filter
    results = client.search(
        collection_name=collection_name,
        query_vector=query_embedding,
        query_filter=rest.Filter(
            must=[
                rest.FieldCondition(
                    key='session_id',
                    match=rest.MatchValue(value=session_id)
                )
            ]
        ),
        limit=limit,
        with_payload=True
    )
    return results


def delete_embeddings(session_id):
    """
    Delete all embeddings for a specific session_id.

    Parameters:
    - session_id (str): Unique identifier for the session/user.
    """
    # Retrieve all point IDs for the given session_id
    point_ids = []

    # Scroll through all points matching the session_id
    offset = None
    while True:
        scroll_result = client.scroll(
            collection_name=collection_name,
            scroll_filter=rest.Filter(
                must=[
                    rest.FieldCondition(
                        key='session_id',
                        match=rest.MatchValue(value=session_id)
                    )
                ]
            ),
            limit=100,
            offset=offset,
            with_payload=False
        )
        points = scroll_result[0]
        offset = scroll_result[1]
        if not points:
            break
        point_ids.extend([point.id for point in points])
        if offset is None:
            break

    if point_ids:
        try:
            client.delete(
                collection_name=collection_name,
                points_selector=rest.PointIdsList(points=point_ids)
            )
            print(f"Deleted embeddings for session_id: {session_id}")
            return True
        except Exception as e:
            print("Error in deleting embeddings : ", e)
            return False
    else:
        print(f"No embeddings found for session_id: {session_id}")
        return True