Spaces:
Build error
Build error
# Standard libraries | |
import os | |
import io | |
import json | |
import numpy as np | |
import pandas as pd | |
from typing import Dict, List, Tuple, Optional | |
import requests | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
# Deep learning frameworks | |
import torch | |
from torch.cuda.amp import autocast | |
import open_clip | |
# Hugging Face | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
BitsAndBytesConfig, | |
pipeline, | |
PreTrainedModel, | |
PreTrainedTokenizer | |
) | |
from huggingface_hub import hf_hub_download, login | |
from langchain.prompts import PromptTemplate | |
# Vector database | |
import faiss | |
# Type hints | |
from typing import Dict, List, Tuple, Optional, Union | |
# Global variables | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
clip_model: Optional[PreTrainedModel] = None | |
clip_preprocess: Optional[callable] = None | |
clip_tokenizer: Optional[PreTrainedTokenizer] = None | |
llm_tokenizer: Optional[PreTrainedTokenizer] = None | |
llm_model: Optional[PreTrainedModel] = None | |
product_df: Optional[pd.DataFrame] = None | |
metadata: Dict = {} | |
embeddings_df: Optional[pd.DataFrame] = None | |
text_faiss: Optional[object] = None | |
image_faiss: Optional[object] = None | |
def initialize_models() -> bool: | |
global clip_model, clip_preprocess, clip_tokenizer, llm_tokenizer, llm_model, device | |
try: | |
print(f"Initializing models on device: {device}") | |
# Initialize CLIP model with error handling and fallback | |
try: | |
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( | |
'hf-hub:Marqo/marqo-fashionCLIP', | |
device=device | |
) | |
clip_model.eval() | |
clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP') | |
print("CLIP model initialized successfully") | |
except Exception as e: | |
print(f"CLIP initialization error: {str(e)}") | |
print("Attempting to load CLIP model with CPU fallback...") | |
try: | |
device = "cpu" | |
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( | |
'hf-hub:Marqo/marqo-fashionCLIP', | |
device=device | |
) | |
clip_model.eval() | |
clip_tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionCLIP') | |
print("CLIP model initialized successfully on CPU") | |
except Exception as cpu_e: | |
raise RuntimeError(f"Failed to initialize CLIP model on CPU: {str(cpu_e)}") | |
# Initialize LLM with optimized settings | |
try: | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise RuntimeError("HF_TOKEN environment variable is not set") | |
login(token=hf_token) | |
model_name = "mistralai/Mistral-7B-v0.1" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
llm_tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
padding_side="left", | |
truncation_side="left", | |
token=hf_token | |
) | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
token=hf_token, | |
low_cpu_mem_usage=True # Set to True to allow device_map usage | |
) | |
llm_model.eval() | |
print("LLM initialized successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize LLM: {str(e)}") | |
return True | |
except Exception as e: | |
raise RuntimeError(f"Model initialization failed: {str(e)}") | |
# Data loading | |
def load_data() -> bool: | |
""" | |
Load and initialize all required data with enhanced metadata support and error handling. | |
Returns: | |
bool: True if data loading successful, raises RuntimeError otherwise | |
""" | |
global product_df, metadata, embeddings_df, text_faiss, image_faiss | |
try: | |
print("Loading product data...") | |
# Load cleaned product data | |
try: | |
cleaned_data_path = hf_hub_download( | |
repo_id="chen196473/amazon_product_2020_cleaned", | |
filename="amazon_cleaned.parquet", | |
repo_type="dataset" | |
) | |
product_df = pd.read_parquet(cleaned_data_path) | |
# Add validation columns | |
product_df['Has_Valid_Image'] = product_df['Processed Image'].notna() | |
product_df['Image_Status'] = product_df['Has_Valid_Image'].map({ | |
True: 'valid', | |
False: 'invalid' | |
}) | |
print("Product data loaded successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load product data: {str(e)}") | |
# Load enhanced metadata | |
print("Loading metadata...") | |
try: | |
metadata = {} | |
metadata_files = [ | |
'base_metadata.json', | |
'category_index.json', | |
'price_range_index.json', | |
'keyword_index.json', | |
'brand_index.json', | |
'product_name_index.json' | |
] | |
for file in metadata_files: | |
file_path = hf_hub_download( | |
repo_id="chen196473/amazon_product_2020_metadata", | |
filename=file, | |
repo_type="dataset" | |
) | |
with open(file_path, 'r') as f: | |
index_name = file.replace('.json', '') | |
data = json.load(f) | |
if index_name == 'base_metadata': | |
data = {item['Uniq_Id']: item for item in data} | |
for item in data.values(): | |
if 'Keywords' in item: | |
item['Keywords'] = set(item['Keywords']) | |
metadata[index_name] = data | |
print("Metadata loaded successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load metadata: {str(e)}") | |
# Load embeddings | |
print("Loading embeddings...") | |
try: | |
text_embeddings_dict, image_embeddings_dict = load_embeddings_from_huggingface( | |
"chen196473/amazon_vector_database" | |
) | |
# Create embeddings DataFrame | |
embeddings_df = pd.DataFrame({ | |
'text_embeddings': list(text_embeddings_dict.values()), | |
'image_embeddings': list(image_embeddings_dict.values()), | |
'Uniq_Id': list(text_embeddings_dict.keys()) | |
}) | |
# Merge with product data | |
product_df = product_df.merge( | |
embeddings_df, | |
left_on='Uniq Id', | |
right_on='Uniq_Id', | |
how='inner' | |
) | |
print("Embeddings loaded and merged successfully") | |
# Create FAISS indexes | |
print("Creating FAISS indexes...") | |
try: | |
create_faiss_indexes(text_embeddings_dict, image_embeddings_dict) | |
print("FAISS indexes created successfully") | |
# Verify FAISS indexes are properly initialized and contain data | |
if text_faiss is None or image_faiss is None: | |
raise RuntimeError("FAISS indexes were not properly initialized") | |
# Test a simple query to verify indexes are working | |
test_query = "test" | |
tokens = clip_tokenizer(test_query).to(device) | |
with torch.no_grad(): | |
text_embedding = clip_model.encode_text(tokens) | |
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
text_embedding = text_embedding.cpu().numpy() | |
# Verify search works | |
test_results = text_faiss.search(text_embedding[0], k=1) | |
if not test_results: | |
raise RuntimeError("FAISS indexes are empty") | |
print("FAISS indexes verified successfully") | |
except Exception as e: | |
raise RuntimeError(f"Failed to create or verify FAISS indexes: {str(e)}") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load embeddings: {str(e)}") | |
# Validate required columns | |
required_columns = [ | |
'Uniq Id', 'Product Name', 'Category', 'Selling Price', | |
'Model Number', 'Image', 'Normalized Description' | |
] | |
missing_cols = set(required_columns) - set(product_df.columns) | |
if missing_cols: | |
raise ValueError(f"Missing required columns: {missing_cols}") | |
# Add enhanced metadata fields | |
if 'Search_Text' not in product_df.columns: | |
product_df['Search_Text'] = product_df.apply( | |
lambda x: metadata['base_metadata'].get(x['Uniq Id'], {}).get('Search_Text', ''), | |
axis=1 | |
) | |
# Final verification of loaded data | |
if product_df is None or product_df.empty: | |
raise RuntimeError("Product DataFrame is empty or not initialized") | |
if not metadata: | |
raise RuntimeError("Metadata dictionary is empty") | |
if embeddings_df is None or embeddings_df.empty: | |
raise RuntimeError("Embeddings DataFrame is empty or not initialized") | |
print("Data loading completed successfully") | |
return True | |
except Exception as e: | |
# Clean up any partially loaded data | |
product_df = None | |
metadata = {} | |
embeddings_df = None | |
text_faiss = None | |
image_faiss = None | |
raise RuntimeError(f"Data loading failed: {str(e)}") | |
def load_embeddings_from_huggingface(repo_id: str) -> Tuple[Dict, Dict]: | |
""" | |
Load embeddings from Hugging Face repository with enhanced error handling. | |
Args: | |
repo_id (str): Hugging Face repository ID | |
Returns: | |
Tuple[Dict, Dict]: Dictionaries containing text and image embeddings | |
""" | |
print("Loading embeddings from Hugging Face...") | |
try: | |
file_path = hf_hub_download( | |
repo_id=repo_id, | |
filename="embeddings.parquet", | |
repo_type="dataset" | |
) | |
df = pd.read_parquet(file_path) | |
# Extract embedding columns | |
text_cols = [col for col in df.columns if col.startswith('text_embedding_')] | |
image_cols = [col for col in df.columns if col.startswith('image_embedding_')] | |
# Create embedding dictionaries | |
text_embeddings_dict = { | |
row['Uniq_Id']: row[text_cols].values.astype(np.float32) | |
for _, row in df.iterrows() | |
} | |
image_embeddings_dict = { | |
row['Uniq_Id']: row[image_cols].values.astype(np.float32) | |
for _, row in df.iterrows() | |
} | |
print(f"Successfully loaded {len(text_embeddings_dict)} embeddings") | |
return text_embeddings_dict, image_embeddings_dict | |
except Exception as e: | |
raise RuntimeError(f"Failed to load embeddings from Hugging Face: {str(e)}") | |
# FAISS index creation | |
class MultiModalFAISSIndex: | |
def __init__(self, dimension, index_type='L2'): | |
import faiss | |
self.dimension = dimension | |
self.index = faiss.IndexFlatL2(dimension) if index_type == 'L2' else faiss.IndexFlatIP(dimension) | |
self.id_to_metadata = {} | |
def add_embeddings(self, embeddings, metadata_list): | |
import numpy as np | |
embeddings = np.array(embeddings).astype('float32') | |
self.index.add(embeddings) | |
for i, metadata in enumerate(metadata_list): | |
self.id_to_metadata[i] = metadata | |
def search(self, query_embedding, k=5): | |
import numpy as np | |
query_embedding = np.array([query_embedding]).astype('float32') | |
distances, indices = self.index.search(query_embedding, k) | |
results = [] | |
for idx in indices[0]: | |
if idx in self.id_to_metadata: | |
results.append(self.id_to_metadata[idx]) | |
return results | |
def create_faiss_indexes(text_embeddings_dict, image_embeddings_dict): | |
"""Create FAISS indexes with error handling""" | |
global text_faiss, image_faiss | |
try: | |
# Get embedding dimension | |
text_dim = next(iter(text_embeddings_dict.values())).shape[0] | |
image_dim = next(iter(image_embeddings_dict.values())).shape[0] | |
# Create indexes | |
text_faiss = MultiModalFAISSIndex(text_dim) | |
image_faiss = MultiModalFAISSIndex(image_dim) | |
# Prepare text embeddings and metadata | |
text_embeddings = [] | |
text_metadata = [] | |
for text_id, embedding in text_embeddings_dict.items(): | |
if text_id in product_df['Uniq Id'].values: | |
product = product_df[product_df['Uniq Id'] == text_id].iloc[0] | |
text_embeddings.append(embedding) | |
text_metadata.append({ | |
'id': text_id, | |
'description': product['Normalized Description'], | |
'product_name': product['Product Name'] | |
}) | |
# Add text embeddings | |
if text_embeddings: | |
text_faiss.add_embeddings(text_embeddings, text_metadata) | |
# Prepare image embeddings and metadata | |
image_embeddings = [] | |
image_metadata = [] | |
for image_id, embedding in image_embeddings_dict.items(): | |
if image_id in product_df['Uniq Id'].values: | |
product = product_df[product_df['Uniq Id'] == image_id].iloc[0] | |
image_embeddings.append(embedding) | |
image_metadata.append({ | |
'id': image_id, | |
'image_url': product['Image'], | |
'product_name': product['Product Name'] | |
}) | |
# Add image embeddings | |
if image_embeddings: | |
image_faiss.add_embeddings(image_embeddings, image_metadata) | |
return True | |
except Exception as e: | |
raise RuntimeError(f"Failed to create FAISS indexes: {str(e)}") | |
def get_few_shot_product_comparison_template(): | |
return """Compare these specific products based on their actual features and specifications: | |
Example 1: | |
Question: Compare iPhone 13 and Samsung Galaxy S21 | |
Answer: The iPhone 13 features a 6.1-inch Super Retina XDR display and dual 12MP cameras, while the Galaxy S21 has a 6.2-inch Dynamic AMOLED display and triple camera setup. Both phones offer 5G connectivity, but the iPhone uses A15 Bionic chip while S21 uses Snapdragon 888. | |
Example 2: | |
Question: Compare Amazon Echo Dot and Google Nest Mini | |
Answer: The Amazon Echo Dot features Alexa voice assistant and a 1.6-inch speaker, while the Google Nest Mini comes with Google Assistant and a 40mm driver. Both devices offer smart home control and music playback, but differ in their ecosystem integration. | |
Current Question: {query} | |
Context: {context} | |
Guidelines: | |
- Only compare the specific products mentioned in the query | |
- Focus on actual product features and specifications | |
- Keep response to 2-3 clear sentences | |
- Ensure factual accuracy based on the context provided | |
Answer:""" | |
def get_zero_shot_product_template(): | |
return """You are a product information specialist. Describe only the specific product's actual features based on the provided context. | |
Context: {context} | |
Question: {query} | |
Guidelines: | |
- Only describe the specific product mentioned in the query | |
- Focus on actual features and specifications from the context | |
- Keep response to 2-3 factual sentences | |
- Ensure information accuracy | |
Answer:""" | |
def get_zero_shot_image_template(): | |
return """Analyze this product image and provide a concise description: | |
Product Information: | |
{context} | |
Guidelines: | |
- Describe the main product features and intended use | |
- Highlight key specifications and materials | |
- Keep response to 2-3 sentences | |
- Focus on practical information | |
Answer:""" | |
# Image processing functions | |
def process_image(image): | |
try: | |
if isinstance(image, str): | |
response = requests.get(image) | |
image = Image.open(io.BytesIO(response.content)) | |
processed_image = clip_preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(processed_image) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy() | |
except Exception as e: | |
raise Exception(f"Error processing image: {str(e)}") | |
def load_image_from_url(url): | |
response = requests.get(url) | |
if response.status_code == 200: | |
return Image.open(io.BytesIO(response.content)) | |
else: | |
raise Exception(f"Failed to fetch image from URL: {url}, Status Code: {response.status_code}") | |
# Context retrieval and enhancement | |
def filter_by_metadata(query, metadata_index): | |
relevant_products = set() | |
# Check category index | |
if 'category_index' in metadata_index: | |
categories = metadata_index['category_index'] | |
for category in categories: | |
if any(term.lower() in category.lower() for term in query.split()): | |
relevant_products.update(categories[category]) | |
# Check product name index | |
if 'product_name_index' in metadata_index: | |
product_names = metadata_index['product_name_index'] | |
for term in query.split(): | |
if term.lower() in product_names: | |
relevant_products.update(product_names[term.lower()]) | |
# Check price ranges | |
price_terms = {'cheap', 'expensive', 'price', 'cost', 'affordable'} | |
if any(term in query.lower() for term in price_terms) and 'price_range_index' in metadata_index: | |
price_ranges = metadata_index['price_range_index'] | |
for price_range in price_ranges: | |
relevant_products.update(price_ranges[price_range]) | |
return relevant_products if relevant_products else None | |
def enhance_context_with_metadata(product, metadata_index): | |
"""Enhanced context building using new metadata structure""" | |
# Access base_metadata using product ID directly since it's now a dictionary | |
base_metadata = metadata_index['base_metadata'].get(product['Uniq Id']) | |
if base_metadata: | |
# Get keywords and search text from enhanced metadata | |
keywords = base_metadata.get('Keywords', []) | |
search_text = base_metadata.get('Search_Text', '') | |
# Build enhanced description | |
description = [] | |
description.append(f"Product Name: {base_metadata['Product_Name']}") | |
description.append(f"Category: {base_metadata['Category']}") | |
description.append(f"Price: ${base_metadata['Selling_Price']:.2f}") | |
# Add key features from normalized description | |
if 'Normalized_Description' in base_metadata: | |
features = [] | |
for feature in base_metadata['Normalized_Description'].split('|'): | |
if ':' in feature: | |
key, value = feature.split(':', 1) | |
if not any(skip in key.lower() for skip in | |
['uniq id', 'product url', 'specifications', 'asin']): | |
features.append(f"{key.strip()}: {value.strip()}") | |
if features: | |
description.append("Key Features:") | |
description.extend(features[:3]) | |
# Add relevant keywords | |
if keywords: | |
description.append("Related Terms: " + ", ".join(list(keywords)[:5])) | |
return "\n".join(description) | |
return None | |
def retrieve_context(query, image=None, top_k=5): | |
"""Enhanced context retrieval using both FAISS and metadata""" | |
# Initialize context lists | |
similar_items = [] | |
context = [] | |
if image is not None: | |
# Process image query | |
image_embedding = process_image(image) | |
image_embedding = image_embedding.reshape(1, -1) | |
similar_items = image_faiss.search(image_embedding[0], k=top_k) | |
else: | |
# Process text query with enhanced metadata filtering | |
relevant_products = filter_by_metadata(query, metadata) | |
tokens = clip_tokenizer(query).to(device) | |
with torch.no_grad(): | |
text_embedding = clip_model.encode_text(tokens) | |
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
text_embedding = text_embedding.cpu().numpy() | |
# Get FAISS results | |
similar_items = text_faiss.search(text_embedding[0], k=top_k*2) # Get more results for filtering | |
# Filter results using metadata if available | |
if relevant_products: | |
similar_items = [item for item in similar_items if item['id'] in relevant_products][:top_k] | |
# Build enhanced context | |
for item in similar_items: | |
product = product_df[product_df['Uniq Id'] == item['id']].iloc[0] | |
enhanced_context = enhance_context_with_metadata(product, metadata) | |
if enhanced_context: | |
context.append(enhanced_context) | |
return "\n\n".join(context), similar_items | |
def display_product_images(similar_items, max_images=1): | |
displayed_images = [] | |
for item in similar_items[:max_images]: | |
try: | |
# Get image URL from product data | |
image_url = item['Image'] if isinstance(item, pd.Series) else item.get('Image') | |
if not image_url: | |
continue | |
# Handle multiple image URLs | |
image_urls = image_url.split('|') | |
image_url = image_urls[0] # Take first image | |
# Load image | |
response = requests.get(image_url) | |
img = Image.open(BytesIO(response.content)) | |
# Get product details | |
product_name = item['Product Name'] if isinstance(item, pd.Series) else item.get('product_name') | |
price = item['Selling Price'] if isinstance(item, pd.Series) else item.get('price', 0) | |
# Add to displayed images | |
displayed_images.append({ | |
'image': img, | |
'product_name': product_name, | |
'price': float(price) | |
}) | |
except Exception as e: | |
print(f"Error processing item: {str(e)}") | |
continue | |
return displayed_images | |
def classify_query(query): | |
"""Classify the type of query to determine the retrieval strategy.""" | |
query_lower = query.lower() | |
if any(keyword in query_lower for keyword in ['compare', 'difference between']): | |
return 'comparison' | |
elif any(keyword in query_lower for keyword in ['show', 'picture', 'image', 'photo']): | |
return 'image_search' | |
else: | |
return 'product_info' | |
def boost_category_relevance(query, product, similarity_score): | |
query_terms = set(query.lower().split()) | |
category_terms = set(product['Category'].lower().split()) | |
category_overlap = len(query_terms & category_terms) | |
category_boost = 1 + (category_overlap * 0.2) # 20% boost per matching term | |
return similarity_score * category_boost | |
def hybrid_retrieval(query, top_k=5): | |
query_type = classify_query(query) | |
tokens = clip_tokenizer(query).to(device) | |
with torch.no_grad(): | |
text_embedding = clip_model.encode_text(tokens) | |
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
text_embedding = text_embedding.cpu().numpy() | |
# First get text matches | |
text_results = text_faiss.search(text_embedding[0], k=top_k*2) | |
if query_type == 'image_search': | |
image_results = [] | |
for item in text_results: | |
# Get original product with embeddings intact | |
product = product_df[product_df['Uniq Id'] == item['id']].iloc[0] | |
# Get image embeddings from embeddings_df instead | |
image_embedding = embeddings_df[embeddings_df['Uniq_Id'] == item['id']]['image_embeddings'].iloc[0] | |
similarity = np.dot(text_embedding.flatten(), image_embedding.flatten()) | |
boosted_similarity = boost_category_relevance(query, product, similarity) | |
image_results.append((product, boosted_similarity)) | |
image_results.sort(key=lambda x: x[1], reverse=True) | |
results = [item for item, _ in image_results[:top_k]] | |
else: | |
results = [product_df[product_df['Uniq Id'] == item['id']].iloc[0] for item in text_results[:top_k]] | |
return results, query_type | |
def fallback_text_search(query, top_k=10): | |
relevant_products = filter_by_metadata(query, metadata) | |
if not relevant_products: | |
# Check brand index specifically | |
if 'brand_index' in metadata: | |
query_terms = query.lower().split() | |
for term in query_terms: | |
if term in metadata['brand_index']: | |
relevant_products = set(metadata['brand_index'][term]) | |
break | |
if relevant_products: | |
results = [product_df[product_df['Uniq Id'] == pid].iloc[0] for pid in list(relevant_products)[:top_k]] | |
else: | |
query_lower = query.lower() | |
results = product_df[ | |
(product_df['Product Name'].str.lower().str.contains(query_lower)) | | |
(product_df['Category'].str.lower().str.contains(query_lower)) | | |
(product_df['Normalized Description'].str.lower().str.contains(query_lower)) | |
].head(top_k) | |
return results | |
def generate_rag_response(query, context, image=None): | |
"""Enhanced RAG response generation""" | |
# Select template based on query type and metadata | |
if "compare" in query.lower() or "difference between" in query.lower() or "vs." in query.lower(): | |
template = get_few_shot_product_comparison_template() | |
elif image is not None: | |
template = get_zero_shot_image_template() | |
else: | |
template = get_zero_shot_product_template() | |
# Create enhanced prompt with metadata context | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=["query", "context"] | |
) | |
# Configure generation parameters | |
pipe = pipeline( | |
"text-generation", | |
model=llm_model, | |
tokenizer=llm_tokenizer, | |
max_new_tokens=300, | |
temperature=0.1, | |
do_sample=False, | |
repetition_penalty=1.2, | |
early_stopping=True, | |
truncation=True, | |
padding=True | |
) | |
# Generate and clean response | |
formatted_prompt = prompt.format(query=query, context=context) | |
response = pipe(formatted_prompt)[0]['generated_text'] | |
# Clean response | |
for section in ["Answer:", "Question:", "Guidelines:", "Context:"]: | |
if section in response: | |
response = response.split(section)[-1].strip() | |
return response | |
def chatbot(query, image_input=None): | |
""" | |
Main chatbot function to handle queries and provide responses. | |
""" | |
if image_input is not None: | |
try: | |
# Convert URL to image if needed | |
if isinstance(image_input, str): | |
image_input = load_image_from_url(image_input) | |
elif not isinstance(image_input, Image.Image): | |
raise ValueError("Invalid image input type") | |
# Get context and generate response | |
context, _ = retrieve_context(query, image_input) | |
if not context: | |
return "No relevant products found for this image." | |
response = generate_rag_response(query, context, image_input) | |
return response | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
return f"Failed to process image: {str(e)}" | |
else: | |
try: | |
print(f"Processing query: {query}") | |
if text_faiss is None or image_faiss is None: | |
return "Search indexes not initialized. Please try again." | |
results, query_type = hybrid_retrieval(query) | |
print(f"Query type: {query_type}") | |
if not results and query_type == 'image_search': | |
print("No relevant images found. Falling back to text search.") | |
results = fallback_text_search(query) | |
if not results: | |
return "No relevant products found." | |
context = "\n\n".join([enhance_context_with_metadata(item, metadata) for item in results]) | |
response = generate_rag_response(query, context) | |
if query_type == 'image_search': | |
print("\nFound matching products:") | |
displayed_images = display_product_images(results) | |
# Always return a dictionary with both text and images for image search queries | |
return { | |
'text': response, | |
'images': displayed_images | |
} | |
return response | |
except Exception as e: | |
print(f"Error processing query: {str(e)}") | |
return f"Error processing request: {str(e)}" | |
def cleanup_resources(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("GPU memory cleared") | |