Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
|
|
3 |
import requests
|
4 |
import os
|
5 |
import time
|
@@ -13,10 +14,10 @@ ASTRIA_API_KEY = os.environ.get("ASTRIA_API_KEY")
|
|
13 |
|
14 |
# --- Core Application Logic ---
|
15 |
|
16 |
-
def virtual_tryon(human_img_path, garment_img_path, garment_type):
|
17 |
"""
|
18 |
This function handles the entire process:
|
19 |
-
1. Validates inputs and
|
20 |
2. Constructs and sends a request to create a Tune and associated Prompt.
|
21 |
3. Polls the Prompt endpoint until the model is trained and image is generated.
|
22 |
4. Yields status updates to the Gradio UI.
|
@@ -29,20 +30,56 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
|
|
29 |
if not human_img_path or not garment_img_path:
|
30 |
raise gr.Error("Please upload both a human image and a garment image.")
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
prompt_id = None
|
36 |
|
37 |
try:
|
38 |
-
#
|
39 |
with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garment_f:
|
40 |
files = [
|
41 |
("tune[images][]", ("garment.jpg", garment_f.read(), "image/jpeg")),
|
42 |
("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
|
43 |
]
|
44 |
data = {
|
45 |
-
"tune[title]": f"Gradio VTO - {
|
46 |
"tune[name]": garment_type,
|
47 |
"tune[model_type]": "faceid",
|
48 |
"tune[base_tune_id]": "1504944", # This is the base model for Flux although no image generation is done in this case
|
@@ -72,9 +109,9 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
|
|
72 |
except Exception as e:
|
73 |
raise gr.Error(f"An unexpected error occurred during upload: {e}")
|
74 |
|
75 |
-
yield None, f"Step
|
76 |
|
77 |
-
#
|
78 |
start_time = time.time()
|
79 |
timeout = 360 # 6-minute timeout for training and generation
|
80 |
|
@@ -90,13 +127,13 @@ def virtual_tryon(human_img_path, garment_img_path, garment_type):
|
|
90 |
images = poll_data.get("images", [])
|
91 |
if images and images[0]:
|
92 |
output_image_url = images[0]
|
93 |
-
yield output_image_url, "Step
|
94 |
return # End the function
|
95 |
else:
|
96 |
# This case handles when training is done but image isn't ready yet.
|
97 |
-
yield None, "Step
|
98 |
else:
|
99 |
-
yield None, "Step
|
100 |
|
101 |
time.sleep(10) # Wait before polling again
|
102 |
|
@@ -127,6 +164,9 @@ ASTRIA_API_KEY = os.environ.get("ASTRIA_API_KEY") # Recommended: Load from envir
|
|
127 |
human_img_path = "path/to/your/human.jpg"
|
128 |
garment_img_path = "path/to/your/garment.jpg"
|
129 |
garment_type = "shirt" # Can be "shirt", "pants", or "dress"
|
|
|
|
|
|
|
130 |
|
131 |
headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
|
132 |
|
@@ -139,7 +179,7 @@ with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garm
|
|
139 |
("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
|
140 |
]
|
141 |
data = {
|
142 |
-
"tune[title]":
|
143 |
"tune[name]": garment_type,
|
144 |
"tune[model_type]": "faceid",
|
145 |
"tune[base_tune_id]": "1504944", # VTO Base Model ID
|
@@ -157,66 +197,4 @@ print(response.json())
|
|
157 |
# See this Space's app.py for a full polling example.
|
158 |
"""
|
159 |
|
160 |
-
with gr.Blocks(
|
161 |
-
# --- Banner and Main Title ---
|
162 |
-
gr.HTML(
|
163 |
-
"""
|
164 |
-
<div style="text-align: center; padding: 20px; background-color: #f0f8ff; border-radius: 10px; margin-bottom: 20px;">
|
165 |
-
<h2 style="color: #333;">Powered by the Astria.ai API</h2>
|
166 |
-
<p style="color: #555; font-size: 1.1em;">This Virtual Try-On demo uses a powerful, single-call API endpoint.
|
167 |
-
<br>
|
168 |
-
You can integrate this functionality directly into your own applications.
|
169 |
-
</p>
|
170 |
-
<a href="https://docs.astria.ai/docs/use-cases/virtual-try-on/" target="_blank" style="text-decoration: none;">
|
171 |
-
<button style="padding: 10px 20px; font-size: 1em; color: white; background-color: #007bff; border: none; border-radius: 5px; cursor: pointer;">
|
172 |
-
Read the Full API Documentation
|
173 |
-
</button>
|
174 |
-
</a>
|
175 |
-
</div>
|
176 |
-
"""
|
177 |
-
)
|
178 |
-
gr.Markdown("# Virtual Fashion Try-On")
|
179 |
-
gr.Markdown(
|
180 |
-
"""
|
181 |
-
**Instructions:** Upload a photo of a person and a photo of a piece of clothing to see them wear it.
|
182 |
-
|
183 |
-
This space is made available through the Astria API, which allows you to create virtual try-on experiences with just a few lines of code.
|
184 |
-
Astria fine-tuning API allows you to create custom AI models from your own images, and then use Virtual Try-On to make those models wear different clothing items.
|
185 |
-
This can be used for fashion e-commerce, virtual fitting rooms, music festivals and concerts and sport events activation campaigns. See more in [Astria Virtual Try-On documentation](https://docs.astria.ai/docs/use-cases/virtual-try-on/).
|
186 |
-
"""
|
187 |
-
)
|
188 |
-
|
189 |
-
# --- Main Application UI ---
|
190 |
-
with gr.Row():
|
191 |
-
with gr.Column(scale=1):
|
192 |
-
human_image = gr.Image(type="filepath", label="Human Image", height=300)
|
193 |
-
garment_image = gr.Image(type="filepath", label="Garment Image", height=300)
|
194 |
-
garment_type = gr.Dropdown(
|
195 |
-
["shirt", "pants", "dress"],
|
196 |
-
label="Garment Type",
|
197 |
-
info="Select the type of clothing item you uploaded.",
|
198 |
-
value="shirt"
|
199 |
-
)
|
200 |
-
submit_btn = gr.Button("Generate", variant="primary")
|
201 |
-
with gr.Column(scale=1):
|
202 |
-
status_text = gr.Textbox(label="Status", interactive=False, lines=2, placeholder="Upload images and click Generate...")
|
203 |
-
result_image = gr.Image(label="Result", height=615, interactive=False)
|
204 |
-
|
205 |
-
# --- API Usage Code Snippet ---
|
206 |
-
with gr.Accordion("Show API Usage Code (Python)", open=False):
|
207 |
-
gr.Code(
|
208 |
-
value=code_snippet,
|
209 |
-
language="python",
|
210 |
-
label="Python Request Example",
|
211 |
-
interactive=False
|
212 |
-
)
|
213 |
-
|
214 |
-
# --- Event Listener ---
|
215 |
-
submit_btn.click(
|
216 |
-
fn=virtual_tryon,
|
217 |
-
inputs=[human_image, garment_image, garment_type],
|
218 |
-
outputs=[result_image, status_text]
|
219 |
-
)
|
220 |
-
|
221 |
-
if __name__ == "__main__":
|
222 |
-
demo.launch()
|
|
|
1 |
# app.py
|
2 |
import gradio as gr
|
3 |
+
from gradio import Request
|
4 |
import requests
|
5 |
import os
|
6 |
import time
|
|
|
14 |
|
15 |
# --- Core Application Logic ---
|
16 |
|
17 |
+
def virtual_tryon(human_img_path, garment_img_path, garment_type, request: Request):
|
18 |
"""
|
19 |
This function handles the entire process:
|
20 |
+
1. Validates inputs, API key, and user quota by using the request object.
|
21 |
2. Constructs and sends a request to create a Tune and associated Prompt.
|
22 |
3. Polls the Prompt endpoint until the model is trained and image is generated.
|
23 |
4. Yields status updates to the Gradio UI.
|
|
|
30 |
if not human_img_path or not garment_img_path:
|
31 |
raise gr.Error("Please upload both a human image and a garment image.")
|
32 |
|
33 |
+
user_ip = request.client.host if request else "unknown"
|
34 |
+
if user_ip == "unknown":
|
35 |
+
# This error is not shown to the user but helps in debugging
|
36 |
+
print("Warning: Could not determine user IP.")
|
37 |
+
# Proceed without a quota check if IP is not found
|
38 |
+
pass
|
39 |
+
else:
|
40 |
+
headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
|
41 |
+
tune_title = f"Gradio VTO {user_ip}"
|
42 |
+
|
43 |
+
# 2. Usage Quota Check
|
44 |
+
yield None, f"Step 1/4: Checking usage quota..."
|
45 |
+
try:
|
46 |
+
# Query the API for tunes created by this user IP
|
47 |
+
quota_check_url = f"{ASTRIA_API_BASE_URL}/tunes"
|
48 |
+
params = {"title": tune_title}
|
49 |
+
quota_response = requests.get(quota_check_url, headers=headers, params=params)
|
50 |
+
quota_response.raise_for_status()
|
51 |
+
user_tunes = quota_response.json()
|
52 |
+
|
53 |
+
# Enforce a limit of 5 trials
|
54 |
+
if len(user_tunes) > 4:
|
55 |
+
raise gr.Error(
|
56 |
+
f"Usage quota reached (5 trials). To continue, please use the API directly at astria.ai."
|
57 |
+
)
|
58 |
+
|
59 |
+
trials_used = len(user_tunes)
|
60 |
+
yield None, f"Quota check passed ({trials_used}/5 trials used). Proceeding to step 2..."
|
61 |
+
|
62 |
+
except requests.exceptions.RequestException as e:
|
63 |
+
raise gr.Error(f"API request failed during quota check: {e}")
|
64 |
+
except Exception as e:
|
65 |
+
# Catch the gr.Error from the check and re-raise it
|
66 |
+
if isinstance(e, gr.Error):
|
67 |
+
raise e
|
68 |
+
raise gr.Error(f"An unexpected error occurred during quota check: {e}")
|
69 |
+
|
70 |
|
71 |
+
yield None, "Step 2/4: Preparing and uploading your images..."
|
72 |
prompt_id = None
|
73 |
|
74 |
try:
|
75 |
+
# 3. Construct and Send API Request to create Tune and Prompt
|
76 |
with open(human_img_path, "rb") as human_f, open(garment_img_path, "rb") as garment_f:
|
77 |
files = [
|
78 |
("tune[images][]", ("garment.jpg", garment_f.read(), "image/jpeg")),
|
79 |
("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
|
80 |
]
|
81 |
data = {
|
82 |
+
"tune[title]": f"Gradio VTO - {user_ip if user_ip != 'unknown' else 'anonymous'}",
|
83 |
"tune[name]": garment_type,
|
84 |
"tune[model_type]": "faceid",
|
85 |
"tune[base_tune_id]": "1504944", # This is the base model for Flux although no image generation is done in this case
|
|
|
109 |
except Exception as e:
|
110 |
raise gr.Error(f"An unexpected error occurred during upload: {e}")
|
111 |
|
112 |
+
yield None, f"Step 3/4: Task created (Prompt ID: {prompt_id}). Waiting for model training and generation..."
|
113 |
|
114 |
+
# 4. Polling the Prompt Endpoint for Results
|
115 |
start_time = time.time()
|
116 |
timeout = 360 # 6-minute timeout for training and generation
|
117 |
|
|
|
127 |
images = poll_data.get("images", [])
|
128 |
if images and images[0]:
|
129 |
output_image_url = images[0]
|
130 |
+
yield output_image_url, "Step 4/4: Generation successful!"
|
131 |
return # End the function
|
132 |
else:
|
133 |
# This case handles when training is done but image isn't ready yet.
|
134 |
+
yield None, "Step 3/4: Model trained. Finalizing image..."
|
135 |
else:
|
136 |
+
yield None, "Step 3/4: Model is training. Checking again in 10 seconds..."
|
137 |
|
138 |
time.sleep(10) # Wait before polling again
|
139 |
|
|
|
164 |
human_img_path = "path/to/your/human.jpg"
|
165 |
garment_img_path = "path/to/your/garment.jpg"
|
166 |
garment_type = "shirt" # Can be "shirt", "pants", or "dress"
|
167 |
+
# To track usage, you can include a user identifier in the title
|
168 |
+
user_identifier = "user-ip-or-id"
|
169 |
+
tune_title = f"My App VTO - {user_identifier} - {int(time.time())}"
|
170 |
|
171 |
headers = {"Authorization": f"Bearer {ASTRIA_API_KEY}"}
|
172 |
|
|
|
179 |
("tune[prompts_attributes][][input_image]", ("human.jpg", human_f.read(), "image/jpeg")),
|
180 |
]
|
181 |
data = {
|
182 |
+
"tune[title]": tune_title,
|
183 |
"tune[name]": garment_type,
|
184 |
"tune[model_type]": "faceid",
|
185 |
"tune[base_tune_id]": "1504944", # VTO Base Model ID
|
|
|
197 |
# See this Space's app.py for a full polling example.
|
198 |
"""
|
199 |
|
200 |
+
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|