alon-astria commited on
Commit
bd63d41
·
verified ·
1 Parent(s): 3a9b7c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -75
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # app.py
2
  import gradio as gr
 
3
  import requests
4
  import os
5
  import time
@@ -13,10 +14,10 @@ ASTRIA_API_KEY = os.environ.get("ASTRIA_API_KEY")
13
 
14
  # --- Core Application Logic ---
15
 
16
- def virtual_tryon(human_img_path, garment_img_path, garment_type):
17
  """
18
  This function handles the entire process:
19
- 1. Validates inputs and API key.
20
  2. Constructs and sends a request to create a Tune and associated Prompt.
21
  3. Polls the Prompt endpoint until the model is trained and image is generated.
22
  4. Yields status updates to the Gradio UI.
@@ -29,20 +30,56 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
29
  if not human_img_path or not garment_img_path:
30
  raise gr.Error("Please upload both a human image and a garment image.")
31
 
32
- yield None, "Step 1/3: Preparing and uploading your images..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
35
  prompt_id = None
36
 
37
  try:
38
- # 2. Construct and Send API Request to create Tune and Prompt
39
  with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garment_f:
40
  files = [
41
  ("tune[images][]", ("garment.jpg", garment_f.read(), "image/jpeg")),
42
  ("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
43
  ]
44
  data = {
45
- "tune[title]": f"Gradio VTO - {int(time.time())}",
46
  "tune[name]": garment_type,
47
  "tune[model_type]": "faceid",
48
  "tune[base_tune_id]": "1504944", # This is the base model for Flux although no image generation is done in this case
@@ -72,9 +109,9 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
72
  except Exception as e:
73
  raise gr.Error(f"An unexpected error occurred during upload: {e}")
74
 
75
- yield None, f"Step 2/3: Task created (Prompt ID: {prompt_id}). Waiting for model training and generation..."
76
 
77
- # 3. Polling the Prompt Endpoint for Results
78
  start_time = time.time()
79
  timeout = 360 # 6-minute timeout for training and generation
80
 
@@ -90,13 +127,13 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
90
  images = poll_data.get("images", [])
91
  if images and images[0]:
92
  output_image_url = images[0]
93
- yield output_image_url, "Step 3/3: Generation successful!"
94
  return # End the function
95
  else:
96
  # This case handles when training is done but image isn't ready yet.
97
- yield None, "Step 2/3: Model trained. Finalizing image..."
98
  else:
99
- yield None, "Step 2/3: Model is training. Checking again in 10 seconds..."
100
 
101
  time.sleep(10) # Wait before polling again
102
 
@@ -127,6 +164,9 @@ ASTRIA_API_KEY = os.environ.get("ASTRIA_API_KEY") # Recommended: Load from envir
127
  human_img_path = "path/to/your/human.jpg"
128
  garment_img_path = "path/to/your/garment.jpg"
129
  garment_type = "shirt" # Can be "shirt", "pants", or "dress"
 
 
 
130
 
131
  headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
132
 
@@ -139,7 +179,7 @@ with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garm
139
  ("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
140
  ]
141
  data = {
142
- "tune[title]": f"My VTO API Test - {int(time.time())}",
143
  "tune[name]": garment_type,
144
  "tune[model_type]": "faceid",
145
  "tune[base_tune_id]": "1504944", # VTO Base Model ID
@@ -157,66 +197,4 @@ print(response.json())
157
  # See this Space's app.py for a full polling example.
158
  """
159
 
160
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
161
- # --- Banner and Main Title ---
162
- gr.HTML(
163
- """
164
- <div style="text-align: center; padding: 20px; background-color: #f0f8ff; border-radius: 10px; margin-bottom: 20px;">
165
- <h2 style="color: #333;">Powered by the Astria.ai API</h2>
166
- <p style="color: #555; font-size: 1.1em;">This Virtual Try-On demo uses a powerful, single-call API endpoint.
167
- <br>
168
- You can integrate this functionality directly into your own applications.
169
- </p>
170
- <a href="https://docs.astria.ai/docs/use-cases/virtual-try-on/" target="_blank" style="text-decoration: none;">
171
- <button style="padding: 10px 20px; font-size: 1em; color: white; background-color: #007bff; border: none; border-radius: 5px; cursor: pointer;">
172
- Read the Full API Documentation
173
- </button>
174
- </a>
175
- </div>
176
- """
177
- )
178
- gr.Markdown("# Virtual Fashion Try-On")
179
- gr.Markdown(
180
- """
181
- **Instructions:** Upload a photo of a person and a photo of a piece of clothing to see them wear it.
182
-
183
- This space is made available through the Astria API, which allows you to create virtual try-on experiences with just a few lines of code.
184
- Astria fine-tuning API allows you to create custom AI models from your own images, and then use Virtual Try-On to make those models wear different clothing items.
185
- This can be used for fashion e-commerce, virtual fitting rooms, music festivals and concerts and sport events activation campaigns. See more in [Astria Virtual Try-On documentation](https://docs.astria.ai/docs/use-cases/virtual-try-on/).
186
- """
187
- )
188
-
189
- # --- Main Application UI ---
190
- with gr.Row():
191
- with gr.Column(scale=1):
192
- human_image = gr.Image(type="filepath", label="Human Image", height=300)
193
- garment_image = gr.Image(type="filepath", label="Garment Image", height=300)
194
- garment_type = gr.Dropdown(
195
- ["shirt", "pants", "dress"],
196
- label="Garment Type",
197
- info="Select the type of clothing item you uploaded.",
198
- value="shirt"
199
- )
200
- submit_btn = gr.Button("Generate", variant="primary")
201
- with gr.Column(scale=1):
202
- status_text = gr.Textbox(label="Status", interactive=False, lines=2, placeholder="Upload images and click Generate...")
203
- result_image = gr.Image(label="Result", height=615, interactive=False)
204
-
205
- # --- API Usage Code Snippet ---
206
- with gr.Accordion("Show API Usage Code (Python)", open=False):
207
- gr.Code(
208
- value=code_snippet,
209
- language="python",
210
- label="Python Request Example",
211
- interactive=False
212
- )
213
-
214
- # --- Event Listener ---
215
- submit_btn.click(
216
- fn=virtual_tryon,
217
- inputs=[human_image, garment_image, garment_type],
218
- outputs=[result_image, status_text]
219
- )
220
-
221
- if __name__ == "__main__":
222
- demo.launch()
 
1
  # app.py
2
  import gradio as gr
3
+ from gradio import Request
4
  import requests
5
  import os
6
  import time
 
14
 
15
  # --- Core Application Logic ---
16
 
17
+ def virtual_tryon(human_img_path, garment_img_path, garment_type, request: Request):
18
  """
19
  This function handles the entire process:
20
+ 1. Validates inputs, API key, and user quota by using the request object.
21
  2. Constructs and sends a request to create a Tune and associated Prompt.
22
  3. Polls the Prompt endpoint until the model is trained and image is generated.
23
  4. Yields status updates to the Gradio UI.
 
30
  if not human_img_path or not garment_img_path:
31
  raise gr.Error("Please upload both a human image and a garment image.")
32
 
33
+ user_ip = request.client.host if request else "unknown"
34
+ if user_ip == "unknown":
35
+ # This error is not shown to the user but helps in debugging
36
+ print("Warning: Could not determine user IP.")
37
+ # Proceed without a quota check if IP is not found
38
+ pass
39
+ else:
40
+ headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
41
+ tune_title = f"Gradio VTO {user_ip}"
42
+
43
+ # 2. Usage Quota Check
44
+ yield None, f"Step 1/4: Checking usage quota..."
45
+ try:
46
+ # Query the API for tunes created by this user IP
47
+ quota_check_url = f"{ASTRIA_API_BASE_URL}/tunes"
48
+ params = {"title": tune_title}
49
+ quota_response = requests.get(quota_check_url, headers=headers, params=params)
50
+ quota_response.raise_for_status()
51
+ user_tunes = quota_response.json()
52
+
53
+ # Enforce a limit of 5 trials
54
+ if len(user_tunes) > 4:
55
+ raise gr.Error(
56
+ f"Usage quota reached (5 trials). To continue, please use the API directly at astria.ai."
57
+ )
58
+
59
+ trials_used = len(user_tunes)
60
+ yield None, f"Quota check passed ({trials_used}/5 trials used). Proceeding to step 2..."
61
+
62
+ except requests.exceptions.RequestException as e:
63
+ raise gr.Error(f"API request failed during quota check: {e}")
64
+ except Exception as e:
65
+ # Catch the gr.Error from the check and re-raise it
66
+ if isinstance(e, gr.Error):
67
+ raise e
68
+ raise gr.Error(f"An unexpected error occurred during quota check: {e}")
69
+
70
 
71
+ yield None, "Step 2/4: Preparing and uploading your images..."
72
  prompt_id = None
73
 
74
  try:
75
+ # 3. Construct and Send API Request to create Tune and Prompt
76
  with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garment_f:
77
  files = [
78
  ("tune[images][]", ("garment.jpg", garment_f.read(), "image/jpeg")),
79
  ("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
80
  ]
81
  data = {
82
+ "tune[title]": f"Gradio VTO - {user_ip if user_ip != 'unknown' else 'anonymous'}",
83
  "tune[name]": garment_type,
84
  "tune[model_type]": "faceid",
85
  "tune[base_tune_id]": "1504944", # This is the base model for Flux although no image generation is done in this case
 
109
  except Exception as e:
110
  raise gr.Error(f"An unexpected error occurred during upload: {e}")
111
 
112
+ yield None, f"Step 3/4: Task created (Prompt ID: {prompt_id}). Waiting for model training and generation..."
113
 
114
+ # 4. Polling the Prompt Endpoint for Results
115
  start_time = time.time()
116
  timeout = 360 # 6-minute timeout for training and generation
117
 
 
127
  images = poll_data.get("images", [])
128
  if images and images[0]:
129
  output_image_url = images[0]
130
+ yield output_image_url, "Step 4/4: Generation successful!"
131
  return # End the function
132
  else:
133
  # This case handles when training is done but image isn't ready yet.
134
+ yield None, "Step 3/4: Model trained. Finalizing image..."
135
  else:
136
+ yield None, "Step 3/4: Model is training. Checking again in 10 seconds..."
137
 
138
  time.sleep(10) # Wait before polling again
139
 
 
164
  human_img_path = "path/to/your/human.jpg"
165
  garment_img_path = "path/to/your/garment.jpg"
166
  garment_type = "shirt" # Can be "shirt", "pants", or "dress"
167
+ # To track usage, you can include a user identifier in the title
168
+ user_identifier = "user-ip-or-id"
169
+ tune_title = f"My App VTO - {user_identifier} - {int(time.time())}"
170
 
171
  headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
172
 
 
179
  ("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
180
  ]
181
  data = {
182
+ "tune[title]": tune_title,
183
  "tune[name]": garment_type,
184
  "tune[model_type]": "faceid",
185
  "tune[base_tune_id]": "1504944", # VTO Base Model ID
 
197
  # See this Space's app.py for a full polling example.
198
  """
199
 
200
+ with gr.Blocks(