import os import json import base64 from typing import List, Tuple, Dict import gradio as gr import httpx from sqlalchemy import create_engine, text from dotenv import load_dotenv import google.generativeai as genai def get_secret(secret_name, service="", username=""): try: from google.colab import userdata return userdata.get(secret_name) except: try: return os.environ[secret_name] except: import keyring return keyring.get_password(service, username) # Load environment variables load_dotenv() # Database configuration DB_NAME = "kroyscappingdb" DB_USER = "read_only" DB_PASSWORD = get_secret('FASHION_PG_PASS') DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net" DB_PORT = "6432" DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" # Create the SQLAlchemy engine db_conn = create_engine(DATABASE_URL) # Configure Gemini API genai.configure(api_key=get_secret("GEMINI_API_KEY")) def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]: """Get marketplace and main image URL for a product.""" query = text(""" select mp, image as main_image_url from public.products where id_product_money = :id_product_money """) with db_conn.connect() as connection: result = connection.execute(query, {"id_product_money": id_product_money}).first() if result is None: raise ValueError(f"No product found with id_product_money: {id_product_money}") return result.mp, result.main_image_url def get_additional_images(id_product_money: str, marketplace: str) -> List[str]: """Get additional images based on marketplace.""" if marketplace == 'lamoda': query = text(""" select info_chrc->'gallery' as more_images from public.lamoda_chrc_and_reviews where id_product_money = :id_product_money limit 1 """) with db_conn.connect() as connection: result = connection.execute(query, {"id_product_money": id_product_money}).first() if result and result.more_images: print(f"Lamoda raw more_images: {result.more_images}") # Handle both string JSON and direct list cases if isinstance(result.more_images, str): paths = json.loads(result.more_images) else: paths = result.more_images return [f"https://a.lmcdn.ru/product{path}" for path in paths] elif marketplace == 'wildberries': query = text(""" select features->>'images' as more_images from public.wb_chrc where id_product_money = :id_product_money limit 1 """) with db_conn.connect() as connection: result = connection.execute(query, {"id_product_money": id_product_money}).first() if result and result.more_images: print(f"Wildberries raw more_images: {result.more_images}") try: data = json.loads(result.more_images) if isinstance(data, list): # Extract image URLs from the JSON structure return [item.get('image_url') for item in data if item.get('image_url')] return [] except Exception as e: print(f"Error parsing JSON: {str(e)}") print(f"Type of more_images: {type(result.more_images)}") return [] return [] def try_scaled_image_url(client: httpx.Client, url: str, marketplace: str, max_retries: int = 3) -> str: """Try to get a scaled version of the image URL, fall back to original if not available.""" scaled_url = url if marketplace == 'lamoda': scaled_url = url.replace('product', 'img600x866') elif marketplace == 'wildberries': scaled_url = url.replace('/big/', '/c516x688/') else: return url # No scaling for other marketplaces for attempt in range(max_retries): try: response = client.get(scaled_url, timeout=5.0) if response.status_code == 200: print(f"Using scaled image: {scaled_url}") return scaled_url else: print(f"Scaled image not available (status {response.status_code}), using original: {url}") return url except httpx.TimeoutException: print(f"Timeout checking scaled image (attempt {attempt + 1}/{max_retries})") if attempt == max_retries - 1: print(f"Max retries reached, using original: {url}") return url except Exception as e: print(f"Error checking scaled image: {type(e).__name__}: {str(e)}") return url return url def download_and_encode_images(image_urls: List[str], marketplace: str) -> Tuple[List[Dict], List[str]]: """Download images and convert them to base64 format for Gemini.""" encoded_images = [] successful_urls = [] # Track URLs that were successfully downloaded timeout = httpx.Timeout(10.0, connect=5.0) with httpx.Client(timeout=timeout) as client: for url in image_urls: max_retries = 3 for attempt in range(max_retries): try: # Try to get scaled version if available final_url = try_scaled_image_url(client, url, marketplace) response = client.get(final_url) response.raise_for_status() encoded_image = base64.b64encode(response.content).decode('utf-8') encoded_images.append({ 'mime_type': 'image/jpeg', # Assuming JPEG format 'data': encoded_image }) successful_urls.append(final_url) # Store the URL that worked (original or scaled) break # Success, exit retry loop except httpx.TimeoutException: print(f"Timeout downloading image (attempt {attempt + 1}/{max_retries}): {url}") if attempt == max_retries - 1: print(f"Max retries reached, skipping image: {url}") except Exception as e: print(f"Error downloading image: {type(e).__name__}: {str(e)}") if attempt == max_retries - 1: print(f"Max retries reached, skipping image: {url}") return encoded_images, successful_urls def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str: """Get response from a Gemini model.""" try: model = genai.GenerativeModel(model_name) # Create a list of content parts content = [] # Add each image as a separate content part for img in encoded_images: content.append(img) # Add the prompt as the final content part content.append(prompt) # Generate response response = model.generate_content(content) return response.text except Exception as e: return f"Error with {model_name}: {str(e)}" def process_input(id_product_money: str, prompt: str, progress=gr.Progress()) -> Tuple[List[str], str, str, str]: """Main processing function.""" try: status_msg = "Getting product data from database..." progress(0, desc=status_msg) marketplace, main_image = get_marketplace_and_main_image(id_product_money) print(f"Marketplace: {marketplace}") print(f"Main image: {main_image}") status_msg = "Fetching additional product images..." progress(0.2, desc=status_msg) additional_images = get_additional_images(id_product_money, marketplace) print(f"Additional images: {additional_images}") # Combine all images and remove duplicates while preserving order all_image_urls = [] seen = set() for url in [main_image] + additional_images: if url not in seen: seen.add(url) all_image_urls.append(url) print(f"\nAll image URLs: {all_image_urls}") status_msg = "Downloading and processing images..." progress(0.4, desc=status_msg) encoded_images, successful_urls = download_and_encode_images(all_image_urls, marketplace) print(f"Number of encoded images: {len(encoded_images)}") if not encoded_images: raise ValueError("No images could be downloaded") # Update all_image_urls to only include successfully downloaded URLs all_image_urls = successful_urls status_msg = "Getting response from Gemini 1.5 Flash..." progress(0.6, desc=status_msg) gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt) status_msg = "Getting response from Gemini 2.0 Flash Exp..." progress(0.8, desc=status_msg) gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt) status_msg = "Analysis complete!" progress(1.0, desc=status_msg) return all_image_urls, gemini_1_5_response, gemini_2_0_response, status_msg except Exception as e: print(f"\nError in process_input: {str(e)}") status_msg = f"Error: {str(e)}" progress(1.0, desc=status_msg) return [], f"Error: {str(e)}", f"Error: {str(e)}", status_msg # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Product Image Analysis with Gemini Models") with gr.Row(): id_input = gr.Textbox(label="Product ID (id_product_money)") prompt_input = gr.Textbox(label="Prompt for VLMs", value="What is this?") submit_btn = gr.Button("Analyze") # Status indicator status = gr.Textbox(label="Status", value="Waiting for input...", interactive=False) with gr.Row(): image_gallery = gr.Gallery(label="Product Images", show_label=True) with gr.Row(): with gr.Column(): gr.Markdown("### Gemini 1.5 Flash Response") gemini_1_5_output = gr.Textbox(label="", show_copy_button=True) with gr.Column(): gr.Markdown("### Gemini 2.0 Flash Exp Response") gemini_2_0_output = gr.Textbox(label="", show_copy_button=True) submit_btn.click( fn=process_input, inputs=[id_input, prompt_input], outputs=[image_gallery, gemini_1_5_output, gemini_2_0_output, status], show_progress="full" ) if __name__ == "__main__": demo.launch(share=True)