wisdom196473 commited on
Commit
1fd9ec2
·
0 Parent(s):

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ .venv/
5
+ venv/
6
+ .idea/
7
+ .vscode/
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amazon E-commerce Visual Assistant
2
+
3
+ A multimodal AI assistant that helps users search and explore Amazon products through natural language and image-based interactions.
4
+
5
+ ## Features
6
+
7
+ - Text and image-based product search
8
+ - Product comparisons and recommendations
9
+ - Visual product recognition
10
+ - Detailed product information retrieval
11
+ - Price analysis and comparison
12
+
13
+ ## Technologies Used
14
+
15
+ - FashionCLIP for visual understanding
16
+ - Mistral-7B Language Model for text generation
17
+ - FAISS for efficient similarity search
18
+ - Streamlit for the user interface
19
+
20
+ ## Setup and Installation
21
+
22
+ 1. Clone the repository:
23
+ ```bash
24
+ git clone https://github.com/wisdom196473/amazon-multimodal-product-assistant.git
25
+ cd amazon-multimodal-product-assistant
26
+ ```
27
+
28
+ 2. Install dependencies:
29
+ ```bash
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ 3. Run the application:
34
+ ```bash
35
+ streamlit run amazon_app.py
36
+ ```
37
+
38
+ ## Project Structure
39
+
40
+ - `amazon_app.py`: Main Streamlit application
41
+ - `model.py`: Core AI model implementations
42
+ - `requirements.txt`: Project dependencies
43
+
44
+ ## License
45
+
46
+ MIT License
.ipynb_checkpoints/Vision_AI-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard libraries
2
+ import os
3
+ import io
4
+ import json
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import Dict, List, Tuple, Optional
8
+ import requests
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ from io import BytesIO
12
+
13
+ # Deep learning frameworks
14
+ import torch
15
+ from torch.cuda.amp import autocast
16
+ import open_clip
17
+
18
+ # Hugging Face
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ BitsAndBytesConfig,
23
+ pipeline,
24
+ PreTrainedModel,
25
+ PreTrainedTokenizer
26
+ )
27
+ from huggingface_hub import hf_hub_download
28
+ from langchain.prompts import PromptTemplate
29
+
30
+ # Vector database
31
+ import faiss
32
+
33
+ # Type hints
34
+ from typing import Dict, List, Tuple, Optional, Union
35
+
36
+ # Global variables
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ clip_model: Optional[PreTrainedModel] = None
39
+ clip_preprocess: Optional[callable] = None
40
+ clip_tokenizer: Optional[PreTrainedTokenizer] = None
41
+ llm_tokenizer: Optional[PreTrainedTokenizer] = None
42
+ llm_model: Optional[PreTrainedModel] = None
43
+ product_df: Optional[pd.DataFrame] = None
44
+ metadata: Dict = {}
45
+ embeddings_df: Optional[pd.DataFrame] = None
46
+ text_faiss: Optional[object] = None
47
+ image_faiss: Optional[object] = None
48
+
49
+ def initialize_models() -> bool:
50
+ """
51
+ Initialize CLIP and LLM models with proper error handling and GPU optimization.
52
+
53
+ Returns:
54
+ bool: True if initialization successful, raises RuntimeError otherwise
55
+ """
56
+ global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
57
+
58
+ try:
59
+ print(f"Initializing models on device: {device}")
60
+
61
+ # Initialize CLIP model with error handling
62
+ try:
63
+ clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
64
+ 'hf-hub:Marqo/marqo-fashionCLIP'
65
+ )
66
+ clip_model = clip_model.to(device)
67
+ clip_model.eval()
68
+ clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
69
+ print("CLIP model initialized successfully")
70
+ except Exception as e:
71
+ raise RuntimeError(f"Failed to initialize CLIP model: {str(e)}")
72
+
73
+ # Initialize LLM with optimized settings
74
+ try:
75
+ model_name = "mistralai/Mistral-7B-v0.1"
76
+ quantization_config = BitsAndBytesConfig(
77
+ load_in_4bit=True,
78
+ bnb_4bit_compute_dtype=torch.float16,
79
+ bnb_4bit_use_double_quant=True,
80
+ bnb_4bit_quant_type="nf4"
81
+ )
82
+
83
+ llm_tokenizer = AutoTokenizer.from_pretrained(
84
+ model_name,
85
+ padding_side="left",
86
+ truncation_side="left"
87
+ )
88
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
89
+
90
+ llm_model = AutoModelForCausalLM.from_pretrained(
91
+ model_name,
92
+ quantization_config=quantization_config,
93
+ device_map="auto",
94
+ torch_dtype=torch.float16
95
+ )
96
+ llm_model.eval()
97
+ print("LLM initialized successfully")
98
+ except Exception as e:
99
+ raise RuntimeError(f"Failed to initialize LLM: {str(e)}")
100
+
101
+ return True
102
+
103
+ except Exception as e:
104
+ raise RuntimeError(f"Model initialization failed: {str(e)}")
105
+
106
+ # Data loading
107
+ def load_data() -> bool:
108
+ """
109
+ Load and initialize all required data with enhanced metadata support and error handling.
110
+
111
+ Returns:
112
+ bool: True if data loading successful, raises RuntimeError otherwise
113
+ """
114
+ global product_df, metadata, embeddings_df, text_faiss, image_faiss
115
+
116
+ try:
117
+ print("Loading product data...")
118
+ # Load cleaned product data
119
+ try:
120
+ cleaned_data_path = hf_hub_download(
121
+ repo_id="chen196473/amazon_product_2020_cleaned",
122
+ filename="amazon_cleaned.parquet",
123
+ repo_type="dataset"
124
+ )
125
+ product_df = pd.read_parquet(cleaned_data_path)
126
+
127
+ # Add validation columns
128
+ product_df['Has_Valid_Image'] = product_df['Processed Image'].notna()
129
+ product_df['Image_Status'] = product_df['Has_Valid_Image'].map({
130
+ True: 'valid',
131
+ False: 'invalid'
132
+ })
133
+ print("Product data loaded successfully")
134
+ except Exception as e:
135
+ raise RuntimeError(f"Failed to load product data: {str(e)}")
136
+
137
+ # Load enhanced metadata
138
+ print("Loading metadata...")
139
+ try:
140
+ metadata = {}
141
+ metadata_files = [
142
+ 'base_metadata.json',
143
+ 'category_index.json',
144
+ 'price_range_index.json',
145
+ 'keyword_index.json',
146
+ 'brand_index.json',
147
+ 'product_name_index.json'
148
+ ]
149
+
150
+ for file in metadata_files:
151
+ file_path = hf_hub_download(
152
+ repo_id="chen196473/amazon_product_2020_metadata",
153
+ filename=file,
154
+ repo_type="dataset"
155
+ )
156
+ with open(file_path, 'r') as f:
157
+ index_name = file.replace('.json', '')
158
+ data = json.load(f)
159
+
160
+ if index_name == 'base_metadata':
161
+ data = {item['Uniq_Id']: item for item in data}
162
+ for item in data.values():
163
+ if 'Keywords' in item:
164
+ item['Keywords'] = set(item['Keywords'])
165
+
166
+ metadata[index_name] = data
167
+ print("Metadata loaded successfully")
168
+ except Exception as e:
169
+ raise RuntimeError(f"Failed to load metadata: {str(e)}")
170
+
171
+ # Load embeddings
172
+ print("Loading embeddings...")
173
+ try:
174
+ text_embeddings_dict, image_embeddings_dict = load_embeddings_from_huggingface(
175
+ "chen196473/amazon_vector_database"
176
+ )
177
+
178
+ # Create embeddings DataFrame
179
+ embeddings_df = pd.DataFrame({
180
+ 'text_embeddings': list(text_embeddings_dict.values()),
181
+ 'image_embeddings': list(image_embeddings_dict.values()),
182
+ 'Uniq_Id': list(text_embeddings_dict.keys())
183
+ })
184
+
185
+ # Merge with product data
186
+ product_df = product_df.merge(
187
+ embeddings_df,
188
+ left_on='Uniq Id',
189
+ right_on='Uniq_Id',
190
+ how='inner'
191
+ )
192
+ print("Embeddings loaded and merged successfully")
193
+
194
+ # Create FAISS indexes
195
+ print("Creating FAISS indexes...")
196
+ try:
197
+ create_faiss_indexes(text_embeddings_dict, image_embeddings_dict)
198
+ print("FAISS indexes created successfully")
199
+
200
+ # Verify FAISS indexes are properly initialized and contain data
201
+ if text_faiss is None or image_faiss is None:
202
+ raise RuntimeError("FAISS indexes were not properly initialized")
203
+
204
+ # Test a simple query to verify indexes are working
205
+ test_query = "test"
206
+ tokens = clip_tokenizer(test_query).to(device)
207
+ with torch.no_grad():
208
+ text_embedding = clip_model.encode_text(tokens)
209
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
210
+ text_embedding = text_embedding.cpu().numpy()
211
+
212
+ # Verify search works
213
+ test_results = text_faiss.search(text_embedding[0], k=1)
214
+ if not test_results:
215
+ raise RuntimeError("FAISS indexes are empty")
216
+
217
+ print("FAISS indexes verified successfully")
218
+
219
+ except Exception as e:
220
+ raise RuntimeError(f"Failed to create or verify FAISS indexes: {str(e)}")
221
+
222
+ except Exception as e:
223
+ raise RuntimeError(f"Failed to load embeddings: {str(e)}")
224
+
225
+ # Validate required columns
226
+ required_columns = [
227
+ 'Uniq Id', 'Product Name', 'Category', 'Selling Price',
228
+ 'Model Number', 'Image', 'Normalized Description'
229
+ ]
230
+ missing_cols = set(required_columns) - set(product_df.columns)
231
+ if missing_cols:
232
+ raise ValueError(f"Missing required columns: {missing_cols}")
233
+
234
+ # Add enhanced metadata fields
235
+ if 'Search_Text' not in product_df.columns:
236
+ product_df['Search_Text'] = product_df.apply(
237
+ lambda x: metadata['base_metadata'].get(x['Uniq Id'], {}).get('Search_Text', ''),
238
+ axis=1
239
+ )
240
+
241
+ # Final verification of loaded data
242
+ if product_df is None or product_df.empty:
243
+ raise RuntimeError("Product DataFrame is empty or not initialized")
244
+
245
+ if not metadata:
246
+ raise RuntimeError("Metadata dictionary is empty")
247
+
248
+ if embeddings_df is None or embeddings_df.empty:
249
+ raise RuntimeError("Embeddings DataFrame is empty or not initialized")
250
+
251
+ print("Data loading completed successfully")
252
+ return True
253
+
254
+ except Exception as e:
255
+ # Clean up any partially loaded data
256
+ product_df = None
257
+ metadata = {}
258
+ embeddings_df = None
259
+ text_faiss = None
260
+ image_faiss = None
261
+ raise RuntimeError(f"Data loading failed: {str(e)}")
262
+
263
+ def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
264
+ """
265
+ Load embeddings from Hugging Face repository with enhanced error handling.
266
+
267
+ Args:
268
+ repo_id (str): Hugging Face repository ID
269
+
270
+ Returns:
271
+ Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
272
+ """
273
+ print("Loading embeddings from Hugging Face...")
274
+ try:
275
+ file_path = hf_hub_download(
276
+ repo_id=repo_id,
277
+ filename="embeddings.parquet",
278
+ repo_type="dataset"
279
+ )
280
+ df = pd.read_parquet(file_path)
281
+
282
+ # Extract embedding columns
283
+ text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
284
+ image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
285
+
286
+ # Create embedding dictionaries
287
+ text_embeddings_dict = {
288
+ row['Uniq_Id']: row[text_cols].values.astype(np.float32)
289
+ for _, row in df.iterrows()
290
+ }
291
+ image_embeddings_dict = {
292
+ row['Uniq_Id']: row[image_cols].values.astype(np.float32)
293
+ for _, row in df.iterrows()
294
+ }
295
+
296
+ print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
297
+ return text_embeddings_dict, image_embeddings_dict
298
+
299
+ except Exception as e:
300
+ raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
301
+
302
+ # FAISS index creation
303
+ class MultiModalFAISSIndex:
304
+ def __init__(self, dimension, index_type='L2'):
305
+ import faiss
306
+ self.dimension = dimension
307
+ self.index = faiss.IndexFlatL2(dimension) if index_type == 'L2' else faiss.IndexFlatIP(dimension)
308
+ self.id_to_metadata = {}
309
+
310
+ def add_embeddings(self, embeddings, metadata_list):
311
+ import numpy as np
312
+ embeddings = np.array(embeddings).astype('float32')
313
+ self.index.add(embeddings)
314
+ for i, metadata in enumerate(metadata_list):
315
+ self.id_to_metadata[i] = metadata
316
+
317
+ def search(self, query_embedding, k=5):
318
+ import numpy as np
319
+ query_embedding = np.array([query_embedding]).astype('float32')
320
+ distances, indices = self.index.search(query_embedding, k)
321
+ results = []
322
+ for idx in indices[0]:
323
+ if idx in self.id_to_metadata:
324
+ results.append(self.id_to_metadata[idx])
325
+ return results
326
+
327
+ def create_faiss_indexes(text_embeddings_dict, image_embeddings_dict):
328
+ """Create FAISS indexes with error handling"""
329
+ global text_faiss, image_faiss
330
+
331
+ try:
332
+ # Get embedding dimension
333
+ text_dim = next(iter(text_embeddings_dict.values())).shape[0]
334
+ image_dim = next(iter(image_embeddings_dict.values())).shape[0]
335
+
336
+ # Create indexes
337
+ text_faiss = MultiModalFAISSIndex(text_dim)
338
+ image_faiss = MultiModalFAISSIndex(image_dim)
339
+
340
+ # Prepare text embeddings and metadata
341
+ text_embeddings = []
342
+ text_metadata = []
343
+ for text_id, embedding in text_embeddings_dict.items():
344
+ if text_id in product_df['Uniq Id'].values:
345
+ product = product_df[product_df['Uniq Id'] == text_id].iloc[0]
346
+ text_embeddings.append(embedding)
347
+ text_metadata.append({
348
+ 'id': text_id,
349
+ 'description': product['Normalized Description'],
350
+ 'product_name': product['Product Name']
351
+ })
352
+
353
+ # Add text embeddings
354
+ if text_embeddings:
355
+ text_faiss.add_embeddings(text_embeddings, text_metadata)
356
+
357
+ # Prepare image embeddings and metadata
358
+ image_embeddings = []
359
+ image_metadata = []
360
+ for image_id, embedding in image_embeddings_dict.items():
361
+ if image_id in product_df['Uniq Id'].values:
362
+ product = product_df[product_df['Uniq Id'] == image_id].iloc[0]
363
+ image_embeddings.append(embedding)
364
+ image_metadata.append({
365
+ 'id': image_id,
366
+ 'image_url': product['Image'],
367
+ 'product_name': product['Product Name']
368
+ })
369
+
370
+ # Add image embeddings
371
+ if image_embeddings:
372
+ image_faiss.add_embeddings(image_embeddings, image_metadata)
373
+
374
+ return True
375
+
376
+ except Exception as e:
377
+ raise RuntimeError(f"Failed to create FAISS indexes: {str(e)}")
378
+
379
+ def get_few_shot_product_comparison_template():
380
+ return """Compare these specific products based on their actual features and specifications:
381
+
382
+ Example 1:
383
+ Question: Compare iPhone 13 and Samsung Galaxy S21
384
+ Answer: The iPhone 13 features a 6.1-inch Super Retina XDR display and dual 12MP cameras, while the Galaxy S21 has a 6.2-inch Dynamic AMOLED display and triple camera setup. Both phones offer 5G connectivity, but the iPhone uses A15 Bionic chip while S21 uses Snapdragon 888.
385
+
386
+ Example 2:
387
+ Question: Compare Amazon Echo Dot and Google Nest Mini
388
+ Answer: The Amazon Echo Dot features Alexa voice assistant and a 1.6-inch speaker, while the Google Nest Mini comes with Google Assistant and a 40mm driver. Both devices offer smart home control and music playback, but differ in their ecosystem integration.
389
+
390
+ Current Question: {query}
391
+ Context: {context}
392
+
393
+ Guidelines:
394
+ - Only compare the specific products mentioned in the query
395
+ - Focus on actual product features and specifications
396
+ - Keep response to 2-3 clear sentences
397
+ - Ensure factual accuracy based on the context provided
398
+
399
+ Answer:"""
400
+
401
+ def get_zero_shot_product_template():
402
+ return """You are a product information specialist. Describe only the specific product's actual features based on the provided context.
403
+
404
+ Context: {context}
405
+
406
+ Question: {query}
407
+
408
+ Guidelines:
409
+ - Only describe the specific product mentioned in the query
410
+ - Focus on actual features and specifications from the context
411
+ - Keep response to 2-3 factual sentences
412
+ - Ensure information accuracy
413
+
414
+ Answer:"""
415
+
416
+ def get_zero_shot_image_template():
417
+ return """Analyze this product image and provide a concise description:
418
+
419
+ Product Information:
420
+ {context}
421
+
422
+ Guidelines:
423
+ - Describe the main product features and intended use
424
+ - Highlight key specifications and materials
425
+ - Keep response to 2-3 sentences
426
+ - Focus on practical information
427
+
428
+ Answer:"""
429
+
430
+ # Image processing functions
431
+ def process_image(image):
432
+ try:
433
+ if isinstance(image, str):
434
+ response = requests.get(image)
435
+ image = Image.open(io.BytesIO(response.content))
436
+
437
+ processed_image = clip_preprocess(image).unsqueeze(0).to(device)
438
+
439
+ with torch.no_grad():
440
+ image_features = clip_model.encode_image(processed_image)
441
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
442
+
443
+ return image_features.cpu().numpy()
444
+ except Exception as e:
445
+ raise Exception(f"Error processing image: {str(e)}")
446
+
447
+ def load_image_from_url(url):
448
+ response = requests.get(url)
449
+ if response.status_code == 200:
450
+ return Image.open(io.BytesIO(response.content))
451
+ else:
452
+ raise Exception(f"Failed to fetch image from URL: {url}, Status Code: {response.status_code}")
453
+
454
+ # Context retrieval and enhancement
455
+ def filter_by_metadata(query, metadata_index):
456
+ relevant_products = set()
457
+
458
+ # Check category index
459
+ if 'category_index' in metadata_index:
460
+ categories = metadata_index['category_index']
461
+ for category in categories:
462
+ if any(term.lower() in category.lower() for term in query.split()):
463
+ relevant_products.update(categories[category])
464
+
465
+ # Check product name index
466
+ if 'product_name_index' in metadata_index:
467
+ product_names = metadata_index['product_name_index']
468
+ for term in query.split():
469
+ if term.lower() in product_names:
470
+ relevant_products.update(product_names[term.lower()])
471
+
472
+ # Check price ranges
473
+ price_terms = {'cheap', 'expensive', 'price', 'cost', 'affordable'}
474
+ if any(term in query.lower() for term in price_terms) and 'price_range_index' in metadata_index:
475
+ price_ranges = metadata_index['price_range_index']
476
+ for price_range in price_ranges:
477
+ relevant_products.update(price_ranges[price_range])
478
+
479
+ return relevant_products if relevant_products else None
480
+
481
+ def enhance_context_with_metadata(product, metadata_index):
482
+ """Enhanced context building using new metadata structure"""
483
+ # Access base_metadata using product ID directly since it's now a dictionary
484
+ base_metadata = metadata_index['base_metadata'].get(product['Uniq Id'])
485
+
486
+ if base_metadata:
487
+ # Get keywords and search text from enhanced metadata
488
+ keywords = base_metadata.get('Keywords', [])
489
+ search_text = base_metadata.get('Search_Text', '')
490
+
491
+ # Build enhanced description
492
+ description = []
493
+ description.append(f"Product Name: {base_metadata['Product_Name']}")
494
+ description.append(f"Category: {base_metadata['Category']}")
495
+ description.append(f"Price: ${base_metadata['Selling_Price']:.2f}")
496
+
497
+ # Add key features from normalized description
498
+ if 'Normalized_Description' in base_metadata:
499
+ features = []
500
+ for feature in base_metadata['Normalized_Description'].split('|'):
501
+ if ':' in feature:
502
+ key, value = feature.split(':', 1)
503
+ if not any(skip in key.lower() for skip in
504
+ ['uniq id', 'product url', 'specifications', 'asin']):
505
+ features.append(f"{key.strip()}: {value.strip()}")
506
+ if features:
507
+ description.append("Key Features:")
508
+ description.extend(features[:3])
509
+
510
+ # Add relevant keywords
511
+ if keywords:
512
+ description.append("Related Terms: " + ", ".join(list(keywords)[:5]))
513
+
514
+ return "\n".join(description)
515
+
516
+ return None
517
+
518
+ def retrieve_context(query, image=None, top_k=5):
519
+ """Enhanced context retrieval using both FAISS and metadata"""
520
+ # Initialize context lists
521
+ similar_items = []
522
+ context = []
523
+
524
+ if image is not None:
525
+ # Process image query
526
+ image_embedding = process_image(image)
527
+ image_embedding = image_embedding.reshape(1, -1)
528
+ similar_items = image_faiss.search(image_embedding[0], k=top_k)
529
+ else:
530
+ # Process text query with enhanced metadata filtering
531
+ relevant_products = filter_by_metadata(query, metadata)
532
+
533
+ tokens = clip_tokenizer(query).to(device)
534
+ with torch.no_grad():
535
+ text_embedding = clip_model.encode_text(tokens)
536
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
537
+ text_embedding = text_embedding.cpu().numpy()
538
+
539
+ # Get FAISS results
540
+ similar_items = text_faiss.search(text_embedding[0], k=top_k*2) # Get more results for filtering
541
+
542
+ # Filter results using metadata if available
543
+ if relevant_products:
544
+ similar_items = [item for item in similar_items if item['id'] in relevant_products][:top_k]
545
+
546
+ # Build enhanced context
547
+ for item in similar_items:
548
+ product = product_df[product_df['Uniq Id'] == item['id']].iloc[0]
549
+ enhanced_context = enhance_context_with_metadata(product, metadata)
550
+ if enhanced_context:
551
+ context.append(enhanced_context)
552
+
553
+ return "\n\n".join(context), similar_items
554
+
555
+ def display_product_images(similar_items, max_images=1):
556
+ displayed_images = []
557
+
558
+ for item in similar_items[:max_images]:
559
+ try:
560
+ # Get image URL from product data
561
+ image_url = item['Image'] if isinstance(item, pd.Series) else item.get('Image')
562
+ if not image_url:
563
+ continue
564
+
565
+ # Handle multiple image URLs
566
+ image_urls = image_url.split('|')
567
+ image_url = image_urls[0] # Take first image
568
+
569
+ # Load image
570
+ response = requests.get(image_url)
571
+ img = Image.open(BytesIO(response.content))
572
+
573
+ # Get product details
574
+ product_name = item['Product Name'] if isinstance(item, pd.Series) else item.get('product_name')
575
+ price = item['Selling Price'] if isinstance(item, pd.Series) else item.get('price', 0)
576
+
577
+ # Add to displayed images
578
+ displayed_images.append({
579
+ 'image': img,
580
+ 'product_name': product_name,
581
+ 'price': float(price)
582
+ })
583
+
584
+ except Exception as e:
585
+ print(f"Error processing item: {str(e)}")
586
+ continue
587
+
588
+ return displayed_images
589
+
590
+ def classify_query(query):
591
+ """Classify the type of query to determine the retrieval strategy."""
592
+ query_lower = query.lower()
593
+ if any(keyword in query_lower for keyword in ['compare', 'difference between']):
594
+ return 'comparison'
595
+ elif any(keyword in query_lower for keyword in ['show', 'picture', 'image', 'photo']):
596
+ return 'image_search'
597
+ else:
598
+ return 'product_info'
599
+
600
+ def boost_category_relevance(query, product, similarity_score):
601
+ query_terms = set(query.lower().split())
602
+ category_terms = set(product['Category'].lower().split())
603
+ category_overlap = len(query_terms & category_terms)
604
+ category_boost = 1 + (category_overlap * 0.2) # 20% boost per matching term
605
+ return similarity_score * category_boost
606
+
607
+ def hybrid_retrieval(query, top_k=5):
608
+ query_type = classify_query(query)
609
+
610
+ tokens = clip_tokenizer(query).to(device)
611
+ with torch.no_grad():
612
+ text_embedding = clip_model.encode_text(tokens)
613
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
614
+ text_embedding = text_embedding.cpu().numpy()
615
+
616
+ # First get text matches
617
+ text_results = text_faiss.search(text_embedding[0], k=top_k*2)
618
+
619
+ if query_type == 'image_search':
620
+ image_results = []
621
+ for item in text_results:
622
+ # Get original product with embeddings intact
623
+ product = product_df[product_df['Uniq Id'] == item['id']].iloc[0]
624
+ # Get image embeddings from embeddings_df instead
625
+ image_embedding = embeddings_df[embeddings_df['Uniq_Id'] == item['id']]['image_embeddings'].iloc[0]
626
+ similarity = np.dot(text_embedding.flatten(), image_embedding.flatten())
627
+ boosted_similarity = boost_category_relevance(query, product, similarity)
628
+ image_results.append((product, boosted_similarity))
629
+
630
+ image_results.sort(key=lambda x: x[1], reverse=True)
631
+ results = [item for item, _ in image_results[:top_k]]
632
+ else:
633
+ results = [product_df[product_df['Uniq Id'] == item['id']].iloc[0] for item in text_results[:top_k]]
634
+
635
+ return results, query_type
636
+
637
+
638
+ def fallback_text_search(query, top_k=10):
639
+ relevant_products = filter_by_metadata(query, metadata)
640
+ if not relevant_products:
641
+ # Check brand index specifically
642
+ if 'brand_index' in metadata:
643
+ query_terms = query.lower().split()
644
+ for term in query_terms:
645
+ if term in metadata['brand_index']:
646
+ relevant_products = set(metadata['brand_index'][term])
647
+ break
648
+
649
+ if relevant_products:
650
+ results = [product_df[product_df['Uniq Id'] == pid].iloc[0] for pid in list(relevant_products)[:top_k]]
651
+ else:
652
+ query_lower = query.lower()
653
+ results = product_df[
654
+ (product_df['Product Name'].str.lower().str.contains(query_lower)) |
655
+ (product_df['Category'].str.lower().str.contains(query_lower)) |
656
+ (product_df['Normalized Description'].str.lower().str.contains(query_lower))
657
+ ].head(top_k)
658
+
659
+ return results
660
+
661
+ def generate_rag_response(query, context, image=None):
662
+ """Enhanced RAG response generation"""
663
+ # Select template based on query type and metadata
664
+ if "compare" in query.lower() or "difference between" in query.lower() or "vs." in query.lower():
665
+ template = get_few_shot_product_comparison_template()
666
+ elif image is not None:
667
+ template = get_zero_shot_image_template()
668
+ else:
669
+ template = get_zero_shot_product_template()
670
+
671
+ # Create enhanced prompt with metadata context
672
+ prompt = PromptTemplate(
673
+ template=template,
674
+ input_variables=["query", "context"]
675
+ )
676
+
677
+ # Configure generation parameters
678
+ pipe = pipeline(
679
+ "text-generation",
680
+ model=llm_model,
681
+ tokenizer=llm_tokenizer,
682
+ max_new_tokens=300,
683
+ temperature=0.1,
684
+ do_sample=False,
685
+ repetition_penalty=1.2,
686
+ early_stopping=True,
687
+ truncation=True,
688
+ padding=True
689
+ )
690
+
691
+ # Generate and clean response
692
+ formatted_prompt = prompt.format(query=query, context=context)
693
+ response = pipe(formatted_prompt)[0]['generated_text']
694
+
695
+ # Clean response
696
+ for section in ["Answer:", "Question:", "Guidelines:", "Context:"]:
697
+ if section in response:
698
+ response = response.split(section)[-1].strip()
699
+
700
+ return response
701
+
702
+ def chatbot(query, image_input=None):
703
+ """
704
+ Main chatbot function to handle queries and provide responses.
705
+ """
706
+ if image_input is not None:
707
+ try:
708
+ # Convert URL to image if needed
709
+ if isinstance(image_input, str):
710
+ image_input = load_image_from_url(image_input)
711
+ elif not isinstance(image_input, Image.Image):
712
+ raise ValueError("Invalid image input type")
713
+
714
+ # Get context and generate response
715
+ context, _ = retrieve_context(query, image_input)
716
+ if not context:
717
+ return "No relevant products found for this image."
718
+ response = generate_rag_response(query, context, image_input)
719
+ return response
720
+
721
+ except Exception as e:
722
+ print(f"Error processing image: {str(e)}")
723
+ return f"Failed to process image: {str(e)}"
724
+ else:
725
+ try:
726
+ print(f"Processing query: {query}")
727
+ if text_faiss is None or image_faiss is None:
728
+ return "Search indexes not initialized. Please try again."
729
+
730
+ results, query_type = hybrid_retrieval(query)
731
+ print(f"Query type: {query_type}")
732
+
733
+ if not results and query_type == 'image_search':
734
+ print("No relevant images found. Falling back to text search.")
735
+ results = fallback_text_search(query)
736
+
737
+ if not results:
738
+ return "No relevant products found."
739
+
740
+ context = "\n\n".join([enhance_context_with_metadata(item, metadata) for item in results])
741
+ response = generate_rag_response(query, context)
742
+
743
+ if query_type == 'image_search':
744
+ print("\nFound matching products:")
745
+ displayed_images = display_product_images(results)
746
+
747
+ # Always return a dictionary with both text and images for image search queries
748
+ return {
749
+ 'text': response,
750
+ 'images': displayed_images
751
+ }
752
+
753
+ return response
754
+ except Exception as e:
755
+ print(f"Error processing query: {str(e)}")
756
+ return f"Error processing request: {str(e)}"
757
+
758
+
759
+ def cleanup_resources():
760
+ if torch.cuda.is_available():
761
+ torch.cuda.empty_cache()
762
+ print("GPU memory cleared")
README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amazon E-commerce Visual Assistant
2
+
3
+ A multimodal AI assistant that helps users search and explore Amazon products through natural language and image-based interactions.
4
+
5
+ ## Features
6
+
7
+ - Text and image-based product search
8
+ - Product comparisons and recommendations
9
+ - Visual product recognition
10
+ - Detailed product information retrieval
11
+ - Price analysis and comparison
12
+
13
+ ## Technologies Used
14
+
15
+ - FashionCLIP for visual understanding
16
+ - Mistral-7B Language Model for text generation
17
+ - FAISS for efficient similarity search
18
+ - Streamlit for the user interface
19
+
20
+ ## Setup and Installation
21
+
22
+ 1. Clone the repository:
23
+ ```bash
24
+ git clone https://github.com/wisdom196473/amazon-multimodal-product-assistant.git
25
+ cd amazon-multimodal-product-assistant
26
+ ```
27
+
28
+ 2. Install dependencies:
29
+ ```bash
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ 3. Run the application:
34
+ ```bash
35
+ streamlit run amazon_app.py
36
+ ```
37
+
38
+ ## Project Structure
39
+
40
+ - `amazon_app.py`: Main Streamlit application
41
+ - `model.py`: Core AI model implementations
42
+ - `requirements.txt`: Project dependencies
43
+
44
+ ## License
45
+
46
+ MIT License
Vision_AI.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
amazon_app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # Configure page
4
+ st.set_page_config(
5
+ page_title="E-commerce Visual Assistant",
6
+ page_icon="🛍️",
7
+ layout="wide"
8
+ )
9
+
10
+ from streamlit_chat import message
11
+ import torch
12
+ from PIL import Image
13
+ import requests
14
+ from io import BytesIO
15
+ from model import initialize_models, load_data, chatbot, cleanup_resources
16
+
17
+ # Helper functions
18
+ def load_image_from_url(url):
19
+ try:
20
+ response = requests.get(url)
21
+ img = Image.open(BytesIO(response.content))
22
+ return img
23
+ except Exception as e:
24
+ st.error(f"Error loading image from URL: {str(e)}")
25
+ return None
26
+
27
+ def initialize_assistant():
28
+ if not st.session_state.models_loaded:
29
+ with st.spinner("Loading models and data..."):
30
+ initialize_models()
31
+ load_data()
32
+ st.session_state.models_loaded = True
33
+ st.success("Assistant is ready!")
34
+
35
+ def display_chat_history():
36
+ for message in st.session_state.messages:
37
+ with st.chat_message(message["role"]):
38
+ st.markdown(message["content"])
39
+ if "image" in message:
40
+ st.image(message["image"], caption="Uploaded Image", width=200)
41
+ if "display_images" in message:
42
+ # Since we only have one image, we don't need multiple columns
43
+ img_data = message["display_images"][0] # Get the first (and only) image
44
+ st.image(
45
+ img_data['image'],
46
+ caption=f"{img_data['product_name']}\nPrice: ${img_data['price']:.2f}",
47
+ width=350 # Adjusted width for single image display
48
+ )
49
+
50
+ def handle_user_input(prompt, uploaded_image):
51
+ # Add user message
52
+ st.session_state.messages.append({"role": "user", "content": prompt})
53
+
54
+ # Generate response
55
+ with st.spinner("Processing your request..."):
56
+ try:
57
+ response = chatbot(prompt, image_input=uploaded_image)
58
+
59
+ if isinstance(response, dict):
60
+ assistant_message = {
61
+ "role": "assistant",
62
+ "content": response['text']
63
+ }
64
+ if 'images' in response and response['images']:
65
+ assistant_message["display_images"] = response['images']
66
+ st.session_state.messages.append(assistant_message)
67
+ else:
68
+ st.session_state.messages.append({
69
+ "role": "assistant",
70
+ "content": response
71
+ })
72
+
73
+ except Exception as e:
74
+ st.error(f"Error: {str(e)}")
75
+ st.session_state.messages.append({
76
+ "role": "assistant",
77
+ "content": f"I encountered an error: {str(e)}"
78
+ })
79
+
80
+ st.rerun()
81
+
82
+ # Custom CSS for enhanced styling
83
+ st.markdown("""
84
+ <style>
85
+ /* Main container styling */
86
+ .main {
87
+ background: linear-gradient(135deg, #f5f7fa 0%, #e8edf2 100%);
88
+ padding: 20px;
89
+ border-radius: 15px;
90
+ }
91
+
92
+ /* Header styling */
93
+ .stTitle {
94
+ color: #1e3d59;
95
+ font-size: 2.5rem !important;
96
+ text-align: center;
97
+ padding: 20px;
98
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
99
+ }
100
+
101
+ /* Sidebar styling */
102
+ .css-1d391kg {
103
+ background: linear-gradient(180deg, #1e3d59 0%, #2b5876 100%);
104
+ }
105
+
106
+ /* Chat container styling */
107
+ .stChatMessage {
108
+ background-color: white;
109
+ border-radius: 15px;
110
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
111
+ margin: 10px 0;
112
+ padding: 15px;
113
+ }
114
+
115
+ /* Input box styling */
116
+ .stTextInput > div > div > input {
117
+ border-radius: 20px;
118
+ border: 2px solid #1e3d59;
119
+ padding: 10px 20px;
120
+ }
121
+
122
+ /* Radio button styling */
123
+ .stRadio > label {
124
+ background-color: white;
125
+ padding: 10px 20px;
126
+ border-radius: 10px;
127
+ margin: 5px;
128
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
129
+ }
130
+
131
+ /* Button styling */
132
+ .stButton > button {
133
+ background: linear-gradient(90deg, #1e3d59 0%, #2b5876 100%);
134
+ color: white;
135
+ border-radius: 20px;
136
+ padding: 10px 25px;
137
+ border: none;
138
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
139
+ transition: all 0.3s ease;
140
+ }
141
+
142
+ .stButton > button:hover {
143
+ transform: translateY(-2px);
144
+ box-shadow: 0 6px 8px rgba(0,0,0,0.2);
145
+ }
146
+
147
+ /* Footer styling */
148
+ footer {
149
+ background-color: white;
150
+ border-radius: 10px;
151
+ padding: 20px;
152
+ margin-top: 30px;
153
+ text-align: center;
154
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
155
+ }
156
+ </style>
157
+ """, unsafe_allow_html=True)
158
+
159
+ # Initialize session state
160
+ if 'messages' not in st.session_state:
161
+ st.session_state.messages = []
162
+ if 'models_loaded' not in st.session_state:
163
+ st.session_state.models_loaded = False
164
+
165
+ # Main title with enhanced styling
166
+ st.markdown("<h1 class='stTitle'>🛍️ Amazon E-commerce Visual Assistant</h1>", unsafe_allow_html=True)
167
+
168
+ # Sidebar configuration with enhanced styling
169
+ with st.sidebar:
170
+ st.title("Assistant Features")
171
+
172
+ st.markdown("### 🤖 How It Works")
173
+ st.markdown("""
174
+ This AI-powered shopping assistant combines:
175
+
176
+ **🧠 Advanced Technologies**
177
+ - FashionCLIP Visual AI
178
+ - Mistral-7B Language Model
179
+ - Multimodal Understanding
180
+
181
+ **💫 Capabilities**
182
+ - Product Search & Recognition
183
+ - Visual Analysis
184
+ - Detailed Comparisons
185
+ - Price Analysis
186
+ """)
187
+
188
+ st.markdown("---")
189
+
190
+ st.markdown("### 👥 Development Team")
191
+ team_members = {
192
+ "Yu-Chih (Wisdom) Chen",
193
+ "Feier Xu",
194
+ "Yanchen Dong",
195
+ "Kitae Kim"
196
+ }
197
+
198
+ for name in team_members:
199
+ st.markdown(f"**{name}**")
200
+
201
+ st.markdown("---")
202
+
203
+ if st.button("🔄 Reset Chat"):
204
+ st.session_state.messages = []
205
+ st.rerun()
206
+
207
+ # Main chat interface
208
+ def main():
209
+ # Initialize assistant
210
+ initialize_assistant()
211
+
212
+ # Chat container
213
+ chat_container = st.container()
214
+
215
+ # User input section at the bottom
216
+ input_container = st.container()
217
+
218
+ with input_container:
219
+ # Chat input
220
+ prompt = st.chat_input("What would you like to know?")
221
+
222
+ # Input options below chat input
223
+ col1, col2, col3 = st.columns([1,1,1])
224
+ with col1:
225
+ input_option = st.radio(
226
+ "Input Method:",
227
+ ("Text Only", "Upload Image", "Image URL"),
228
+ key="input_method"
229
+ )
230
+
231
+ # Handle different input methods
232
+ uploaded_image = None
233
+ if input_option == "Upload Image":
234
+ with col2:
235
+ uploaded_file = st.file_uploader("Choose image", type=["jpg", "jpeg", "png"])
236
+ if uploaded_file:
237
+ uploaded_image = Image.open(uploaded_file)
238
+ st.image(uploaded_image, caption="Uploaded Image", width=200)
239
+
240
+ elif input_option == "Image URL":
241
+ with col2:
242
+ image_url = st.text_input("Enter image URL")
243
+ if image_url:
244
+ uploaded_image = load_image_from_url(image_url)
245
+ if uploaded_image:
246
+ st.image(uploaded_image, caption="Image from URL", width=200)
247
+
248
+ # Display chat history
249
+ with chat_container:
250
+ display_chat_history()
251
+
252
+ # Handle user input and generate response
253
+ if prompt:
254
+ handle_user_input(prompt, uploaded_image)
255
+
256
+ # Footer
257
+ st.markdown("""
258
+ <footer>
259
+ <h3>💡 Tips for Best Results</h3>
260
+ <p>Be specific in your questions for more accurate responses!</p>
261
+ <p>Try asking about product features, comparisons, or prices.</p>
262
+ </footer>
263
+ """, unsafe_allow_html=True)
264
+
265
+ if __name__ == "__main__":
266
+ try:
267
+ main()
268
+ finally:
269
+ cleanup_resources()
clip_embedding_evaluation_results/evaluation_metrics.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Timestamp,Model,Dataset,Recall@1,Precision@1,Recall@5,Precision@5,NDCG@5,Recall@10,Precision@10,NDCG@10
2
+ 20241205,FashionCLIP-FAISS,Amazon Product Dataset,0.638,0.638,0.851,0.17,0.756,0.901,0.09,0.772
model.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard libraries
2
+ import os
3
+ import io
4
+ import json
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import Dict, List, Tuple, Optional
8
+ import requests
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ from io import BytesIO
12
+
13
+ # Deep learning frameworks
14
+ import torch
15
+ from torch.cuda.amp import autocast
16
+ import open_clip
17
+
18
+ # Hugging Face
19
+ from transformers import (
20
+ AutoTokenizer,
21
+ AutoModelForCausalLM,
22
+ BitsAndBytesConfig,
23
+ pipeline,
24
+ PreTrainedModel,
25
+ PreTrainedTokenizer
26
+ )
27
+ from huggingface_hub import hf_hub_download
28
+ from langchain.prompts import PromptTemplate
29
+
30
+ # Vector database
31
+ import faiss
32
+
33
+ # Type hints
34
+ from typing import Dict, List, Tuple, Optional, Union
35
+
36
+ # Global variables
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ clip_model: Optional[PreTrainedModel] = None
39
+ clip_preprocess: Optional[callable] = None
40
+ clip_tokenizer: Optional[PreTrainedTokenizer] = None
41
+ llm_tokenizer: Optional[PreTrainedTokenizer] = None
42
+ llm_model: Optional[PreTrainedModel] = None
43
+ product_df: Optional[pd.DataFrame] = None
44
+ metadata: Dict = {}
45
+ embeddings_df: Optional[pd.DataFrame] = None
46
+ text_faiss: Optional[object] = None
47
+ image_faiss: Optional[object] = None
48
+
49
+ def initialize_models() -> bool:
50
+ """
51
+ Initialize CLIP and LLM models with proper error handling and GPU optimization.
52
+
53
+ Returns:
54
+ bool: True if initialization successful, raises RuntimeError otherwise
55
+ """
56
+ global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device
57
+
58
+ try:
59
+ print(f"Initializing models on device: {device}")
60
+
61
+ # Initialize CLIP model with error handling
62
+ try:
63
+ clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
64
+ 'hf-hub:Marqo/marqo-fashionCLIP'
65
+ )
66
+ clip_model = clip_model.to(device)
67
+ clip_model.eval()
68
+ clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP')
69
+ print("CLIP model initialized successfully")
70
+ except Exception as e:
71
+ raise RuntimeError(f"Failed to initialize CLIP model: {str(e)}")
72
+
73
+ # Initialize LLM with optimized settings
74
+ try:
75
+ model_name = "mistralai/Mistral-7B-v0.1"
76
+ quantization_config = BitsAndBytesConfig(
77
+ load_in_4bit=True,
78
+ bnb_4bit_compute_dtype=torch.float16,
79
+ bnb_4bit_use_double_quant=True,
80
+ bnb_4bit_quant_type="nf4"
81
+ )
82
+
83
+ llm_tokenizer = AutoTokenizer.from_pretrained(
84
+ model_name,
85
+ padding_side="left",
86
+ truncation_side="left"
87
+ )
88
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
89
+
90
+ llm_model = AutoModelForCausalLM.from_pretrained(
91
+ model_name,
92
+ quantization_config=quantization_config,
93
+ device_map="auto",
94
+ torch_dtype=torch.float16
95
+ )
96
+ llm_model.eval()
97
+ print("LLM initialized successfully")
98
+ except Exception as e:
99
+ raise RuntimeError(f"Failed to initialize LLM: {str(e)}")
100
+
101
+ return True
102
+
103
+ except Exception as e:
104
+ raise RuntimeError(f"Model initialization failed: {str(e)}")
105
+
106
+ # Data loading
107
+ def load_data() -> bool:
108
+ """
109
+ Load and initialize all required data with enhanced metadata support and error handling.
110
+
111
+ Returns:
112
+ bool: True if data loading successful, raises RuntimeError otherwise
113
+ """
114
+ global product_df, metadata, embeddings_df, text_faiss, image_faiss
115
+
116
+ try:
117
+ print("Loading product data...")
118
+ # Load cleaned product data
119
+ try:
120
+ cleaned_data_path = hf_hub_download(
121
+ repo_id="chen196473/amazon_product_2020_cleaned",
122
+ filename="amazon_cleaned.parquet",
123
+ repo_type="dataset"
124
+ )
125
+ product_df = pd.read_parquet(cleaned_data_path)
126
+
127
+ # Add validation columns
128
+ product_df['Has_Valid_Image'] = product_df['Processed Image'].notna()
129
+ product_df['Image_Status'] = product_df['Has_Valid_Image'].map({
130
+ True: 'valid',
131
+ False: 'invalid'
132
+ })
133
+ print("Product data loaded successfully")
134
+ except Exception as e:
135
+ raise RuntimeError(f"Failed to load product data: {str(e)}")
136
+
137
+ # Load enhanced metadata
138
+ print("Loading metadata...")
139
+ try:
140
+ metadata = {}
141
+ metadata_files = [
142
+ 'base_metadata.json',
143
+ 'category_index.json',
144
+ 'price_range_index.json',
145
+ 'keyword_index.json',
146
+ 'brand_index.json',
147
+ 'product_name_index.json'
148
+ ]
149
+
150
+ for file in metadata_files:
151
+ file_path = hf_hub_download(
152
+ repo_id="chen196473/amazon_product_2020_metadata",
153
+ filename=file,
154
+ repo_type="dataset"
155
+ )
156
+ with open(file_path, 'r') as f:
157
+ index_name = file.replace('.json', '')
158
+ data = json.load(f)
159
+
160
+ if index_name == 'base_metadata':
161
+ data = {item['Uniq_Id']: item for item in data}
162
+ for item in data.values():
163
+ if 'Keywords' in item:
164
+ item['Keywords'] = set(item['Keywords'])
165
+
166
+ metadata[index_name] = data
167
+ print("Metadata loaded successfully")
168
+ except Exception as e:
169
+ raise RuntimeError(f"Failed to load metadata: {str(e)}")
170
+
171
+ # Load embeddings
172
+ print("Loading embeddings...")
173
+ try:
174
+ text_embeddings_dict, image_embeddings_dict = load_embeddings_from_huggingface(
175
+ "chen196473/amazon_vector_database"
176
+ )
177
+
178
+ # Create embeddings DataFrame
179
+ embeddings_df = pd.DataFrame({
180
+ 'text_embeddings': list(text_embeddings_dict.values()),
181
+ 'image_embeddings': list(image_embeddings_dict.values()),
182
+ 'Uniq_Id': list(text_embeddings_dict.keys())
183
+ })
184
+
185
+ # Merge with product data
186
+ product_df = product_df.merge(
187
+ embeddings_df,
188
+ left_on='Uniq Id',
189
+ right_on='Uniq_Id',
190
+ how='inner'
191
+ )
192
+ print("Embeddings loaded and merged successfully")
193
+
194
+ # Create FAISS indexes
195
+ print("Creating FAISS indexes...")
196
+ try:
197
+ create_faiss_indexes(text_embeddings_dict, image_embeddings_dict)
198
+ print("FAISS indexes created successfully")
199
+
200
+ # Verify FAISS indexes are properly initialized and contain data
201
+ if text_faiss is None or image_faiss is None:
202
+ raise RuntimeError("FAISS indexes were not properly initialized")
203
+
204
+ # Test a simple query to verify indexes are working
205
+ test_query = "test"
206
+ tokens = clip_tokenizer(test_query).to(device)
207
+ with torch.no_grad():
208
+ text_embedding = clip_model.encode_text(tokens)
209
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
210
+ text_embedding = text_embedding.cpu().numpy()
211
+
212
+ # Verify search works
213
+ test_results = text_faiss.search(text_embedding[0], k=1)
214
+ if not test_results:
215
+ raise RuntimeError("FAISS indexes are empty")
216
+
217
+ print("FAISS indexes verified successfully")
218
+
219
+ except Exception as e:
220
+ raise RuntimeError(f"Failed to create or verify FAISS indexes: {str(e)}")
221
+
222
+ except Exception as e:
223
+ raise RuntimeError(f"Failed to load embeddings: {str(e)}")
224
+
225
+ # Validate required columns
226
+ required_columns = [
227
+ 'Uniq Id', 'Product Name', 'Category', 'Selling Price',
228
+ 'Model Number', 'Image', 'Normalized Description'
229
+ ]
230
+ missing_cols = set(required_columns) - set(product_df.columns)
231
+ if missing_cols:
232
+ raise ValueError(f"Missing required columns: {missing_cols}")
233
+
234
+ # Add enhanced metadata fields
235
+ if 'Search_Text' not in product_df.columns:
236
+ product_df['Search_Text'] = product_df.apply(
237
+ lambda x: metadata['base_metadata'].get(x['Uniq Id'], {}).get('Search_Text', ''),
238
+ axis=1
239
+ )
240
+
241
+ # Final verification of loaded data
242
+ if product_df is None or product_df.empty:
243
+ raise RuntimeError("Product DataFrame is empty or not initialized")
244
+
245
+ if not metadata:
246
+ raise RuntimeError("Metadata dictionary is empty")
247
+
248
+ if embeddings_df is None or embeddings_df.empty:
249
+ raise RuntimeError("Embeddings DataFrame is empty or not initialized")
250
+
251
+ print("Data loading completed successfully")
252
+ return True
253
+
254
+ except Exception as e:
255
+ # Clean up any partially loaded data
256
+ product_df = None
257
+ metadata = {}
258
+ embeddings_df = None
259
+ text_faiss = None
260
+ image_faiss = None
261
+ raise RuntimeError(f"Data loading failed: {str(e)}")
262
+
263
+ def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]:
264
+ """
265
+ Load embeddings from Hugging Face repository with enhanced error handling.
266
+
267
+ Args:
268
+ repo_id (str): Hugging Face repository ID
269
+
270
+ Returns:
271
+ Tuple[Dict, Dict]: Dictionaries containing text and image embeddings
272
+ """
273
+ print("Loading embeddings from Hugging Face...")
274
+ try:
275
+ file_path = hf_hub_download(
276
+ repo_id=repo_id,
277
+ filename="embeddings.parquet",
278
+ repo_type="dataset"
279
+ )
280
+ df = pd.read_parquet(file_path)
281
+
282
+ # Extract embedding columns
283
+ text_cols = [col for col in df.columns if col.startswith('text_embedding_')]
284
+ image_cols = [col for col in df.columns if col.startswith('image_embedding_')]
285
+
286
+ # Create embedding dictionaries
287
+ text_embeddings_dict = {
288
+ row['Uniq_Id']: row[text_cols].values.astype(np.float32)
289
+ for _, row in df.iterrows()
290
+ }
291
+ image_embeddings_dict = {
292
+ row['Uniq_Id']: row[image_cols].values.astype(np.float32)
293
+ for _, row in df.iterrows()
294
+ }
295
+
296
+ print(f"Successfully loaded {len(text_embeddings_dict)} embeddings")
297
+ return text_embeddings_dict, image_embeddings_dict
298
+
299
+ except Exception as e:
300
+ raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}")
301
+
302
+ # FAISS index creation
303
+ class MultiModalFAISSIndex:
304
+ def __init__(self, dimension, index_type='L2'):
305
+ import faiss
306
+ self.dimension = dimension
307
+ self.index = faiss.IndexFlatL2(dimension) if index_type == 'L2' else faiss.IndexFlatIP(dimension)
308
+ self.id_to_metadata = {}
309
+
310
+ def add_embeddings(self, embeddings, metadata_list):
311
+ import numpy as np
312
+ embeddings = np.array(embeddings).astype('float32')
313
+ self.index.add(embeddings)
314
+ for i, metadata in enumerate(metadata_list):
315
+ self.id_to_metadata[i] = metadata
316
+
317
+ def search(self, query_embedding, k=5):
318
+ import numpy as np
319
+ query_embedding = np.array([query_embedding]).astype('float32')
320
+ distances, indices = self.index.search(query_embedding, k)
321
+ results = []
322
+ for idx in indices[0]:
323
+ if idx in self.id_to_metadata:
324
+ results.append(self.id_to_metadata[idx])
325
+ return results
326
+
327
+ def create_faiss_indexes(text_embeddings_dict, image_embeddings_dict):
328
+ """Create FAISS indexes with error handling"""
329
+ global text_faiss, image_faiss
330
+
331
+ try:
332
+ # Get embedding dimension
333
+ text_dim = next(iter(text_embeddings_dict.values())).shape[0]
334
+ image_dim = next(iter(image_embeddings_dict.values())).shape[0]
335
+
336
+ # Create indexes
337
+ text_faiss = MultiModalFAISSIndex(text_dim)
338
+ image_faiss = MultiModalFAISSIndex(image_dim)
339
+
340
+ # Prepare text embeddings and metadata
341
+ text_embeddings = []
342
+ text_metadata = []
343
+ for text_id, embedding in text_embeddings_dict.items():
344
+ if text_id in product_df['Uniq Id'].values:
345
+ product = product_df[product_df['Uniq Id'] == text_id].iloc[0]
346
+ text_embeddings.append(embedding)
347
+ text_metadata.append({
348
+ 'id': text_id,
349
+ 'description': product['Normalized Description'],
350
+ 'product_name': product['Product Name']
351
+ })
352
+
353
+ # Add text embeddings
354
+ if text_embeddings:
355
+ text_faiss.add_embeddings(text_embeddings, text_metadata)
356
+
357
+ # Prepare image embeddings and metadata
358
+ image_embeddings = []
359
+ image_metadata = []
360
+ for image_id, embedding in image_embeddings_dict.items():
361
+ if image_id in product_df['Uniq Id'].values:
362
+ product = product_df[product_df['Uniq Id'] == image_id].iloc[0]
363
+ image_embeddings.append(embedding)
364
+ image_metadata.append({
365
+ 'id': image_id,
366
+ 'image_url': product['Image'],
367
+ 'product_name': product['Product Name']
368
+ })
369
+
370
+ # Add image embeddings
371
+ if image_embeddings:
372
+ image_faiss.add_embeddings(image_embeddings, image_metadata)
373
+
374
+ return True
375
+
376
+ except Exception as e:
377
+ raise RuntimeError(f"Failed to create FAISS indexes: {str(e)}")
378
+
379
+ def get_few_shot_product_comparison_template():
380
+ return """Compare these specific products based on their actual features and specifications:
381
+
382
+ Example 1:
383
+ Question: Compare iPhone 13 and Samsung Galaxy S21
384
+ Answer: The iPhone 13 features a 6.1-inch Super Retina XDR display and dual 12MP cameras, while the Galaxy S21 has a 6.2-inch Dynamic AMOLED display and triple camera setup. Both phones offer 5G connectivity, but the iPhone uses A15 Bionic chip while S21 uses Snapdragon 888.
385
+
386
+ Example 2:
387
+ Question: Compare Amazon Echo Dot and Google Nest Mini
388
+ Answer: The Amazon Echo Dot features Alexa voice assistant and a 1.6-inch speaker, while the Google Nest Mini comes with Google Assistant and a 40mm driver. Both devices offer smart home control and music playback, but differ in their ecosystem integration.
389
+
390
+ Current Question: {query}
391
+ Context: {context}
392
+
393
+ Guidelines:
394
+ - Only compare the specific products mentioned in the query
395
+ - Focus on actual product features and specifications
396
+ - Keep response to 2-3 clear sentences
397
+ - Ensure factual accuracy based on the context provided
398
+
399
+ Answer:"""
400
+
401
+ def get_zero_shot_product_template():
402
+ return """You are a product information specialist. Describe only the specific product's actual features based on the provided context.
403
+
404
+ Context: {context}
405
+
406
+ Question: {query}
407
+
408
+ Guidelines:
409
+ - Only describe the specific product mentioned in the query
410
+ - Focus on actual features and specifications from the context
411
+ - Keep response to 2-3 factual sentences
412
+ - Ensure information accuracy
413
+
414
+ Answer:"""
415
+
416
+ def get_zero_shot_image_template():
417
+ return """Analyze this product image and provide a concise description:
418
+
419
+ Product Information:
420
+ {context}
421
+
422
+ Guidelines:
423
+ - Describe the main product features and intended use
424
+ - Highlight key specifications and materials
425
+ - Keep response to 2-3 sentences
426
+ - Focus on practical information
427
+
428
+ Answer:"""
429
+
430
+ # Image processing functions
431
+ def process_image(image):
432
+ try:
433
+ if isinstance(image, str):
434
+ response = requests.get(image)
435
+ image = Image.open(io.BytesIO(response.content))
436
+
437
+ processed_image = clip_preprocess(image).unsqueeze(0).to(device)
438
+
439
+ with torch.no_grad():
440
+ image_features = clip_model.encode_image(processed_image)
441
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
442
+
443
+ return image_features.cpu().numpy()
444
+ except Exception as e:
445
+ raise Exception(f"Error processing image: {str(e)}")
446
+
447
+ def load_image_from_url(url):
448
+ response = requests.get(url)
449
+ if response.status_code == 200:
450
+ return Image.open(io.BytesIO(response.content))
451
+ else:
452
+ raise Exception(f"Failed to fetch image from URL: {url}, Status Code: {response.status_code}")
453
+
454
+ # Context retrieval and enhancement
455
+ def filter_by_metadata(query, metadata_index):
456
+ relevant_products = set()
457
+
458
+ # Check category index
459
+ if 'category_index' in metadata_index:
460
+ categories = metadata_index['category_index']
461
+ for category in categories:
462
+ if any(term.lower() in category.lower() for term in query.split()):
463
+ relevant_products.update(categories[category])
464
+
465
+ # Check product name index
466
+ if 'product_name_index' in metadata_index:
467
+ product_names = metadata_index['product_name_index']
468
+ for term in query.split():
469
+ if term.lower() in product_names:
470
+ relevant_products.update(product_names[term.lower()])
471
+
472
+ # Check price ranges
473
+ price_terms = {'cheap', 'expensive', 'price', 'cost', 'affordable'}
474
+ if any(term in query.lower() for term in price_terms) and 'price_range_index' in metadata_index:
475
+ price_ranges = metadata_index['price_range_index']
476
+ for price_range in price_ranges:
477
+ relevant_products.update(price_ranges[price_range])
478
+
479
+ return relevant_products if relevant_products else None
480
+
481
+ def enhance_context_with_metadata(product, metadata_index):
482
+ """Enhanced context building using new metadata structure"""
483
+ # Access base_metadata using product ID directly since it's now a dictionary
484
+ base_metadata = metadata_index['base_metadata'].get(product['Uniq Id'])
485
+
486
+ if base_metadata:
487
+ # Get keywords and search text from enhanced metadata
488
+ keywords = base_metadata.get('Keywords', [])
489
+ search_text = base_metadata.get('Search_Text', '')
490
+
491
+ # Build enhanced description
492
+ description = []
493
+ description.append(f"Product Name: {base_metadata['Product_Name']}")
494
+ description.append(f"Category: {base_metadata['Category']}")
495
+ description.append(f"Price: ${base_metadata['Selling_Price']:.2f}")
496
+
497
+ # Add key features from normalized description
498
+ if 'Normalized_Description' in base_metadata:
499
+ features = []
500
+ for feature in base_metadata['Normalized_Description'].split('|'):
501
+ if ':' in feature:
502
+ key, value = feature.split(':', 1)
503
+ if not any(skip in key.lower() for skip in
504
+ ['uniq id', 'product url', 'specifications', 'asin']):
505
+ features.append(f"{key.strip()}: {value.strip()}")
506
+ if features:
507
+ description.append("Key Features:")
508
+ description.extend(features[:3])
509
+
510
+ # Add relevant keywords
511
+ if keywords:
512
+ description.append("Related Terms: " + ", ".join(list(keywords)[:5]))
513
+
514
+ return "\n".join(description)
515
+
516
+ return None
517
+
518
+ def retrieve_context(query, image=None, top_k=5):
519
+ """Enhanced context retrieval using both FAISS and metadata"""
520
+ # Initialize context lists
521
+ similar_items = []
522
+ context = []
523
+
524
+ if image is not None:
525
+ # Process image query
526
+ image_embedding = process_image(image)
527
+ image_embedding = image_embedding.reshape(1, -1)
528
+ similar_items = image_faiss.search(image_embedding[0], k=top_k)
529
+ else:
530
+ # Process text query with enhanced metadata filtering
531
+ relevant_products = filter_by_metadata(query, metadata)
532
+
533
+ tokens = clip_tokenizer(query).to(device)
534
+ with torch.no_grad():
535
+ text_embedding = clip_model.encode_text(tokens)
536
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
537
+ text_embedding = text_embedding.cpu().numpy()
538
+
539
+ # Get FAISS results
540
+ similar_items = text_faiss.search(text_embedding[0], k=top_k*2) # Get more results for filtering
541
+
542
+ # Filter results using metadata if available
543
+ if relevant_products:
544
+ similar_items = [item for item in similar_items if item['id'] in relevant_products][:top_k]
545
+
546
+ # Build enhanced context
547
+ for item in similar_items:
548
+ product = product_df[product_df['Uniq Id'] == item['id']].iloc[0]
549
+ enhanced_context = enhance_context_with_metadata(product, metadata)
550
+ if enhanced_context:
551
+ context.append(enhanced_context)
552
+
553
+ return "\n\n".join(context), similar_items
554
+
555
+ def display_product_images(similar_items, max_images=1):
556
+ displayed_images = []
557
+
558
+ for item in similar_items[:max_images]:
559
+ try:
560
+ # Get image URL from product data
561
+ image_url = item['Image'] if isinstance(item, pd.Series) else item.get('Image')
562
+ if not image_url:
563
+ continue
564
+
565
+ # Handle multiple image URLs
566
+ image_urls = image_url.split('|')
567
+ image_url = image_urls[0] # Take first image
568
+
569
+ # Load image
570
+ response = requests.get(image_url)
571
+ img = Image.open(BytesIO(response.content))
572
+
573
+ # Get product details
574
+ product_name = item['Product Name'] if isinstance(item, pd.Series) else item.get('product_name')
575
+ price = item['Selling Price'] if isinstance(item, pd.Series) else item.get('price', 0)
576
+
577
+ # Add to displayed images
578
+ displayed_images.append({
579
+ 'image': img,
580
+ 'product_name': product_name,
581
+ 'price': float(price)
582
+ })
583
+
584
+ except Exception as e:
585
+ print(f"Error processing item: {str(e)}")
586
+ continue
587
+
588
+ return displayed_images
589
+
590
+ def classify_query(query):
591
+ """Classify the type of query to determine the retrieval strategy."""
592
+ query_lower = query.lower()
593
+ if any(keyword in query_lower for keyword in ['compare', 'difference between']):
594
+ return 'comparison'
595
+ elif any(keyword in query_lower for keyword in ['show', 'picture', 'image', 'photo']):
596
+ return 'image_search'
597
+ else:
598
+ return 'product_info'
599
+
600
+ def boost_category_relevance(query, product, similarity_score):
601
+ query_terms = set(query.lower().split())
602
+ category_terms = set(product['Category'].lower().split())
603
+ category_overlap = len(query_terms & category_terms)
604
+ category_boost = 1 + (category_overlap * 0.2) # 20% boost per matching term
605
+ return similarity_score * category_boost
606
+
607
+ def hybrid_retrieval(query, top_k=5):
608
+ query_type = classify_query(query)
609
+
610
+ tokens = clip_tokenizer(query).to(device)
611
+ with torch.no_grad():
612
+ text_embedding = clip_model.encode_text(tokens)
613
+ text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
614
+ text_embedding = text_embedding.cpu().numpy()
615
+
616
+ # First get text matches
617
+ text_results = text_faiss.search(text_embedding[0], k=top_k*2)
618
+
619
+ if query_type == 'image_search':
620
+ image_results = []
621
+ for item in text_results:
622
+ # Get original product with embeddings intact
623
+ product = product_df[product_df['Uniq Id'] == item['id']].iloc[0]
624
+ # Get image embeddings from embeddings_df instead
625
+ image_embedding = embeddings_df[embeddings_df['Uniq_Id'] == item['id']]['image_embeddings'].iloc[0]
626
+ similarity = np.dot(text_embedding.flatten(), image_embedding.flatten())
627
+ boosted_similarity = boost_category_relevance(query, product, similarity)
628
+ image_results.append((product, boosted_similarity))
629
+
630
+ image_results.sort(key=lambda x: x[1], reverse=True)
631
+ results = [item for item, _ in image_results[:top_k]]
632
+ else:
633
+ results = [product_df[product_df['Uniq Id'] == item['id']].iloc[0] for item in text_results[:top_k]]
634
+
635
+ return results, query_type
636
+
637
+
638
+ def fallback_text_search(query, top_k=10):
639
+ relevant_products = filter_by_metadata(query, metadata)
640
+ if not relevant_products:
641
+ # Check brand index specifically
642
+ if 'brand_index' in metadata:
643
+ query_terms = query.lower().split()
644
+ for term in query_terms:
645
+ if term in metadata['brand_index']:
646
+ relevant_products = set(metadata['brand_index'][term])
647
+ break
648
+
649
+ if relevant_products:
650
+ results = [product_df[product_df['Uniq Id'] == pid].iloc[0] for pid in list(relevant_products)[:top_k]]
651
+ else:
652
+ query_lower = query.lower()
653
+ results = product_df[
654
+ (product_df['Product Name'].str.lower().str.contains(query_lower)) |
655
+ (product_df['Category'].str.lower().str.contains(query_lower)) |
656
+ (product_df['Normalized Description'].str.lower().str.contains(query_lower))
657
+ ].head(top_k)
658
+
659
+ return results
660
+
661
+ def generate_rag_response(query, context, image=None):
662
+ """Enhanced RAG response generation"""
663
+ # Select template based on query type and metadata
664
+ if "compare" in query.lower() or "difference between" in query.lower() or "vs." in query.lower():
665
+ template = get_few_shot_product_comparison_template()
666
+ elif image is not None:
667
+ template = get_zero_shot_image_template()
668
+ else:
669
+ template = get_zero_shot_product_template()
670
+
671
+ # Create enhanced prompt with metadata context
672
+ prompt = PromptTemplate(
673
+ template=template,
674
+ input_variables=["query", "context"]
675
+ )
676
+
677
+ # Configure generation parameters
678
+ pipe = pipeline(
679
+ "text-generation",
680
+ model=llm_model,
681
+ tokenizer=llm_tokenizer,
682
+ max_new_tokens=300,
683
+ temperature=0.1,
684
+ do_sample=False,
685
+ repetition_penalty=1.2,
686
+ early_stopping=True,
687
+ truncation=True,
688
+ padding=True
689
+ )
690
+
691
+ # Generate and clean response
692
+ formatted_prompt = prompt.format(query=query, context=context)
693
+ response = pipe(formatted_prompt)[0]['generated_text']
694
+
695
+ # Clean response
696
+ for section in ["Answer:", "Question:", "Guidelines:", "Context:"]:
697
+ if section in response:
698
+ response = response.split(section)[-1].strip()
699
+
700
+ return response
701
+
702
+ def chatbot(query, image_input=None):
703
+ """
704
+ Main chatbot function to handle queries and provide responses.
705
+ """
706
+ if image_input is not None:
707
+ try:
708
+ # Convert URL to image if needed
709
+ if isinstance(image_input, str):
710
+ image_input = load_image_from_url(image_input)
711
+ elif not isinstance(image_input, Image.Image):
712
+ raise ValueError("Invalid image input type")
713
+
714
+ # Get context and generate response
715
+ context, _ = retrieve_context(query, image_input)
716
+ if not context:
717
+ return "No relevant products found for this image."
718
+ response = generate_rag_response(query, context, image_input)
719
+ return response
720
+
721
+ except Exception as e:
722
+ print(f"Error processing image: {str(e)}")
723
+ return f"Failed to process image: {str(e)}"
724
+ else:
725
+ try:
726
+ print(f"Processing query: {query}")
727
+ if text_faiss is None or image_faiss is None:
728
+ return "Search indexes not initialized. Please try again."
729
+
730
+ results, query_type = hybrid_retrieval(query)
731
+ print(f"Query type: {query_type}")
732
+
733
+ if not results and query_type == 'image_search':
734
+ print("No relevant images found. Falling back to text search.")
735
+ results = fallback_text_search(query)
736
+
737
+ if not results:
738
+ return "No relevant products found."
739
+
740
+ context = "\n\n".join([enhance_context_with_metadata(item, metadata) for item in results])
741
+ response = generate_rag_response(query, context)
742
+
743
+ if query_type == 'image_search':
744
+ print("\nFound matching products:")
745
+ displayed_images = display_product_images(results)
746
+
747
+ # Always return a dictionary with both text and images for image search queries
748
+ return {
749
+ 'text': response,
750
+ 'images': displayed_images
751
+ }
752
+
753
+ return response
754
+ except Exception as e:
755
+ print(f"Error processing query: {str(e)}")
756
+ return f"Error processing request: {str(e)}"
757
+
758
+
759
+ def cleanup_resources():
760
+ if torch.cuda.is_available():
761
+ torch.cuda.empty_cache()
762
+ print("GPU memory cleared")
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.28.2
2
+ streamlit-chat==0.1.1
3
+ torch==2.1.1
4
+ transformers==4.35.2
5
+ open_clip_torch==2.23.0
6
+ pillow==10.1.0
7
+ pandas==2.1.3
8
+ numpy==1.26.2
9
+ faiss-cpu==1.7.4
10
+ huggingface_hub==0.19.4
11
+ langchain==0.0.339
12
+ requests==2.31.0
13
+ pyngrok==7.0.3
14
+ bitsandbytes==0.41.1