File size: 5,365 Bytes
5306da4 ecaf3da 5306da4 ecaf3da 5306da4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import numpy as np
import uuid
import asyncio
# Initialize Qdrant client and collection
client = QdrantClient(host='localhost', port=6333)
collection_name = 'vocRT_collection'
async def initialize_collection():
"""
Initialize the Qdrant collection if it doesn't exist.
"""
collections = client.get_collections().collections
collection_names = [col.name for col in collections]
if collection_name not in collection_names:
# Create the collection if it doesn't exist
client.create_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(
size=768, # Adjust the size to match your embedding dimension
distance=rest.Distance.COSINE
),
)
print(f"Collection '{collection_name}' created.")
else:
print(f"Collection '{collection_name}' already exists.")
asyncio.run(initialize_collection())
async def store_embeddings(session_id, embeddings, texts=None, name="", title="", summary="", categories=""):
"""
Store embeddings for a specific session_id.
Parameters:
- session_id (str): Unique identifier for the session/user.
- embeddings (list of numpy arrays or lists): The embeddings to store.
- texts (list of str, optional): Corresponding text passages for the embeddings.
"""
await initialize_collection()
if texts is not None and len(embeddings) != len(texts):
raise ValueError(
"The number of embeddings and texts must be the same.")
# Upsert embeddings with metadata
points = []
for idx, embedding in enumerate(embeddings):
payload = {'session_id': session_id}
if texts is not None:
payload['text'] = texts[idx]
if name is not None:
payload['filename'] = name
if title is not None:
payload['title'] = title
if summary is not None:
payload['summary'] = summary
if categories is not None:
if isinstance(categories, str):
categories_list = [cat.strip()
for cat in categories.split(',') if cat.strip()]
else:
categories_list = list(categories)
payload['categories'] = categories_list
point_id = str(uuid.uuid4())
point = rest.PointStruct(
id=point_id,
vector=embedding.tolist() if isinstance(embedding, np.ndarray) else embedding,
payload=payload
)
points.append(point)
client.upsert(
collection_name=collection_name,
wait=True,
points=points
)
print(f"Embeddings stored for session_id: {session_id}")
def search_embeddings(session_id, query_embedding, limit=10):
"""
Search embeddings for a specific session_id using a query embedding.
Parameters:
- session_id (str): Unique identifier for the session/user.
- query_embedding (numpy array or list): The query embedding vector.
- limit (int): The number of top results to return.
Returns:
- List of search results, each containing the ID, distance, and payload.
"""
# Ensure query_embedding is a list
if isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.tolist()
# Perform search with session_id filter
results = client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=rest.Filter(
must=[
rest.FieldCondition(
key='session_id',
match=rest.MatchValue(value=session_id)
)
]
),
limit=limit,
with_payload=True
)
return results
def delete_embeddings(session_id):
"""
Delete all embeddings for a specific session_id.
Parameters:
- session_id (str): Unique identifier for the session/user.
"""
# Retrieve all point IDs for the given session_id
point_ids = []
# Scroll through all points matching the session_id
offset = None
while True:
scroll_result = client.scroll(
collection_name=collection_name,
scroll_filter=rest.Filter(
must=[
rest.FieldCondition(
key='session_id',
match=rest.MatchValue(value=session_id)
)
]
),
limit=100,
offset=offset,
with_payload=False
)
points = scroll_result[0]
offset = scroll_result[1]
if not points:
break
point_ids.extend([point.id for point in points])
if offset is None:
break
if point_ids:
try:
client.delete(
collection_name=collection_name,
points_selector=rest.PointIdsList(points=point_ids)
)
print(f"Deleted embeddings for session_id: {session_id}")
return True
except Exception as e:
print("Error in deleting embeddings : ", e)
return False
else:
print(f"No embeddings found for session_id: {session_id}")
return True
|