alexander-lazarin commited on
Commit
1b944da
·
1 Parent(s): b11f8f0

Add a CLI testing file

Browse files
Files changed (1) hide show
  1. test.py +207 -0
test.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from typing import List, Tuple, Dict
5
+ import httpx
6
+ from sqlalchemy import create_engine, text
7
+ from dotenv import load_dotenv
8
+ import google.generativeai as genai
9
+
10
+ def get_secret(secret_name, service="", username=""):
11
+ try:
12
+ from google.colab import userdata
13
+ return userdata.get(secret_name)
14
+ except:
15
+ try:
16
+ return os.environ[secret_name]
17
+ except:
18
+ import keyring
19
+ return keyring.get_password(service, username)
20
+
21
+ # Load environment variables
22
+ load_dotenv()
23
+
24
+ # Database configuration
25
+ DB_NAME = "kroyscappingdb"
26
+ DB_USER = "read_only"
27
+ DB_PASSWORD = get_secret('FASHION_PG_PASS')
28
+ DB_HOST = "rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net"
29
+ DB_PORT = "6432"
30
+
31
+ DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
32
+
33
+ # Create the SQLAlchemy engine
34
+ db_conn = create_engine(DATABASE_URL)
35
+
36
+ # Configure Gemini API
37
+ genai.configure(api_key=get_secret("GEMINI_API_KEY"))
38
+
39
+ def get_marketplace_and_main_image(id_product_money: str) -> Tuple[str, str]:
40
+ """Get marketplace and main image URL for a product."""
41
+ query = text("""
42
+ select mp, image as main_image_url
43
+ from public.products
44
+ where id_product_money = :id_product_money
45
+ """)
46
+
47
+ with db_conn.connect() as connection:
48
+ result = connection.execute(query, {"id_product_money": id_product_money}).first()
49
+ if result is None:
50
+ raise ValueError(f"No product found with id_product_money: {id_product_money}")
51
+ return result.mp, result.main_image_url
52
+
53
+ def get_additional_images(id_product_money: str, marketplace: str) -> List[str]:
54
+ """Get additional images based on marketplace."""
55
+ if marketplace == 'lamoda':
56
+ query = text("""
57
+ select info_chrc->'gallery' as more_images
58
+ from public.lamoda_chrc_and_reviews
59
+ where id_product_money = :id_product_money
60
+ limit 1
61
+ """)
62
+ with db_conn.connect() as connection:
63
+ result = connection.execute(query, {"id_product_money": id_product_money}).first()
64
+ if result and result.more_images:
65
+ print(f"Lamoda raw more_images: {result.more_images}")
66
+ # Handle both string JSON and direct list cases
67
+ if isinstance(result.more_images, str):
68
+ paths = json.loads(result.more_images)
69
+ else:
70
+ paths = result.more_images
71
+ return [f"https://a.lmcdn.ru/product{path}" for path in paths]
72
+
73
+ elif marketplace == 'wildberries':
74
+ query = text("""
75
+ select features->>'images' as more_images
76
+ from public.wb_chrc
77
+ where id_product_money = :id_product_money
78
+ limit 1
79
+ """)
80
+ with db_conn.connect() as connection:
81
+ result = connection.execute(query, {"id_product_money": id_product_money}).first()
82
+ if result and result.more_images:
83
+ print(f"Wildberries raw more_images: {result.more_images}")
84
+ try:
85
+ urls = json.loads(result.more_images)
86
+ if isinstance(urls, list) and len(urls) > 0:
87
+ # Split the URLs by semicolons
88
+ return urls[0].split(';')
89
+ return []
90
+ except Exception as e:
91
+ print(f"Error parsing JSON: {str(e)}")
92
+ print(f"Type of more_images: {type(result.more_images)}")
93
+ return []
94
+
95
+ return []
96
+
97
+ def download_and_encode_images(image_urls: List[str]) -> List[Dict]:
98
+ """Download images and convert them to base64 format for Gemini."""
99
+ encoded_images = []
100
+ with httpx.Client() as client:
101
+ for url in image_urls:
102
+ try:
103
+ response = client.get(url)
104
+ response.raise_for_status()
105
+ encoded_image = base64.b64encode(response.content).decode('utf-8')
106
+ encoded_images.append({
107
+ 'mime_type': 'image/jpeg', # Assuming JPEG format
108
+ 'data': encoded_image
109
+ })
110
+ except Exception as e:
111
+ print(f"Error downloading image {url}: {str(e)}")
112
+ return encoded_images
113
+
114
+ def get_gemini_response(model_name: str, encoded_images: List[Dict], prompt: str) -> str:
115
+ """Get response from a Gemini model."""
116
+ try:
117
+ model = genai.GenerativeModel(model_name)
118
+ # Create a list of content parts
119
+ content = []
120
+ # Add each image as a separate content part
121
+ for img in encoded_images:
122
+ content.append(img)
123
+ # Add the prompt as the final content part
124
+ content.append(prompt)
125
+ # Generate response
126
+ response = model.generate_content(content)
127
+ return response.text
128
+ except Exception as e:
129
+ return f"Error with {model_name}: {str(e)}"
130
+
131
+ def process_input(id_product_money: str, prompt: str) -> Tuple[List[str], str, str]:
132
+ """Main processing function."""
133
+ try:
134
+ print("Getting marketplace and main image...")
135
+ marketplace, main_image = get_marketplace_and_main_image(id_product_money)
136
+ print(f"Marketplace: {marketplace}")
137
+ print(f"Main image: {main_image}")
138
+
139
+ print("\nGetting additional images...")
140
+ additional_images = get_additional_images(id_product_money, marketplace)
141
+ print(f"Additional images: {additional_images}")
142
+
143
+ # Combine all images and remove duplicates while preserving order
144
+ all_image_urls = []
145
+ seen = set()
146
+ for url in [main_image] + additional_images:
147
+ if url not in seen:
148
+ seen.add(url)
149
+ all_image_urls.append(url)
150
+ print(f"\nAll image URLs: {all_image_urls}")
151
+
152
+ print("\nDownloading and encoding images...")
153
+ encoded_images = download_and_encode_images(all_image_urls)
154
+ print(f"Number of encoded images: {len(encoded_images)}")
155
+
156
+ if not encoded_images:
157
+ raise ValueError("No images could be downloaded")
158
+
159
+ print("\nGetting Gemini responses...")
160
+ # Get responses from both models
161
+ gemini_1_5_response = get_gemini_response("gemini-1.5-flash", encoded_images, prompt)
162
+ gemini_2_0_response = get_gemini_response("gemini-2.0-flash-exp", encoded_images, prompt)
163
+
164
+ return all_image_urls, gemini_1_5_response, gemini_2_0_response
165
+
166
+ except Exception as e:
167
+ print(f"\nError in process_input: {str(e)}")
168
+ return [], f"Error: {str(e)}", f"Error: {str(e)}"
169
+
170
+ def main():
171
+ """Command-line interface for testing."""
172
+ print("Product Image Analysis with Gemini Models")
173
+ print("-" * 40)
174
+
175
+ while True:
176
+ try:
177
+ id_product_money = input("\nEnter product ID (or 'q' to quit): ")
178
+ if id_product_money.lower() == 'q':
179
+ break
180
+
181
+ prompt = input("Enter prompt (or press Enter for default 'What is this?'): ")
182
+ if not prompt:
183
+ prompt = "What is this?"
184
+
185
+ print("\nProcessing...")
186
+ image_urls, gemini_1_5_response, gemini_2_0_response = process_input(id_product_money, prompt)
187
+
188
+ print("\nProduct Images:")
189
+ for i, url in enumerate(image_urls, 1):
190
+ print(f"{i}. {url}")
191
+
192
+ print("\nGemini 1.5 Flash Response:")
193
+ print("-" * 30)
194
+ print(gemini_1_5_response)
195
+
196
+ print("\nGemini 2.0 Flash Exp Response:")
197
+ print("-" * 30)
198
+ print(gemini_2_0_response)
199
+
200
+ except KeyboardInterrupt:
201
+ print("\nExiting...")
202
+ break
203
+ except Exception as e:
204
+ print(f"\nError: {str(e)}")
205
+
206
+ if __name__ == "__main__":
207
+ main()