rfmantoan commited on
Commit
3fc9d29
·
1 Parent(s): 27b3217

add missing utils file

Browse files
Files changed (1) hide show
  1. utils/vector_database.py +91 -0
utils/vector_database.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from pymilvus import connections, Collection
3
+ from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
4
+
5
+ def load_collection(name):
6
+ collection = Collection(name)
7
+ return collection
8
+
9
+ def create_collection(name, description):
10
+ fields = [
11
+ FieldSchema(name="text_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
12
+ FieldSchema(name="image_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
13
+ FieldSchema(name="avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
14
+ FieldSchema(name="weighted_avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
15
+ FieldSchema(name="image_id", dtype=DataType.INT64, is_primary=True),
16
+ FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=5000)
17
+ ]
18
+
19
+ schema = CollectionSchema(fields, description=description)
20
+ collection = Collection(name=name, schema=schema)
21
+
22
+ return collection
23
+
24
+ def create_hnsw_index(collection):
25
+ index_params = {
26
+ "index_type": "HNSW",
27
+ "metric_type": "IP", # IP for cosine similarity
28
+ "params": {"M": 32, "efConstruction": 200}
29
+ }
30
+
31
+ collection.create_index(field_name="text_embedding", index_params=index_params)
32
+ collection.create_index(field_name="image_embedding", index_params=index_params)
33
+ collection.create_index(field_name="avg_embedding", index_params=index_params)
34
+ collection.create_index(field_name="weighted_avg_embedding", index_params=index_params)
35
+
36
+ def insert_data(collection, catalog, column, text_embeds, image_embeds, avg_embeds, w_avg_embeds):
37
+ image_ids = catalog['Id'].tolist()
38
+ metadata = catalog[column].tolist()
39
+
40
+ collection.insert([
41
+ text_embeds,
42
+ image_embeds,
43
+ avg_embeds,
44
+ w_avg_embeds,
45
+ image_ids,
46
+ metadata
47
+ ])
48
+
49
+ def search_in_milvus(collection, search_field, query_embedding, top_k=6):
50
+
51
+ # Step 1: Perform search in Milvus
52
+ search_params = {"metric_type": "IP", "params": {"ef": 128}}
53
+ results = collection.search(
54
+ query_embedding.tolist(), # Query vector
55
+ search_field, # Field to search in
56
+ param=search_params,
57
+ limit=top_k, # Top k results
58
+ output_fields=["image_id", "metadata"]
59
+ )
60
+
61
+ # Step 2: Extract the relevant information from the search results
62
+ search_results = []
63
+ for result in results[0]: # The first element of 'results' contains the search results
64
+ image_id = result.entity.get("image_id") # Retrieve the image ID
65
+ metadata = result.entity.get("metadata") # Retrieve metadata (such as description, brand, etc.)
66
+ similarity_score = result.distance # Retrieve similarity score (distance)
67
+
68
+ # Load the image (you can use PIL to load the image)
69
+ #image_path = "/content/drive/MyDrive/images/" + str(image_id) + ".jpg"
70
+ image_path = "/images/" + str(image_id) + ".jpg"
71
+ image = Image.open(image_path)
72
+
73
+ # Append the image, metadata, and score to the search results
74
+ search_results.append({
75
+ "image": image,
76
+ "metadata": metadata,
77
+ "similarity_score": similarity_score
78
+ })
79
+
80
+ # Step 3: Return the search results
81
+ return search_results
82
+
83
+ conn = None
84
+
85
+ conn = connections.connect("default",
86
+ uri='https://in03-6efb78578dde7a3.serverless.gcp-us-west1.cloud.zilliz.com',
87
+ token='78a82c19d7a02c531dab34d97ffde11caba0aa18b58ad02c46ee98df99d912291043835a002e427d89d5ddbb65b342191c36c1ae'
88
+ )
89
+
90
+ fashionclip_collection = load_collection("fashionclip")
91
+ fashionsiglip_collection = load_collection("fashionsiglip")