Spaces:
Runtime error
Runtime error
rfmantoan
commited on
Commit
·
27b3217
1
Parent(s):
8b39863
add utils
Browse files- utils/data_preprocessing.py +29 -0
- utils/embedding_generation.py +169 -0
- utils/load_models.py +20 -0
- utils/refine_metadata.py +72 -0
- utils/search_functions.py +80 -0
utils/data_preprocessing.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
def load_data(catalog):
|
4 |
+
catalog = pd.read_excel('catalog_1k.xlsx')
|
5 |
+
return catalog
|
6 |
+
|
7 |
+
def preprocess_data(catalog):
|
8 |
+
# Clean description
|
9 |
+
catalog['Description'] = catalog['Description'].str.replace('\n', '')
|
10 |
+
|
11 |
+
# Id column to integer
|
12 |
+
catalog['Id'] = pd.to_numeric(catalog['Id'], errors='coerce').astype('Int64')
|
13 |
+
|
14 |
+
# Map gender
|
15 |
+
catalog['Gender'] = catalog['Gender'].map({1: 'Women', 2: 'Men', 3: 'Unisex'})
|
16 |
+
|
17 |
+
# Drop sub-sub-categories
|
18 |
+
catalog = catalog.drop(['L3'], axis=1)
|
19 |
+
|
20 |
+
# Drop items without gender
|
21 |
+
catalog = catalog.dropna(subset=['Gender'])
|
22 |
+
|
23 |
+
# Use best image link
|
24 |
+
catalog['Image'] = catalog['Image'].str.split(',').str[-1]
|
25 |
+
|
26 |
+
# Convert the columns to strings before joining them
|
27 |
+
catalog["SimpleMetadata"] = catalog[["L1", "L2", "Gender", "MaterialName", "BrandName", "Name"]].astype(str).agg(', '.join, axis=1)
|
28 |
+
|
29 |
+
return catalog
|
utils/embedding_generation.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from utils.load_models import fclip_model, fclip_processor
|
4 |
+
from utils.load_models import siglip_model, siglip_preprocess_train, siglip_preprocess_val, siglip_tokenizer
|
5 |
+
|
6 |
+
def get_info(catalog, column):
|
7 |
+
image_paths = []
|
8 |
+
text_descriptions = []
|
9 |
+
|
10 |
+
for index, row in catalog.iterrows():
|
11 |
+
path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg"
|
12 |
+
image_paths.append(path)
|
13 |
+
text_descriptions.append(row[column])
|
14 |
+
|
15 |
+
return image_paths, text_descriptions
|
16 |
+
|
17 |
+
def normalize_embedding(embedding):
|
18 |
+
norm = torch.norm(embedding, p=2, dim=-1, keepdim=True).item() # Get the norm before normalization
|
19 |
+
embedding = embedding / norm
|
20 |
+
return embedding.detach().cpu().numpy()
|
21 |
+
|
22 |
+
def normalize_embeddings(embeddings):
|
23 |
+
norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
|
24 |
+
normalized_embeddings = embeddings / norm
|
25 |
+
return normalized_embeddings
|
26 |
+
|
27 |
+
def generate_fclip_embeddings(image_paths, texts, batch_size, alpha):
|
28 |
+
image_embeds_list = []
|
29 |
+
text_embeds_list = []
|
30 |
+
|
31 |
+
# Batch processing loop
|
32 |
+
for i in range(0, len(image_paths), batch_size):
|
33 |
+
batch_image_paths = image_paths[i:i + batch_size]
|
34 |
+
batch_texts = texts[i:i + batch_size]
|
35 |
+
|
36 |
+
# Load and preprocess batch of images and texts
|
37 |
+
images = [Image.open(path).convert("RGB") for path in batch_image_paths]
|
38 |
+
|
39 |
+
# Set the maximum sequence length to 77 to match the position embeddings
|
40 |
+
inputs = fclip_processor(text=batch_texts, images=images, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
41 |
+
|
42 |
+
# Move inputs to the GPU
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to GPU
|
45 |
+
|
46 |
+
# Generate embeddings
|
47 |
+
with torch.no_grad():
|
48 |
+
outputs = fclip_model(**inputs)
|
49 |
+
|
50 |
+
image_embeds_list.append(outputs.image_embeds)
|
51 |
+
text_embeds_list.append(outputs.text_embeds)
|
52 |
+
|
53 |
+
# Concatenate all embeddings
|
54 |
+
image_embeds = torch.cat(image_embeds_list, dim=0)
|
55 |
+
text_embeds = torch.cat(text_embeds_list, dim=0)
|
56 |
+
|
57 |
+
# Normalize embeddings
|
58 |
+
image_embeds = normalize_embeddings(image_embeds)
|
59 |
+
text_embeds = normalize_embeddings(text_embeds)
|
60 |
+
|
61 |
+
# Average embeddings
|
62 |
+
avg_embeds = (image_embeds + text_embeds) / 2
|
63 |
+
weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
|
64 |
+
avg_embeds = normalize_embeddings(avg_embeds)
|
65 |
+
weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
|
66 |
+
|
67 |
+
return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
|
68 |
+
|
69 |
+
def generate_siglip_embeddings(image_paths, texts, batch_size, alpha):
|
70 |
+
image_embeds_list = []
|
71 |
+
text_embeds_list = []
|
72 |
+
|
73 |
+
# Batch processing loop
|
74 |
+
for i in range(0, len(image_paths), batch_size):
|
75 |
+
batch_image_paths = image_paths[i:i + batch_size]
|
76 |
+
batch_texts = texts[i:i + batch_size]
|
77 |
+
|
78 |
+
# Load and preprocess batch of images and texts
|
79 |
+
images = [siglip_preprocess_val(Image.open(image_path).convert('RGB')).unsqueeze(0) for image_path in batch_image_paths]
|
80 |
+
images = torch.cat(images)
|
81 |
+
|
82 |
+
tokens = siglip_tokenizer(batch_texts)
|
83 |
+
|
84 |
+
# Move images to the same device as the model weights (GPU if available)
|
85 |
+
if torch.cuda.is_available():
|
86 |
+
images = images.cuda()
|
87 |
+
tokens = tokens.cuda()
|
88 |
+
|
89 |
+
# Generate embeddings
|
90 |
+
with torch.no_grad():
|
91 |
+
image_embeddings_batch = siglip_model.encode_image(images)
|
92 |
+
text_embeddings_batch = siglip_model.encode_text(tokens)
|
93 |
+
|
94 |
+
# Store embeddings
|
95 |
+
image_embeds_list.append(image_embeddings_batch)
|
96 |
+
text_embeds_list.append(text_embeddings_batch)
|
97 |
+
|
98 |
+
# Concatenate all embeddings
|
99 |
+
image_embeds = torch.cat(image_embeds_list, dim=0)
|
100 |
+
text_embeds = torch.cat(text_embeds_list, dim=0)
|
101 |
+
|
102 |
+
# Normalize embeddings
|
103 |
+
image_embeds = normalize_embeddings(image_embeds)
|
104 |
+
text_embeds = normalize_embeddings(text_embeds)
|
105 |
+
|
106 |
+
# Average embeddings
|
107 |
+
avg_embeds = (image_embeds + text_embeds) / 2
|
108 |
+
weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
|
109 |
+
avg_embeds = normalize_embeddings(avg_embeds)
|
110 |
+
weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
|
111 |
+
|
112 |
+
return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
|
113 |
+
|
114 |
+
# Function to process text embedding for any model
|
115 |
+
def generate_text_embedding(model, tokenizer, query, model_type):
|
116 |
+
if model_type == "fashionCLIP":
|
117 |
+
# Process the text with the tokenizer and move to GPU
|
118 |
+
inputs = tokenizer(text=query, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
119 |
+
|
120 |
+
if torch.cuda.is_available():
|
121 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
122 |
+
|
123 |
+
# Get text embedding from the model
|
124 |
+
text_embed = model.get_text_features(**inputs)
|
125 |
+
elif model_type == "fashionSigLIP":
|
126 |
+
tokens = tokenizer(query)
|
127 |
+
|
128 |
+
# Tokenize text and move to GPU
|
129 |
+
if torch.cuda.is_available():
|
130 |
+
tokens = tokens.to("cuda")
|
131 |
+
|
132 |
+
# Get text embedding from the model
|
133 |
+
text_embed = model.encode_text(tokens)
|
134 |
+
|
135 |
+
return normalize_embedding(text_embed)
|
136 |
+
|
137 |
+
# Function to process image embedding for any model
|
138 |
+
def generate_image_embedding(model, processor, image_path, model_type):
|
139 |
+
image = Image.open(image_path).convert("RGB")
|
140 |
+
|
141 |
+
if model_type == "fashionCLIP":
|
142 |
+
# Preprocess image for FashionCLIP and move to GPU
|
143 |
+
inputs = processor(images=image, return_tensors="pt")
|
144 |
+
|
145 |
+
if torch.cuda.is_available():
|
146 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
147 |
+
|
148 |
+
# Get image embedding from the model
|
149 |
+
image_embed = model.get_image_features(**inputs)
|
150 |
+
elif model_type == "fashionSigLIP":
|
151 |
+
# Preprocess image for SigLip and move to GPU
|
152 |
+
image_tensor = processor(image).unsqueeze(0)
|
153 |
+
|
154 |
+
if torch.cuda.is_available():
|
155 |
+
image_tensor = image_tensor.to("cuda")
|
156 |
+
|
157 |
+
# Get image embedding from the model
|
158 |
+
image_embed = model.encode_image(image_tensor)
|
159 |
+
|
160 |
+
return normalize_embedding(image_embed)
|
161 |
+
|
162 |
+
# Unified function to generate embeddings for both models and query types
|
163 |
+
def generate_query_embedding(query, query_type, model, processor, tokenizer, model_type):
|
164 |
+
if query_type == "text":
|
165 |
+
return generate_text_embedding(model, tokenizer, query, model_type)
|
166 |
+
elif query_type == "image":
|
167 |
+
return generate_image_embedding(model, processor, query, model_type)
|
168 |
+
else:
|
169 |
+
raise ValueError("Invalid query type. Choose 'text' or 'image'.")
|
utils/load_models.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import open_clip
|
3 |
+
from transformers import CLIPProcessor, CLIPModel
|
4 |
+
|
5 |
+
fclip_model = None
|
6 |
+
fclip_processor = None
|
7 |
+
siglip_model = None
|
8 |
+
siglip_tokenizer = None
|
9 |
+
|
10 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
+
|
12 |
+
fclip_model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
|
13 |
+
fclip_processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
|
14 |
+
|
15 |
+
siglip_model, siglip_preprocess_train, siglip_preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
16 |
+
siglip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
17 |
+
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
fclip_model.to(device)
|
20 |
+
siglip_model.to(device)
|
utils/refine_metadata.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from transformers import BitsAndBytesConfig, pipeline
|
4 |
+
|
5 |
+
def ask_llava(metadata, image_path):
|
6 |
+
"""
|
7 |
+
Function to get the image description using LLaVA.
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Unpack metadata
|
11 |
+
category = metadata.get('category', '')
|
12 |
+
subcategory = metadata.get('subcategory', '')
|
13 |
+
material = metadata.get('material', '')
|
14 |
+
gender = metadata.get('gender', '')
|
15 |
+
brand = metadata.get('brand', '')
|
16 |
+
name = metadata.get('name', '')
|
17 |
+
|
18 |
+
# Build the prompt for LLaVA
|
19 |
+
image = Image.open(image_path)
|
20 |
+
#prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, use your knowledge of fashion trends, styles, colors, gender preferences and brand information as well as your ability to describe, analyze and understand the image of the item to refine the metadata. Your goal is to improve the embedding process for models like FashionCLIP and MARGO-FashionSigLip by creating a more nuanced and detailed description that would boost the performance of the models. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name} - Description: {description} Refine and expand the metadata by incorporating information from the image and about the fashion item's style, cut, pattern, color scheme, brand, and any notable details. Include insights on current fashion trends and how the item fits within those trends. Be mindful that the it should be too around 77 tokens only, therefore, try to be concise and keep the description direct and useful for text to image and text to text search. Return the refined metadata as a single paragraph.\nASSISTANT:"""
|
21 |
+
prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, return an enhanced metadata structured in a single sentence with each field separated by a comma (do not include the field name, just use the same order). Keep it very concise and simple but make it more unterstandle for embedding models that will be used for search purposes. Also do a color analysis and add an extra field for the color of the item. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name}.\nASSISTANT:"""
|
22 |
+
|
23 |
+
# Generate description
|
24 |
+
outputs = img2text_pipeline(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
|
25 |
+
|
26 |
+
description = outputs[0]["generated_text"]
|
27 |
+
|
28 |
+
description = description.split("ASSISTANT: ")
|
29 |
+
|
30 |
+
return description[1]
|
31 |
+
|
32 |
+
def refine_metadata(catalog, column):
|
33 |
+
catalog[column] = ""
|
34 |
+
|
35 |
+
# Iterate over the DataFrame and process each item
|
36 |
+
for index, row in catalog.iterrows():
|
37 |
+
|
38 |
+
metadata = {
|
39 |
+
'category': row['L1'],
|
40 |
+
'subcategory': row['L2'],
|
41 |
+
'material': row['MaterialName'],
|
42 |
+
'gender': row['Gender'],
|
43 |
+
'brand': row['BrandName'],
|
44 |
+
'name': row['Name'],
|
45 |
+
'description': row['Description']
|
46 |
+
}
|
47 |
+
|
48 |
+
# Ensure the image ID is converted to a string
|
49 |
+
#image_path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg"
|
50 |
+
image_path = "/images/" + str(row["Id"]) + ".jpg"
|
51 |
+
|
52 |
+
# Generate the image description using LLaVA
|
53 |
+
refined_metadata = refine_metadata(metadata, image_path)
|
54 |
+
|
55 |
+
# Store results back in the DataFrame
|
56 |
+
catalog.at[index, column] = refined_metadata
|
57 |
+
|
58 |
+
return catalog
|
59 |
+
|
60 |
+
img2text_pipeline = None
|
61 |
+
|
62 |
+
quantization_config = BitsAndBytesConfig(
|
63 |
+
load_in_4bit=True,
|
64 |
+
bnb_4bit_compute_dtype=torch.float16
|
65 |
+
)
|
66 |
+
|
67 |
+
model_id = "llava-hf/llava-1.5-7b-hf"
|
68 |
+
|
69 |
+
if torch.cuda.is_available():
|
70 |
+
img2text_pipeline = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
|
71 |
+
else:
|
72 |
+
img2text_pipeline = pipeline("image-to-text", model=model_id)
|
utils/search_functions.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from utils.vector_database import search_in_milvus, fashionclip_collection, fashionsiglip_collection
|
3 |
+
from utils.embedding_generation import generate_query_embedding
|
4 |
+
from utils.load_models import fclip_model, fclip_processor
|
5 |
+
from utils.load_models import siglip_model, siglip_preprocess_val, siglip_tokenizer
|
6 |
+
|
7 |
+
# Function to dynamically select the Milvus collection and search field
|
8 |
+
def get_milvus_collection_and_field(model_type, embedding_type):
|
9 |
+
# Define mapping of model and embedding types to collections and fields
|
10 |
+
if model_type == "fashionCLIP":
|
11 |
+
collection = fashionclip_collection
|
12 |
+
if embedding_type == "text":
|
13 |
+
search_field = "text_embedding"
|
14 |
+
elif embedding_type == "image":
|
15 |
+
search_field = "image_embedding"
|
16 |
+
elif embedding_type == "average":
|
17 |
+
search_field = "avg_embedding"
|
18 |
+
elif embedding_type == "weighted average":
|
19 |
+
search_field = "weighted_avg_embedding"
|
20 |
+
elif model_type == "fashionSigLIP":
|
21 |
+
collection = fashionsiglip_collection
|
22 |
+
if embedding_type == "text":
|
23 |
+
search_field = "text_embedding"
|
24 |
+
elif embedding_type == "image":
|
25 |
+
search_field = "image_embedding"
|
26 |
+
elif embedding_type == "average":
|
27 |
+
search_field = "avg_embedding"
|
28 |
+
elif embedding_type == "weighted average":
|
29 |
+
search_field = "weighted_avg_embedding"
|
30 |
+
else:
|
31 |
+
raise ValueError("Invalid model type. Choose 'fashionCLIP' or 'fashionSigLIP'.")
|
32 |
+
|
33 |
+
return collection, search_field
|
34 |
+
|
35 |
+
# Function to handle the complete search flow
|
36 |
+
def search(query, query_type, model_type, embedding_type):
|
37 |
+
# Step 1: Generate the query embedding based on the user input and model type
|
38 |
+
if model_type == "fashionCLIP":
|
39 |
+
query_embedding = generate_query_embedding(query, query_type, fclip_model, fclip_processor, fclip_processor, "fashionCLIP")
|
40 |
+
elif model_type == "fashionSigLIP":
|
41 |
+
query_embedding = generate_query_embedding(query, query_type, siglip_model, siglip_preprocess_val, siglip_tokenizer, "fashionSigLIP")
|
42 |
+
|
43 |
+
# Step 2: Get the appropriate Milvus collection and search field
|
44 |
+
collection, search_field = get_milvus_collection_and_field(model_type, embedding_type)
|
45 |
+
|
46 |
+
# Step 3: Perform search in Milvus using the query embedding
|
47 |
+
search_results = search_in_milvus(collection, search_field, query_embedding, top_k=10)
|
48 |
+
|
49 |
+
# Step 4: Extract images, similarity scores, and metadata from the search results
|
50 |
+
images = [result['image'] for result in search_results]
|
51 |
+
scores = [result['similarity_score'] for result in search_results]
|
52 |
+
metadata = [result['metadata'] for result in search_results]
|
53 |
+
|
54 |
+
return images, scores, metadata
|
55 |
+
|
56 |
+
# Function to run the search and get results for both models
|
57 |
+
def run_search(query_type, embedding_type, query_input_text, query_input_image):
|
58 |
+
if query_type == "text":
|
59 |
+
query = query_input_text
|
60 |
+
else:
|
61 |
+
query = query_input_image
|
62 |
+
|
63 |
+
# Perform search for FashionCLIP
|
64 |
+
fclip_images, fclip_scores, fclip_metadata = search(query, query_type, "fashionCLIP", embedding_type)
|
65 |
+
|
66 |
+
# Perform search for MARGO-FashionSigLip
|
67 |
+
siglip_images, siglip_scores, siglip_metadata = search(query, query_type, "fashionSigLIP", embedding_type)
|
68 |
+
|
69 |
+
# Convert scores and metadata into a pandas DataFrame for each model
|
70 |
+
fclip_results_df = pd.DataFrame({
|
71 |
+
"Score": fclip_scores,
|
72 |
+
"Metadata": fclip_metadata,
|
73 |
+
})
|
74 |
+
|
75 |
+
siglip_results_df = pd.DataFrame({
|
76 |
+
"Score": siglip_scores,
|
77 |
+
"Metadata": siglip_metadata,
|
78 |
+
})
|
79 |
+
|
80 |
+
return fclip_images, fclip_results_df, siglip_images, siglip_results_df
|