|
import os |
|
import json |
|
import base64 |
|
from typing import List, Tuple, Dict |
|
import httpx |
|
from sqlalchemy import create_engine, text |
|
from dotenv import load_dotenv |
|
import google.generativeai as genai |
|
|
|
def get_secret(secret_name, service="", username=""): |
|
try: |
|
from google.colab import userdata |
|
return userdata.get(secret_name) |
|
except: |
|
try: |
|
return os.environ[secret_name] |
|
except: |
|
import keyring |
|
return keyring.get_password(service, username) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
DB_NAME = "kroyscappingdb" |
|
DB_USER = "read_only" |
|
DB_PASSWORD = get_secret('FASHION_PG_PASS') |
|
DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net" |
|
DB_PORT = "6432" |
|
|
|
DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" |
|
|
|
|
|
db_conn = create_engine(DATABASE_URL) |
|
|
|
|
|
genai.configure(api_key=get_secret("GEMINI_API_KEY")) |
|
|
|
def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]: |
|
"""Get marketplace and main image URL for a product.""" |
|
query = text(""" |
|
select mp, image as main_image_url |
|
from public.products |
|
where id_product_money = :id_product_money |
|
""") |
|
|
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result is None: |
|
raise ValueError(f"No product found with id_product_money: {id_product_money}") |
|
return result.mp, result.main_image_url |
|
|
|
def get_additional_images(id_product_money: str, marketplace: str) -> List[str]: |
|
"""Get additional images based on marketplace.""" |
|
if marketplace == 'lamoda': |
|
query = text(""" |
|
select info_chrc->'gallery' as more_images |
|
from public.lamoda_chrc_and_reviews |
|
where id_product_money = :id_product_money |
|
limit 1 |
|
""") |
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result and result.more_images: |
|
print(f"Lamoda raw more_images: {result.more_images}") |
|
|
|
if isinstance(result.more_images, str): |
|
paths = json.loads(result.more_images) |
|
else: |
|
paths = result.more_images |
|
return [f"https://a.lmcdn.ru/product{path}" for path in paths] |
|
|
|
elif marketplace == 'wildberries': |
|
query = text(""" |
|
select features->>'images' as more_images |
|
from public.wb_chrc |
|
where id_product_money = :id_product_money |
|
limit 1 |
|
""") |
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result and result.more_images: |
|
print(f"Wildberries raw more_images: {result.more_images}") |
|
try: |
|
data = json.loads(result.more_images) |
|
if isinstance(data, list): |
|
|
|
return [item.get('image_url') for item in data if item.get('image_url')] |
|
return [] |
|
except Exception as e: |
|
print(f"Error parsing JSON: {str(e)}") |
|
print(f"Type of more_images: {type(result.more_images)}") |
|
return [] |
|
|
|
return [] |
|
|
|
def try_scaled_image_url(client: httpx.Client, url: str, marketplace: str, max_retries: int = 3) -> str: |
|
"""Try to get a scaled version of the image URL, fall back to original if not available.""" |
|
scaled_url = url |
|
|
|
if marketplace == 'lamoda': |
|
scaled_url = url.replace('product', 'img600x866') |
|
elif marketplace == 'wildberries': |
|
scaled_url = url.replace('/big/', '/c516x688/') |
|
else: |
|
return url |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
response = client.get(scaled_url, timeout=5.0) |
|
if response.status_code == 200: |
|
print(f"Using scaled image: {scaled_url}") |
|
return scaled_url |
|
else: |
|
print(f"Scaled image not available (status {response.status_code}), using original: {url}") |
|
return url |
|
except httpx.TimeoutException: |
|
print(f"Timeout checking scaled image (attempt {attempt + 1}/{max_retries})") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, using original: {url}") |
|
return url |
|
except Exception as e: |
|
print(f"Error checking scaled image: {type(e).__name__}: {str(e)}") |
|
return url |
|
|
|
return url |
|
|
|
def download_and_encode_images(image_urls: List[str], marketplace: str) -> List[Dict]: |
|
"""Download images and convert them to base64 format for Gemini.""" |
|
encoded_images = [] |
|
timeout = httpx.Timeout(10.0, connect=5.0) |
|
with httpx.Client(timeout=timeout) as client: |
|
for url in image_urls: |
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
|
|
final_url = try_scaled_image_url(client, url, marketplace) |
|
response = client.get(final_url) |
|
response.raise_for_status() |
|
encoded_image = base64.b64encode(response.content).decode('utf-8') |
|
encoded_images.append({ |
|
'mime_type': 'image/jpeg', |
|
'data': encoded_image |
|
}) |
|
break |
|
except httpx.TimeoutException: |
|
print(f"Timeout downloading image (attempt {attempt + 1}/{max_retries}): {url}") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, skipping image: {url}") |
|
except Exception as e: |
|
print(f"Error downloading image: {type(e).__name__}: {str(e)}") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, skipping image: {url}") |
|
return encoded_images |
|
|
|
def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str: |
|
"""Get response from a Gemini model.""" |
|
try: |
|
model = genai.GenerativeModel(model_name) |
|
|
|
content = [] |
|
|
|
for img in encoded_images: |
|
content.append(img) |
|
|
|
content.append(prompt) |
|
|
|
response = model.generate_content(content) |
|
return response.text |
|
except Exception as e: |
|
return f"Error with {model_name}: {str(e)}" |
|
|
|
def process_input(id_product_money: str, prompt: str) -> Tuple[List[str], str, str]: |
|
"""Main processing function.""" |
|
try: |
|
print("Getting marketplace and main image...") |
|
marketplace, main_image = get_marketplace_and_main_image(id_product_money) |
|
print(f"Marketplace: {marketplace}") |
|
print(f"Main image: {main_image}") |
|
|
|
print("\nGetting additional images...") |
|
additional_images = get_additional_images(id_product_money, marketplace) |
|
print(f"Additional images: {additional_images}") |
|
|
|
|
|
all_image_urls = [] |
|
seen = set() |
|
for url in [main_image] + additional_images: |
|
if url not in seen: |
|
seen.add(url) |
|
all_image_urls.append(url) |
|
print(f"\nAll image URLs: {all_image_urls}") |
|
|
|
print("\nDownloading and encoding images...") |
|
encoded_images = download_and_encode_images(all_image_urls, marketplace) |
|
print(f"Number of encoded images: {len(encoded_images)}") |
|
|
|
if not encoded_images: |
|
raise ValueError("No images could be downloaded") |
|
|
|
print("\nGetting Gemini responses...") |
|
|
|
gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt) |
|
gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt) |
|
|
|
return all_image_urls, gemini_1_5_response, gemini_2_0_response |
|
|
|
except Exception as e: |
|
print(f"\nError in process_input: {str(e)}") |
|
return [], f"Error: {str(e)}", f"Error: {str(e)}" |
|
|
|
def main(): |
|
"""Command-line interface for testing.""" |
|
print("Product Image Analysis with Gemini Models") |
|
print("-" * 40) |
|
|
|
while True: |
|
try: |
|
id_product_money = input("\nEnter product ID (or 'q' to quit): ") |
|
if id_product_money.lower() == 'q': |
|
break |
|
|
|
prompt = input("Enter prompt (or press Enter for default 'What is this?'): ") |
|
if not prompt: |
|
prompt = "What is this?" |
|
|
|
print("\nProcessing...") |
|
image_urls, gemini_1_5_response, gemini_2_0_response = process_input(id_product_money, prompt) |
|
|
|
print("\nProduct Images:") |
|
for i, url in enumerate(image_urls, 1): |
|
print(f"{i}. {url}") |
|
|
|
print("\nGemini 1.5 Flash Response:") |
|
print("-" * 30) |
|
print(gemini_1_5_response) |
|
|
|
print("\nGemini 2.0 Flash Exp Response:") |
|
print("-" * 30) |
|
print(gemini_2_0_response) |
|
|
|
except KeyboardInterrupt: |
|
print("\nExiting...") |
|
break |
|
except Exception as e: |
|
print(f"\nError: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|