alexander-lazarin's picture
Use WB scaled-down images
70b8860
import os
import json
import base64
from typing import List, Tuple, Dict
import httpx
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
import google.generativeai as genai
def get_secret(secret_name, service="", username=""):
try:
from google.colab import userdata
return userdata.get(secret_name)
except:
try:
return os.environ[secret_name]
except:
import keyring
return keyring.get_password(service, username)
# Load environment variables
load_dotenv()
# Database configuration
DB_NAME = "kroyscappingdb"
DB_USER = "read_only"
DB_PASSWORD = get_secret('FASHION_PG_PASS')
DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net"
DB_PORT = "6432"
DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# Create the SQLAlchemy engine
db_conn = create_engine(DATABASE_URL)
# Configure Gemini API
genai.configure(api_key=get_secret("GEMINI_API_KEY"))
def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]:
"""Get marketplace and main image URL for a product."""
query = text("""
select mp, image as main_image_url
from public.products
where id_product_money = :id_product_money
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result is None:
raise ValueError(f"No product found with id_product_money: {id_product_money}")
return result.mp, result.main_image_url
def get_additional_images(id_product_money: str, marketplace: str) -> List[str]:
"""Get additional images based on marketplace."""
if marketplace == 'lamoda':
query = text("""
select info_chrc->'gallery' as more_images
from public.lamoda_chrc_and_reviews
where id_product_money = :id_product_money
limit 1
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result and result.more_images:
print(f"Lamoda raw more_images: {result.more_images}")
# Handle both string JSON and direct list cases
if isinstance(result.more_images, str):
paths = json.loads(result.more_images)
else:
paths = result.more_images
return [f"https://a.lmcdn.ru/product{path}" for path in paths]
elif marketplace == 'wildberries':
query = text("""
select features->>'images' as more_images
from public.wb_chrc
where id_product_money = :id_product_money
limit 1
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result and result.more_images:
print(f"Wildberries raw more_images: {result.more_images}")
try:
data = json.loads(result.more_images)
if isinstance(data, list):
# Extract image URLs from the JSON structure
return [item.get('image_url') for item in data if item.get('image_url')]
return []
except Exception as e:
print(f"Error parsing JSON: {str(e)}")
print(f"Type of more_images: {type(result.more_images)}")
return []
return []
def try_scaled_image_url(client: httpx.Client, url: str, marketplace: str, max_retries: int = 3) -> str:
"""Try to get a scaled version of the image URL, fall back to original if not available."""
scaled_url = url
if marketplace == 'lamoda':
scaled_url = url.replace('product', 'img600x866')
elif marketplace == 'wildberries':
scaled_url = url.replace('/big/', '/c516x688/')
else:
return url # No scaling for other marketplaces
for attempt in range(max_retries):
try:
response = client.get(scaled_url, timeout=5.0)
if response.status_code == 200:
print(f"Using scaled image: {scaled_url}")
return scaled_url
else:
print(f"Scaled image not available (status {response.status_code}), using original: {url}")
return url
except httpx.TimeoutException:
print(f"Timeout checking scaled image (attempt {attempt + 1}/{max_retries})")
if attempt == max_retries - 1:
print(f"Max retries reached, using original: {url}")
return url
except Exception as e:
print(f"Error checking scaled image: {type(e).__name__}: {str(e)}")
return url
return url
def download_and_encode_images(image_urls: List[str], marketplace: str) -> List[Dict]:
"""Download images and convert them to base64 format for Gemini."""
encoded_images = []
timeout = httpx.Timeout(10.0, connect=5.0)
with httpx.Client(timeout=timeout) as client:
for url in image_urls:
max_retries = 3
for attempt in range(max_retries):
try:
# Try to get scaled version if available
final_url = try_scaled_image_url(client, url, marketplace)
response = client.get(final_url)
response.raise_for_status()
encoded_image = base64.b64encode(response.content).decode('utf-8')
encoded_images.append({
'mime_type': 'image/jpeg', # Assuming JPEG format
'data': encoded_image
})
break # Success, exit retry loop
except httpx.TimeoutException:
print(f"Timeout downloading image (attempt {attempt + 1}/{max_retries}): {url}")
if attempt == max_retries - 1:
print(f"Max retries reached, skipping image: {url}")
except Exception as e:
print(f"Error downloading image: {type(e).__name__}: {str(e)}")
if attempt == max_retries - 1:
print(f"Max retries reached, skipping image: {url}")
return encoded_images
def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str:
"""Get response from a Gemini model."""
try:
model = genai.GenerativeModel(model_name)
# Create a list of content parts
content = []
# Add each image as a separate content part
for img in encoded_images:
content.append(img)
# Add the prompt as the final content part
content.append(prompt)
# Generate response
response = model.generate_content(content)
return response.text
except Exception as e:
return f"Error with {model_name}: {str(e)}"
def process_input(id_product_money: str, prompt: str) -> Tuple[List[str], str, str]:
"""Main processing function."""
try:
print("Getting marketplace and main image...")
marketplace, main_image = get_marketplace_and_main_image(id_product_money)
print(f"Marketplace: {marketplace}")
print(f"Main image: {main_image}")
print("\nGetting additional images...")
additional_images = get_additional_images(id_product_money, marketplace)
print(f"Additional images: {additional_images}")
# Combine all images and remove duplicates while preserving order
all_image_urls = []
seen = set()
for url in [main_image] + additional_images:
if url not in seen:
seen.add(url)
all_image_urls.append(url)
print(f"\nAll image URLs: {all_image_urls}")
print("\nDownloading and encoding images...")
encoded_images = download_and_encode_images(all_image_urls, marketplace)
print(f"Number of encoded images: {len(encoded_images)}")
if not encoded_images:
raise ValueError("No images could be downloaded")
print("\nGetting Gemini responses...")
# Get responses from both models
gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt)
gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt)
return all_image_urls, gemini_1_5_response, gemini_2_0_response
except Exception as e:
print(f"\nError in process_input: {str(e)}")
return [], f"Error: {str(e)}", f"Error: {str(e)}"
def main():
"""Command-line interface for testing."""
print("Product Image Analysis with Gemini Models")
print("-" * 40)
while True:
try:
id_product_money = input("\nEnter product ID (or 'q' to quit): ")
if id_product_money.lower() == 'q':
break
prompt = input("Enter prompt (or press Enter for default 'What is this?'): ")
if not prompt:
prompt = "What is this?"
print("\nProcessing...")
image_urls, gemini_1_5_response, gemini_2_0_response = process_input(id_product_money, prompt)
print("\nProduct Images:")
for i, url in enumerate(image_urls, 1):
print(f"{i}. {url}")
print("\nGemini 1.5 Flash Response:")
print("-" * 30)
print(gemini_1_5_response)
print("\nGemini 2.0 Flash Exp Response:")
print("-" * 30)
print(gemini_2_0_response)
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"\nError: {str(e)}")
if __name__ == "__main__":
main()