|
import os |
|
import json |
|
import base64 |
|
from typing import List, Tuple, Dict |
|
import gradio as gr |
|
import httpx |
|
from sqlalchemy import create_engine, text |
|
from dotenv import load_dotenv |
|
import google.generativeai as genai |
|
|
|
def get_secret(secret_name, service="", username=""): |
|
try: |
|
from google.colab import userdata |
|
return userdata.get(secret_name) |
|
except: |
|
try: |
|
return os.environ[secret_name] |
|
except: |
|
import keyring |
|
return keyring.get_password(service, username) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
DB_NAME = "kroyscappingdb" |
|
DB_USER = "read_only" |
|
DB_PASSWORD = get_secret('FASHION_PG_PASS') |
|
DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net" |
|
DB_PORT = "6432" |
|
|
|
DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" |
|
|
|
|
|
db_conn = create_engine(DATABASE_URL) |
|
|
|
|
|
genai.configure(api_key=get_secret("GEMINI_API_KEY")) |
|
|
|
def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]: |
|
"""Get marketplace and main image URL for a product.""" |
|
query = text(""" |
|
select mp, image as main_image_url |
|
from public.products |
|
where id_product_money = :id_product_money |
|
""") |
|
|
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result is None: |
|
raise ValueError(f"No product found with id_product_money: {id_product_money}") |
|
return result.mp, result.main_image_url |
|
|
|
def get_additional_images(id_product_money: str, marketplace: str) -> List[str]: |
|
"""Get additional images based on marketplace.""" |
|
if marketplace == 'lamoda': |
|
query = text(""" |
|
select info_chrc->'gallery' as more_images |
|
from public.lamoda_chrc_and_reviews |
|
where id_product_money = :id_product_money |
|
limit 1 |
|
""") |
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result and result.more_images: |
|
print(f"Lamoda raw more_images: {result.more_images}") |
|
|
|
if isinstance(result.more_images, str): |
|
paths = json.loads(result.more_images) |
|
else: |
|
paths = result.more_images |
|
return [f"https://a.lmcdn.ru/product{path}" for path in paths] |
|
|
|
elif marketplace == 'wildberries': |
|
query = text(""" |
|
select features->>'images' as more_images |
|
from public.wb_chrc |
|
where id_product_money = :id_product_money |
|
limit 1 |
|
""") |
|
with db_conn.connect() as connection: |
|
result = connection.execute(query, {"id_product_money": id_product_money}).first() |
|
if result and result.more_images: |
|
print(f"Wildberries raw more_images: {result.more_images}") |
|
try: |
|
data = json.loads(result.more_images) |
|
if isinstance(data, list): |
|
|
|
return [item.get('image_url') for item in data if item.get('image_url')] |
|
return [] |
|
except Exception as e: |
|
print(f"Error parsing JSON: {str(e)}") |
|
print(f"Type of more_images: {type(result.more_images)}") |
|
return [] |
|
|
|
return [] |
|
|
|
def try_scaled_image_url(client: httpx.Client, url: str, marketplace: str, max_retries: int = 3) -> str: |
|
"""Try to get a scaled version of the image URL, fall back to original if not available.""" |
|
scaled_url = url |
|
|
|
if marketplace == 'lamoda': |
|
scaled_url = url.replace('product', 'img600x866') |
|
elif marketplace == 'wildberries': |
|
scaled_url = url.replace('/big/', '/c516x688/') |
|
else: |
|
return url |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
response = client.get(scaled_url, timeout=5.0) |
|
if response.status_code == 200: |
|
print(f"Using scaled image: {scaled_url}") |
|
return scaled_url |
|
else: |
|
print(f"Scaled image not available (status {response.status_code}), using original: {url}") |
|
return url |
|
except httpx.TimeoutException: |
|
print(f"Timeout checking scaled image (attempt {attempt + 1}/{max_retries})") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, using original: {url}") |
|
return url |
|
except Exception as e: |
|
print(f"Error checking scaled image: {type(e).__name__}: {str(e)}") |
|
return url |
|
|
|
return url |
|
|
|
def download_and_encode_images(image_urls: List[str], marketplace: str) -> Tuple[List[Dict], List[str]]: |
|
"""Download images and convert them to base64 format for Gemini.""" |
|
encoded_images = [] |
|
successful_urls = [] |
|
timeout = httpx.Timeout(10.0, connect=5.0) |
|
with httpx.Client(timeout=timeout) as client: |
|
for url in image_urls: |
|
max_retries = 3 |
|
for attempt in range(max_retries): |
|
try: |
|
|
|
final_url = try_scaled_image_url(client, url, marketplace) |
|
response = client.get(final_url) |
|
response.raise_for_status() |
|
encoded_image = base64.b64encode(response.content).decode('utf-8') |
|
encoded_images.append({ |
|
'mime_type': 'image/jpeg', |
|
'data': encoded_image |
|
}) |
|
successful_urls.append(final_url) |
|
break |
|
except httpx.TimeoutException: |
|
print(f"Timeout downloading image (attempt {attempt + 1}/{max_retries}): {url}") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, skipping image: {url}") |
|
except Exception as e: |
|
print(f"Error downloading image: {type(e).__name__}: {str(e)}") |
|
if attempt == max_retries - 1: |
|
print(f"Max retries reached, skipping image: {url}") |
|
return encoded_images, successful_urls |
|
|
|
def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str: |
|
"""Get response from a Gemini model.""" |
|
try: |
|
model = genai.GenerativeModel(model_name) |
|
|
|
content = [] |
|
|
|
for img in encoded_images: |
|
content.append(img) |
|
|
|
content.append(prompt) |
|
|
|
response = model.generate_content(content) |
|
return response.text |
|
except Exception as e: |
|
return f"Error with {model_name}: {str(e)}" |
|
|
|
def process_input(id_product_money: str, prompt: str, progress=gr.Progress()) -> Tuple[List[str], str, str, str]: |
|
"""Main processing function.""" |
|
try: |
|
status_msg = "Getting product data from database..." |
|
progress(0, desc=status_msg) |
|
marketplace, main_image = get_marketplace_and_main_image(id_product_money) |
|
print(f"Marketplace: {marketplace}") |
|
print(f"Main image: {main_image}") |
|
|
|
status_msg = "Fetching additional product images..." |
|
progress(0.2, desc=status_msg) |
|
additional_images = get_additional_images(id_product_money, marketplace) |
|
print(f"Additional images: {additional_images}") |
|
|
|
|
|
all_image_urls = [] |
|
seen = set() |
|
for url in [main_image] + additional_images: |
|
if url not in seen: |
|
seen.add(url) |
|
all_image_urls.append(url) |
|
print(f"\nAll image URLs: {all_image_urls}") |
|
|
|
status_msg = "Downloading and processing images..." |
|
progress(0.4, desc=status_msg) |
|
encoded_images, successful_urls = download_and_encode_images(all_image_urls, marketplace) |
|
print(f"Number of encoded images: {len(encoded_images)}") |
|
|
|
if not encoded_images: |
|
raise ValueError("No images could be downloaded") |
|
|
|
|
|
all_image_urls = successful_urls |
|
|
|
status_msg = "Getting response from Gemini 1.5 Flash..." |
|
progress(0.6, desc=status_msg) |
|
gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt) |
|
|
|
status_msg = "Getting response from Gemini 2.0 Flash Exp..." |
|
progress(0.8, desc=status_msg) |
|
gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt) |
|
|
|
status_msg = "Analysis complete!" |
|
progress(1.0, desc=status_msg) |
|
return all_image_urls, gemini_1_5_response, gemini_2_0_response, status_msg |
|
|
|
except Exception as e: |
|
print(f"\nError in process_input: {str(e)}") |
|
status_msg = f"Error: {str(e)}" |
|
progress(1.0, desc=status_msg) |
|
return [], f"Error: {str(e)}", f"Error: {str(e)}", status_msg |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Product Image Analysis with Gemini Models") |
|
|
|
with gr.Row(): |
|
id_input = gr.Textbox(label="Product ID (id_product_money)") |
|
prompt_input = gr.Textbox(label="Prompt for VLMs", value="What is this?") |
|
|
|
submit_btn = gr.Button("Analyze") |
|
|
|
|
|
status = gr.Textbox(label="Status", value="Waiting for input...", interactive=False) |
|
|
|
with gr.Row(): |
|
image_gallery = gr.Gallery(label="Product Images", show_label=True) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Gemini 1.5 Flash Response") |
|
gemini_1_5_output = gr.Textbox(label="", show_copy_button=True) |
|
with gr.Column(): |
|
gr.Markdown("### Gemini 2.0 Flash Exp Response") |
|
gemini_2_0_output = gr.Textbox(label="", show_copy_button=True) |
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[id_input, prompt_input], |
|
outputs=[image_gallery, gemini_1_5_output, gemini_2_0_output, status], |
|
show_progress="full" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|