rfmantoan commited on
Commit
27b3217
·
1 Parent(s): 8b39863
utils/data_preprocessing.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ def load_data(catalog):
4
+ catalog = pd.read_excel('catalog_1k.xlsx')
5
+ return catalog
6
+
7
+ def preprocess_data(catalog):
8
+ # Clean description
9
+ catalog['Description'] = catalog['Description'].str.replace('\n', '')
10
+
11
+ # Id column to integer
12
+ catalog['Id'] = pd.to_numeric(catalog['Id'], errors='coerce').astype('Int64')
13
+
14
+ # Map gender
15
+ catalog['Gender'] = catalog['Gender'].map({1: 'Women', 2: 'Men', 3: 'Unisex'})
16
+
17
+ # Drop sub-sub-categories
18
+ catalog = catalog.drop(['L3'], axis=1)
19
+
20
+ # Drop items without gender
21
+ catalog = catalog.dropna(subset=['Gender'])
22
+
23
+ # Use best image link
24
+ catalog['Image'] = catalog['Image'].str.split(',').str[-1]
25
+
26
+ # Convert the columns to strings before joining them
27
+ catalog["SimpleMetadata"] = catalog[["L1", "L2", "Gender", "MaterialName", "BrandName", "Name"]].astype(str).agg(', '.join, axis=1)
28
+
29
+ return catalog
utils/embedding_generation.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from utils.load_models import fclip_model, fclip_processor
4
+ from utils.load_models import siglip_model, siglip_preprocess_train, siglip_preprocess_val, siglip_tokenizer
5
+
6
+ def get_info(catalog, column):
7
+ image_paths = []
8
+ text_descriptions = []
9
+
10
+ for index, row in catalog.iterrows():
11
+ path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg"
12
+ image_paths.append(path)
13
+ text_descriptions.append(row[column])
14
+
15
+ return image_paths, text_descriptions
16
+
17
+ def normalize_embedding(embedding):
18
+ norm = torch.norm(embedding, p=2, dim=-1, keepdim=True).item() # Get the norm before normalization
19
+ embedding = embedding / norm
20
+ return embedding.detach().cpu().numpy()
21
+
22
+ def normalize_embeddings(embeddings):
23
+ norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
24
+ normalized_embeddings = embeddings / norm
25
+ return normalized_embeddings
26
+
27
+ def generate_fclip_embeddings(image_paths, texts, batch_size, alpha):
28
+ image_embeds_list = []
29
+ text_embeds_list = []
30
+
31
+ # Batch processing loop
32
+ for i in range(0, len(image_paths), batch_size):
33
+ batch_image_paths = image_paths[i:i + batch_size]
34
+ batch_texts = texts[i:i + batch_size]
35
+
36
+ # Load and preprocess batch of images and texts
37
+ images = [Image.open(path).convert("RGB") for path in batch_image_paths]
38
+
39
+ # Set the maximum sequence length to 77 to match the position embeddings
40
+ inputs = fclip_processor(text=batch_texts, images=images, return_tensors="pt", padding=True, truncation=True, max_length=77)
41
+
42
+ # Move inputs to the GPU
43
+ if torch.cuda.is_available():
44
+ inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to GPU
45
+
46
+ # Generate embeddings
47
+ with torch.no_grad():
48
+ outputs = fclip_model(**inputs)
49
+
50
+ image_embeds_list.append(outputs.image_embeds)
51
+ text_embeds_list.append(outputs.text_embeds)
52
+
53
+ # Concatenate all embeddings
54
+ image_embeds = torch.cat(image_embeds_list, dim=0)
55
+ text_embeds = torch.cat(text_embeds_list, dim=0)
56
+
57
+ # Normalize embeddings
58
+ image_embeds = normalize_embeddings(image_embeds)
59
+ text_embeds = normalize_embeddings(text_embeds)
60
+
61
+ # Average embeddings
62
+ avg_embeds = (image_embeds + text_embeds) / 2
63
+ weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
64
+ avg_embeds = normalize_embeddings(avg_embeds)
65
+ weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
66
+
67
+ return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
68
+
69
+ def generate_siglip_embeddings(image_paths, texts, batch_size, alpha):
70
+ image_embeds_list = []
71
+ text_embeds_list = []
72
+
73
+ # Batch processing loop
74
+ for i in range(0, len(image_paths), batch_size):
75
+ batch_image_paths = image_paths[i:i + batch_size]
76
+ batch_texts = texts[i:i + batch_size]
77
+
78
+ # Load and preprocess batch of images and texts
79
+ images = [siglip_preprocess_val(Image.open(image_path).convert('RGB')).unsqueeze(0) for image_path in batch_image_paths]
80
+ images = torch.cat(images)
81
+
82
+ tokens = siglip_tokenizer(batch_texts)
83
+
84
+ # Move images to the same device as the model weights (GPU if available)
85
+ if torch.cuda.is_available():
86
+ images = images.cuda()
87
+ tokens = tokens.cuda()
88
+
89
+ # Generate embeddings
90
+ with torch.no_grad():
91
+ image_embeddings_batch = siglip_model.encode_image(images)
92
+ text_embeddings_batch = siglip_model.encode_text(tokens)
93
+
94
+ # Store embeddings
95
+ image_embeds_list.append(image_embeddings_batch)
96
+ text_embeds_list.append(text_embeddings_batch)
97
+
98
+ # Concatenate all embeddings
99
+ image_embeds = torch.cat(image_embeds_list, dim=0)
100
+ text_embeds = torch.cat(text_embeds_list, dim=0)
101
+
102
+ # Normalize embeddings
103
+ image_embeds = normalize_embeddings(image_embeds)
104
+ text_embeds = normalize_embeddings(text_embeds)
105
+
106
+ # Average embeddings
107
+ avg_embeds = (image_embeds + text_embeds) / 2
108
+ weighted_avg_embeds = alpha * image_embeds + (1 - alpha) * text_embeds
109
+ avg_embeds = normalize_embeddings(avg_embeds)
110
+ weighted_avg_embeds = normalize_embeddings(weighted_avg_embeds)
111
+
112
+ return image_embeds.cpu().numpy(), text_embeds.cpu().numpy(), avg_embeds.cpu().numpy(), weighted_avg_embeds.cpu().numpy()
113
+
114
+ # Function to process text embedding for any model
115
+ def generate_text_embedding(model, tokenizer, query, model_type):
116
+ if model_type == "fashionCLIP":
117
+ # Process the text with the tokenizer and move to GPU
118
+ inputs = tokenizer(text=query, return_tensors="pt", padding=True, truncation=True, max_length=77)
119
+
120
+ if torch.cuda.is_available():
121
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
122
+
123
+ # Get text embedding from the model
124
+ text_embed = model.get_text_features(**inputs)
125
+ elif model_type == "fashionSigLIP":
126
+ tokens = tokenizer(query)
127
+
128
+ # Tokenize text and move to GPU
129
+ if torch.cuda.is_available():
130
+ tokens = tokens.to("cuda")
131
+
132
+ # Get text embedding from the model
133
+ text_embed = model.encode_text(tokens)
134
+
135
+ return normalize_embedding(text_embed)
136
+
137
+ # Function to process image embedding for any model
138
+ def generate_image_embedding(model, processor, image_path, model_type):
139
+ image = Image.open(image_path).convert("RGB")
140
+
141
+ if model_type == "fashionCLIP":
142
+ # Preprocess image for FashionCLIP and move to GPU
143
+ inputs = processor(images=image, return_tensors="pt")
144
+
145
+ if torch.cuda.is_available():
146
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
147
+
148
+ # Get image embedding from the model
149
+ image_embed = model.get_image_features(**inputs)
150
+ elif model_type == "fashionSigLIP":
151
+ # Preprocess image for SigLip and move to GPU
152
+ image_tensor = processor(image).unsqueeze(0)
153
+
154
+ if torch.cuda.is_available():
155
+ image_tensor = image_tensor.to("cuda")
156
+
157
+ # Get image embedding from the model
158
+ image_embed = model.encode_image(image_tensor)
159
+
160
+ return normalize_embedding(image_embed)
161
+
162
+ # Unified function to generate embeddings for both models and query types
163
+ def generate_query_embedding(query, query_type, model, processor, tokenizer, model_type):
164
+ if query_type == "text":
165
+ return generate_text_embedding(model, tokenizer, query, model_type)
166
+ elif query_type == "image":
167
+ return generate_image_embedding(model, processor, query, model_type)
168
+ else:
169
+ raise ValueError("Invalid query type. Choose 'text' or 'image'.")
utils/load_models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ from transformers import CLIPProcessor, CLIPModel
4
+
5
+ fclip_model = None
6
+ fclip_processor = None
7
+ siglip_model = None
8
+ siglip_tokenizer = None
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ fclip_model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
13
+ fclip_processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
14
+
15
+ siglip_model, siglip_preprocess_train, siglip_preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
16
+ siglip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
17
+
18
+ if torch.cuda.is_available():
19
+ fclip_model.to(device)
20
+ siglip_model.to(device)
utils/refine_metadata.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import BitsAndBytesConfig, pipeline
4
+
5
+ def ask_llava(metadata, image_path):
6
+ """
7
+ Function to get the image description using LLaVA.
8
+ """
9
+
10
+ # Unpack metadata
11
+ category = metadata.get('category', '')
12
+ subcategory = metadata.get('subcategory', '')
13
+ material = metadata.get('material', '')
14
+ gender = metadata.get('gender', '')
15
+ brand = metadata.get('brand', '')
16
+ name = metadata.get('name', '')
17
+
18
+ # Build the prompt for LLaVA
19
+ image = Image.open(image_path)
20
+ #prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, use your knowledge of fashion trends, styles, colors, gender preferences and brand information as well as your ability to describe, analyze and understand the image of the item to refine the metadata. Your goal is to improve the embedding process for models like FashionCLIP and MARGO-FashionSigLip by creating a more nuanced and detailed description that would boost the performance of the models. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name} - Description: {description} Refine and expand the metadata by incorporating information from the image and about the fashion item's style, cut, pattern, color scheme, brand, and any notable details. Include insights on current fashion trends and how the item fits within those trends. Be mindful that the it should be too around 77 tokens only, therefore, try to be concise and keep the description direct and useful for text to image and text to text search. Return the refined metadata as a single paragraph.\nASSISTANT:"""
21
+ prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, return an enhanced metadata structured in a single sentence with each field separated by a comma (do not include the field name, just use the same order). Keep it very concise and simple but make it more unterstandle for embedding models that will be used for search purposes. Also do a color analysis and add an extra field for the color of the item. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name}.\nASSISTANT:"""
22
+
23
+ # Generate description
24
+ outputs = img2text_pipeline(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
25
+
26
+ description = outputs[0]["generated_text"]
27
+
28
+ description = description.split("ASSISTANT: ")
29
+
30
+ return description[1]
31
+
32
+ def refine_metadata(catalog, column):
33
+ catalog[column] = ""
34
+
35
+ # Iterate over the DataFrame and process each item
36
+ for index, row in catalog.iterrows():
37
+
38
+ metadata = {
39
+ 'category': row['L1'],
40
+ 'subcategory': row['L2'],
41
+ 'material': row['MaterialName'],
42
+ 'gender': row['Gender'],
43
+ 'brand': row['BrandName'],
44
+ 'name': row['Name'],
45
+ 'description': row['Description']
46
+ }
47
+
48
+ # Ensure the image ID is converted to a string
49
+ #image_path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg"
50
+ image_path = "/images/" + str(row["Id"]) + ".jpg"
51
+
52
+ # Generate the image description using LLaVA
53
+ refined_metadata = refine_metadata(metadata, image_path)
54
+
55
+ # Store results back in the DataFrame
56
+ catalog.at[index, column] = refined_metadata
57
+
58
+ return catalog
59
+
60
+ img2text_pipeline = None
61
+
62
+ quantization_config = BitsAndBytesConfig(
63
+ load_in_4bit=True,
64
+ bnb_4bit_compute_dtype=torch.float16
65
+ )
66
+
67
+ model_id = "llava-hf/llava-1.5-7b-hf"
68
+
69
+ if torch.cuda.is_available():
70
+ img2text_pipeline = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
71
+ else:
72
+ img2text_pipeline = pipeline("image-to-text", model=model_id)
utils/search_functions.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from utils.vector_database import search_in_milvus, fashionclip_collection, fashionsiglip_collection
3
+ from utils.embedding_generation import generate_query_embedding
4
+ from utils.load_models import fclip_model, fclip_processor
5
+ from utils.load_models import siglip_model, siglip_preprocess_val, siglip_tokenizer
6
+
7
+ # Function to dynamically select the Milvus collection and search field
8
+ def get_milvus_collection_and_field(model_type, embedding_type):
9
+ # Define mapping of model and embedding types to collections and fields
10
+ if model_type == "fashionCLIP":
11
+ collection = fashionclip_collection
12
+ if embedding_type == "text":
13
+ search_field = "text_embedding"
14
+ elif embedding_type == "image":
15
+ search_field = "image_embedding"
16
+ elif embedding_type == "average":
17
+ search_field = "avg_embedding"
18
+ elif embedding_type == "weighted average":
19
+ search_field = "weighted_avg_embedding"
20
+ elif model_type == "fashionSigLIP":
21
+ collection = fashionsiglip_collection
22
+ if embedding_type == "text":
23
+ search_field = "text_embedding"
24
+ elif embedding_type == "image":
25
+ search_field = "image_embedding"
26
+ elif embedding_type == "average":
27
+ search_field = "avg_embedding"
28
+ elif embedding_type == "weighted average":
29
+ search_field = "weighted_avg_embedding"
30
+ else:
31
+ raise ValueError("Invalid model type. Choose 'fashionCLIP' or 'fashionSigLIP'.")
32
+
33
+ return collection, search_field
34
+
35
+ # Function to handle the complete search flow
36
+ def search(query, query_type, model_type, embedding_type):
37
+ # Step 1: Generate the query embedding based on the user input and model type
38
+ if model_type == "fashionCLIP":
39
+ query_embedding = generate_query_embedding(query, query_type, fclip_model, fclip_processor, fclip_processor, "fashionCLIP")
40
+ elif model_type == "fashionSigLIP":
41
+ query_embedding = generate_query_embedding(query, query_type, siglip_model, siglip_preprocess_val, siglip_tokenizer, "fashionSigLIP")
42
+
43
+ # Step 2: Get the appropriate Milvus collection and search field
44
+ collection, search_field = get_milvus_collection_and_field(model_type, embedding_type)
45
+
46
+ # Step 3: Perform search in Milvus using the query embedding
47
+ search_results = search_in_milvus(collection, search_field, query_embedding, top_k=10)
48
+
49
+ # Step 4: Extract images, similarity scores, and metadata from the search results
50
+ images = [result['image'] for result in search_results]
51
+ scores = [result['similarity_score'] for result in search_results]
52
+ metadata = [result['metadata'] for result in search_results]
53
+
54
+ return images, scores, metadata
55
+
56
+ # Function to run the search and get results for both models
57
+ def run_search(query_type, embedding_type, query_input_text, query_input_image):
58
+ if query_type == "text":
59
+ query = query_input_text
60
+ else:
61
+ query = query_input_image
62
+
63
+ # Perform search for FashionCLIP
64
+ fclip_images, fclip_scores, fclip_metadata = search(query, query_type, "fashionCLIP", embedding_type)
65
+
66
+ # Perform search for MARGO-FashionSigLip
67
+ siglip_images, siglip_scores, siglip_metadata = search(query, query_type, "fashionSigLIP", embedding_type)
68
+
69
+ # Convert scores and metadata into a pandas DataFrame for each model
70
+ fclip_results_df = pd.DataFrame({
71
+ "Score": fclip_scores,
72
+ "Metadata": fclip_metadata,
73
+ })
74
+
75
+ siglip_results_df = pd.DataFrame({
76
+ "Score": siglip_scores,
77
+ "Metadata": siglip_metadata,
78
+ })
79
+
80
+ return fclip_images, fclip_results_df, siglip_images, siglip_results_df