File size: 3,736 Bytes
3fc9d29
 
 
b6b8856
3fc9d29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b8856
3fc9d29
 
 
 
 
 
 
b6b8856
3fc9d29
 
 
 
b6b8856
 
3fc9d29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from PIL import Image
from pymilvus import connections, Collection
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
from utils.fetch_image import fetch_image

def load_collection(name):
    collection = Collection(name)
    return collection

def create_collection(name, description):
    fields = [
        FieldSchema(name="text_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
        FieldSchema(name="image_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
        FieldSchema(name="avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
        FieldSchema(name="weighted_avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512),
        FieldSchema(name="image_id", dtype=DataType.INT64, is_primary=True),
        FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=5000)
    ]
    
    schema = CollectionSchema(fields, description=description)
    collection = Collection(name=name, schema=schema)
    
    return collection

def create_hnsw_index(collection):
    index_params = {
        "index_type": "HNSW",
        "metric_type": "IP",  # IP for cosine similarity
        "params": {"M": 32, "efConstruction": 200}
    }
    
    collection.create_index(field_name="text_embedding", index_params=index_params)
    collection.create_index(field_name="image_embedding", index_params=index_params)
    collection.create_index(field_name="avg_embedding", index_params=index_params)
    collection.create_index(field_name="weighted_avg_embedding", index_params=index_params)

def insert_data(collection, catalog, column, text_embeds, image_embeds, avg_embeds, w_avg_embeds):
    image_ids = catalog['Id'].tolist()
    metadata = catalog[column].tolist()
    
    collection.insert([
        text_embeds,
        image_embeds,
        avg_embeds,
        w_avg_embeds,
        image_ids,
        metadata
    ])

def search_in_milvus(collection, search_field, query_embedding, top_k=6):

    # Step 1: Perform search in Milvus
    search_params = {"metric_type": "IP", "params": {"ef": 128}}
    results = collection.search(
        query_embedding.tolist(), # Query vector
        search_field, # Field to search in
        param=search_params,
        limit=top_k, # Top k results
        output_fields=["image_id", "metadata", "url"]
    )

    # Step 2: Extract the relevant information from the search results
    search_results = []
    for result in results[0]:  # The first element of 'results' contains the search results
        image_id = result.entity.get("image_id")  # Retrieve the image ID
        metadata = result.entity.get("metadata")  # Retrieve metadata (such as description, brand, etc.)
        url = result.entity.get("url")  # Retrieve url to fetch image
        similarity_score = result.distance  # Retrieve similarity score (distance)

        # Load the image (you can use PIL to load the image)
        #image_path = "/content/drive/MyDrive/images/" + str(image_id) + ".jpg"
        #image = Image.open(image_path)
        image = fetch_image(url)

        # Append the image, metadata, and score to the search results
        search_results.append({
            "image": image,
            "metadata": metadata,
            "similarity_score": similarity_score
        })

    # Step 3: Return the search results
    return search_results

conn = None 

conn = connections.connect("default",
                    uri='https://in03-6efb78578dde7a3.serverless.gcp-us-west1.cloud.zilliz.com',
                    token='78a82c19d7a02c531dab34d97ffde11caba0aa18b58ad02c46ee98df99d912291043835a002e427d89d5ddbb65b342191c36c1ae'
                    )

fashionclip_collection = load_collection("fashionclip")
fashionsiglip_collection = load_collection("fashionsiglip")