Spaces:
Runtime error
Runtime error
from PIL import Image | |
from pymilvus import connections, Collection | |
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection | |
from utils.fetch_image import fetch_image | |
def load_collection(name): | |
collection = Collection(name) | |
return collection | |
def create_collection(name, description): | |
fields = [ | |
FieldSchema(name="text_embedding", dtype=DataType.FLOAT_VECTOR, dim=512), | |
FieldSchema(name="image_embedding", dtype=DataType.FLOAT_VECTOR, dim=512), | |
FieldSchema(name="avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512), | |
FieldSchema(name="weighted_avg_embedding", dtype=DataType.FLOAT_VECTOR, dim=512), | |
FieldSchema(name="image_id", dtype=DataType.INT64, is_primary=True), | |
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=5000) | |
] | |
schema = CollectionSchema(fields, description=description) | |
collection = Collection(name=name, schema=schema) | |
return collection | |
def create_hnsw_index(collection): | |
index_params = { | |
"index_type": "HNSW", | |
"metric_type": "IP", # IP for cosine similarity | |
"params": {"M": 32, "efConstruction": 200} | |
} | |
collection.create_index(field_name="text_embedding", index_params=index_params) | |
collection.create_index(field_name="image_embedding", index_params=index_params) | |
collection.create_index(field_name="avg_embedding", index_params=index_params) | |
collection.create_index(field_name="weighted_avg_embedding", index_params=index_params) | |
def insert_data(collection, catalog, column, text_embeds, image_embeds, avg_embeds, w_avg_embeds): | |
image_ids = catalog['Id'].tolist() | |
metadata = catalog[column].tolist() | |
collection.insert([ | |
text_embeds, | |
image_embeds, | |
avg_embeds, | |
w_avg_embeds, | |
image_ids, | |
metadata | |
]) | |
def search_in_milvus(collection, search_field, query_embedding, top_k=6): | |
# Step 1: Perform search in Milvus | |
search_params = {"metric_type": "IP", "params": {"ef": 128}} | |
results = collection.search( | |
query_embedding.tolist(), # Query vector | |
search_field, # Field to search in | |
param=search_params, | |
limit=top_k, # Top k results | |
output_fields=["image_id", "metadata", "url"] | |
) | |
# Step 2: Extract the relevant information from the search results | |
search_results = [] | |
for result in results[0]: # The first element of 'results' contains the search results | |
image_id = result.entity.get("image_id") # Retrieve the image ID | |
metadata = result.entity.get("metadata") # Retrieve metadata (such as description, brand, etc.) | |
url = result.entity.get("url") # Retrieve url to fetch image | |
similarity_score = result.distance # Retrieve similarity score (distance) | |
# Load the image (you can use PIL to load the image) | |
#image_path = "/content/drive/MyDrive/images/" + str(image_id) + ".jpg" | |
#image = Image.open(image_path) | |
image = fetch_image(url) | |
# Append the image, metadata, and score to the search results | |
search_results.append({ | |
"image": image, | |
"metadata": metadata, | |
"similarity_score": similarity_score | |
}) | |
# Step 3: Return the search results | |
return search_results | |
conn = None | |
conn = connections.connect("default", | |
uri='https://in03-6efb78578dde7a3.serverless.gcp-us-west1.cloud.zilliz.com', | |
token='78a82c19d7a02c531dab34d97ffde11caba0aa18b58ad02c46ee98df99d912291043835a002e427d89d5ddbb65b342191c36c1ae' | |
) | |
fashionclip_collection = load_collection("fashionclip") | |
fashionsiglip_collection = load_collection("fashionsiglip") | |