alexander-lazarin's picture
Add a CLI testing file
1b944da
raw
history blame
7.89 kB
import os
import json
import base64
from typing import List, Tuple, Dict
import httpx
from sqlalchemy import create_engine, text
from dotenv import load_dotenv
import google.generativeai as genai
def get_secret(secret_name, service="", username=""):
try:
from google.colab import userdata
return userdata.get(secret_name)
except:
try:
return os.environ[secret_name]
except:
import keyring
return keyring.get_password(service, username)
# Load environment variables
load_dotenv()
# Database configuration
DB_NAME = "kroyscappingdb"
DB_USER = "read_only"
DB_PASSWORD = get_secret('FASHION_PG_PASS')
DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net"
DB_PORT = "6432"
DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
# Create the SQLAlchemy engine
db_conn = create_engine(DATABASE_URL)
# Configure Gemini API
genai.configure(api_key=get_secret("GEMINI_API_KEY"))
def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]:
"""Get marketplace and main image URL for a product."""
query = text("""
select mp, image as main_image_url
from public.products
where id_product_money = :id_product_money
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result is None:
raise ValueError(f"No product found with id_product_money: {id_product_money}")
return result.mp, result.main_image_url
def get_additional_images(id_product_money: str, marketplace: str) -> List[str]:
"""Get additional images based on marketplace."""
if marketplace == 'lamoda':
query = text("""
select info_chrc->'gallery' as more_images
from public.lamoda_chrc_and_reviews
where id_product_money = :id_product_money
limit 1
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result and result.more_images:
print(f"Lamoda raw more_images: {result.more_images}")
# Handle both string JSON and direct list cases
if isinstance(result.more_images, str):
paths = json.loads(result.more_images)
else:
paths = result.more_images
return [f"https://a.lmcdn.ru/product{path}" for path in paths]
elif marketplace == 'wildberries':
query = text("""
select features->>'images' as more_images
from public.wb_chrc
where id_product_money = :id_product_money
limit 1
""")
with db_conn.connect() as connection:
result = connection.execute(query, {"id_product_money": id_product_money}).first()
if result and result.more_images:
print(f"Wildberries raw more_images: {result.more_images}")
try:
urls = json.loads(result.more_images)
if isinstance(urls, list) and len(urls) > 0:
# Split the URLs by semicolons
return urls[0].split(';')
return []
except Exception as e:
print(f"Error parsing JSON: {str(e)}")
print(f"Type of more_images: {type(result.more_images)}")
return []
return []
def download_and_encode_images(image_urls: List[str]) -> List[Dict]:
"""Download images and convert them to base64 format for Gemini."""
encoded_images = []
with httpx.Client() as client:
for url in image_urls:
try:
response = client.get(url)
response.raise_for_status()
encoded_image = base64.b64encode(response.content).decode('utf-8')
encoded_images.append({
'mime_type': 'image/jpeg', # Assuming JPEG format
'data': encoded_image
})
except Exception as e:
print(f"Error downloading image {url}: {str(e)}")
return encoded_images
def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str:
"""Get response from a Gemini model."""
try:
model = genai.GenerativeModel(model_name)
# Create a list of content parts
content = []
# Add each image as a separate content part
for img in encoded_images:
content.append(img)
# Add the prompt as the final content part
content.append(prompt)
# Generate response
response = model.generate_content(content)
return response.text
except Exception as e:
return f"Error with {model_name}: {str(e)}"
def process_input(id_product_money: str, prompt: str) -> Tuple[List[str], str, str]:
"""Main processing function."""
try:
print("Getting marketplace and main image...")
marketplace, main_image = get_marketplace_and_main_image(id_product_money)
print(f"Marketplace: {marketplace}")
print(f"Main image: {main_image}")
print("\nGetting additional images...")
additional_images = get_additional_images(id_product_money, marketplace)
print(f"Additional images: {additional_images}")
# Combine all images and remove duplicates while preserving order
all_image_urls = []
seen = set()
for url in [main_image] + additional_images:
if url not in seen:
seen.add(url)
all_image_urls.append(url)
print(f"\nAll image URLs: {all_image_urls}")
print("\nDownloading and encoding images...")
encoded_images = download_and_encode_images(all_image_urls)
print(f"Number of encoded images: {len(encoded_images)}")
if not encoded_images:
raise ValueError("No images could be downloaded")
print("\nGetting Gemini responses...")
# Get responses from both models
gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt)
gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt)
return all_image_urls, gemini_1_5_response, gemini_2_0_response
except Exception as e:
print(f"\nError in process_input: {str(e)}")
return [], f"Error: {str(e)}", f"Error: {str(e)}"
def main():
"""Command-line interface for testing."""
print("Product Image Analysis with Gemini Models")
print("-" * 40)
while True:
try:
id_product_money = input("\nEnter product ID (or 'q' to quit): ")
if id_product_money.lower() == 'q':
break
prompt = input("Enter prompt (or press Enter for default 'What is this?'): ")
if not prompt:
prompt = "What is this?"
print("\nProcessing...")
image_urls, gemini_1_5_response, gemini_2_0_response = process_input(id_product_money, prompt)
print("\nProduct Images:")
for i, url in enumerate(image_urls, 1):
print(f"{i}. {url}")
print("\nGemini 1.5 Flash Response:")
print("-" * 30)
print(gemini_1_5_response)
print("\nGemini 2.0 Flash Exp Response:")
print("-" * 30)
print(gemini_2_0_response)
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"\nError: {str(e)}")
if __name__ == "__main__":
main()