File size: 7,669 Bytes
065b6ad d9fa664 065b6ad d9fa664 a82e32f 065b6ad d9fa664 065b6ad 26b5c93 065b6ad 26b5c93 065b6ad 97177b4 065b6ad d9fa664 065b6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from pymilvus import MilvusClient, DataType
import numpy as np
import concurrent.futures
from pymilvus import Collection
import os
class MilvusManager:
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
#import environ variables from .env
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
self.client = MilvusClient(uri=milvus_uri)
self.collection_name = collection_name
self.dim = dim
if self.client.has_collection(collection_name=self.collection_name):
self.client.load_collection(collection_name=self.collection_name)
print("Loaded existing collection.")
elif create_collection:
self.create_collection()
self.create_index()
def create_collection(self):
if self.client.has_collection(collection_name=self.collection_name):
print("Collection already exists.")
return
schema = self.client.create_schema(
auto_id=True,
enable_dynamic_fields=True,
)
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
)
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
self.client.create_collection(
collection_name=self.collection_name, schema=schema
)
def create_index(self):
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_FLAT",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=self.collection_name, index_params=index_params, sync=True
)
def search(self, data, topk):
# Retrieve all collection names from the Milvus client.
collections = self.client.list_collections()
# Set search parameters (here, using Inner Product metric).
search_params = {"metric_type": "COSINE", "params": {}} #default metric type is "IP"
# Set to store unique (doc_id, collection_name) pairs across all collections.
doc_collection_pairs = set()
# Query each collection individually
for collection in collections:
self.client.load_collection(collection_name=collection)
print("collection loaded:"+ collection)
results = self.client.search(
collection,
data,
limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50)
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
# Accumulate document IDs along with their originating collection.
for r_id in range(len(results)):
for r in range(len(results[r_id])):
doc_id = results[r_id][r]["entity"]["doc_id"]
doc_collection_pairs.add((doc_id, collection))
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
# Query for detailed document vectors in the given collection.
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=16380,
)
# Stack the vectors for dot product computation.
doc_vecs = np.vstack(
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
)
# Compute a similarity score via dot product.
score = np.dot(data, doc_vecs.T).max(1).sum()
return (score, doc_id, collection_name)
# Use a thread pool to rerank each document concurrently.
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection)
for doc_id, collection in doc_collection_pairs
}
for future in concurrent.futures.as_completed(futures):
score, doc_id, collection = future.result()
scores.append((score, doc_id, collection))
#doc_id is page number!
# Sort the reranked results by score in descending order.
scores.sort(key=lambda x: x[0], reverse=True)
# Unload the collection after search to free memory.
self.client.release_collection(collection_name=collection)
return scores[:topk] if len(scores) >= topk else scores
"""
search_params = {"metric_type": "IP", "params": {}}
results = self.client.search(
self.collection_name,
data,
limit=50,
output_fields=["vector", "seq_id", "doc_id"],
search_params=search_params,
)
doc_ids = {result["entity"]["doc_id"] for result in results[0]}
scores = []
def rerank_single_doc(doc_id, data, client, collection_name):
doc_colbert_vecs = client.query(
collection_name=collection_name,
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
output_fields=["seq_id", "vector", "doc"],
limit=1000,
)
doc_vecs = np.vstack(
[doc["vector"] for doc in doc_colbert_vecs]
)
score = np.dot(data, doc_vecs.T).max(1).sum()
return score, doc_id
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
futures = {
executor.submit(
rerank_single_doc, doc_id, data, self.client, self.collection_name
): doc_id
for doc_id in doc_ids
}
for future in concurrent.futures.as_completed(futures):
score, doc_id = future.result()
scores.append((score, doc_id))
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk]
"""
def insert(self, data):
colbert_vecs = data["colbert_vecs"]
seq_length = len(colbert_vecs)
doc_ids = [data["doc_id"]] * seq_length
seq_ids = list(range(seq_length))
docs = [""] * seq_length
docs[0] = data["filepath"]
self.client.insert(
self.collection_name,
[
{
"vector": colbert_vecs[i],
"seq_id": seq_ids[i],
"doc_id": doc_ids[i],
"doc": docs[i],
}
for i in range(seq_length)
],
)
def get_images_as_doc(self, images_with_vectors):
return [
{
"colbert_vecs": image["colbert_vecs"],
"doc_id": idx,
"filepath": image["filepath"],
}
for idx, image in enumerate(images_with_vectors)
]
def insert_images_data(self, image_data):
data = self.get_images_as_doc(image_data)
for item in data:
self.insert(item)
|