Spaces:
Runtime error
Runtime error
Commit
·
a40e4e2
1
Parent(s):
b16d55b
Update app.py
Browse files
app.py
CHANGED
@@ -1,29 +1,30 @@
|
|
1 |
-
import
|
|
|
2 |
import boto3
|
3 |
from botocore.config import Config
|
4 |
from dotenv import load_dotenv
|
5 |
import os
|
6 |
import shutil
|
7 |
-
from typing import List, Tuple
|
8 |
import uuid
|
9 |
-
import zipfile
|
10 |
import argparse
|
11 |
import logging
|
12 |
import sendgrid
|
13 |
from sendgrid.helpers.mail import Mail, Email, To, Content
|
14 |
-
|
15 |
-
|
16 |
-
from
|
17 |
-
from itertools import cycle
|
18 |
|
19 |
import requests
|
20 |
import banana_dev as banana
|
21 |
import streamlit as st
|
22 |
from PIL import Image
|
23 |
-
from st_btn_select import st_btn_select
|
24 |
from streamlit_image_select import image_select
|
25 |
import smart_open
|
26 |
|
|
|
|
|
|
|
27 |
logging.basicConfig()
|
28 |
logger = logging.getLogger(__name__)
|
29 |
logger.setLevel(logging.INFO)
|
@@ -33,29 +34,74 @@ logger.setLevel(logging.INFO)
|
|
33 |
load_dotenv()
|
34 |
|
35 |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# Command-line arguments to control some stuff for easier local testing.
|
38 |
# Eventually may want to move everything into functions and have a
|
39 |
# if __name__ == "main" setup instead of everything inline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
if var not in st.session_state:
|
44 |
-
st.session_state[var] = False
|
45 |
-
|
46 |
-
obj_session_variables = ['s3_face_file_path', 's3_theme_file_path', 'captcha_response', 'user_email']
|
47 |
-
for var in obj_session_variables:
|
48 |
-
if var not in st.session_state:
|
49 |
-
st.session_state[var] = None
|
50 |
-
|
51 |
-
if "key" not in st.session_state:
|
52 |
-
st.session_state["key"] = uuid.uuid4().hex
|
53 |
|
54 |
-
if "captcha" not in st.session_state:
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
|
60 |
|
61 |
def bucket_parts(s3_path: str) -> Tuple[str, str]:
|
@@ -129,43 +175,46 @@ def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
|
|
129 |
return upload_url
|
130 |
|
131 |
|
132 |
-
def zip_and_upload_images(identifier: str, uploaded_files: List[
|
133 |
"""Save images as zip file to s3 for use in backend.
|
134 |
|
135 |
Blocks until images are processed, added to zip file, and uploaded to S3.
|
136 |
|
137 |
Args:
|
138 |
identifier: unique identifier for the run, used in s3 link
|
139 |
-
uploaded_files:
|
140 |
image_type: string to identify different batches of images used in the
|
141 |
backend model/training. Currently used values: "face", "theme"
|
142 |
|
143 |
Returns:
|
144 |
S3 location of zip file containing png images.
|
145 |
"""
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
169 |
|
170 |
def send_email(to_email, user_code):
|
171 |
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
|
@@ -178,9 +227,6 @@ def send_email(to_email, user_code):
|
|
178 |
response = sg.client.mail.send.post(request_body=mail_json)
|
179 |
|
180 |
|
181 |
-
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
|
182 |
-
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
|
183 |
-
|
184 |
# Create a function to generate a captcha
|
185 |
def generate_captcha():
|
186 |
# Make a GET request to the API endpoint to generate a captcha
|
@@ -211,55 +257,73 @@ def verify_captcha(captcha_id, captcha_response):
|
|
211 |
return {"error": "Failed to verify captcha"}
|
212 |
|
213 |
def train_model(model_inputs):
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
|
219 |
-
|
220 |
user_email_input = st.empty()
|
221 |
-
with user_email_input.form(key='
|
222 |
text_input = st.text_input(label='Please Enter Your Email')
|
223 |
submit_button = st.form_submit_button(label='Submit')
|
224 |
-
if submit_button:
|
225 |
-
st.session_state["user_auth_sess"] = True
|
226 |
st.session_state["user_email"] = text_input
|
227 |
-
send_email(
|
228 |
-
|
229 |
-
|
230 |
-
#
|
231 |
-
|
|
|
|
|
|
|
232 |
user_auth = st.empty()
|
233 |
-
with user_auth.form("one-
|
234 |
text_input = st.text_input(label='Please Input One Time Code')
|
235 |
submit_button = st.form_submit_button(label='Submit')
|
236 |
if submit_button:
|
237 |
if text_input == st.session_state["key"]:
|
238 |
-
st.session_state["
|
|
|
239 |
else:
|
240 |
st.markdown("Please Enter Correct Code!")
|
241 |
|
242 |
-
|
243 |
-
|
244 |
identifier = st.session_state["key"]
|
245 |
-
user_auth.empty()
|
246 |
face_images = st.empty()
|
247 |
-
with face_images.form("
|
248 |
uploaded_files = st.file_uploader(
|
249 |
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
250 |
)
|
|
|
|
|
251 |
submitted = st.form_submit_button(f"Upload")
|
252 |
if submitted:
|
253 |
with st.spinner('Uploading...'):
|
254 |
-
st.session_state["
|
255 |
identifier, uploaded_files, "face"
|
256 |
)
|
|
|
|
|
257 |
st.success(f'Uploading {len(uploaded_files)} files done!')
|
258 |
-
st.session_state["
|
|
|
|
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
face_images.empty()
|
263 |
preset_theme_images = st.empty()
|
264 |
with preset_theme_images.form("choose-preset-theme"):
|
265 |
img = image_select(
|
@@ -274,6 +338,7 @@ if st.session_state["face_images_uploaded"]:
|
|
274 |
)
|
275 |
|
276 |
col1, col2 = st.columns([0.17, 1])
|
|
|
277 |
with col1:
|
278 |
submitted_3 = st.form_submit_button("Submit!")
|
279 |
if submitted_3:
|
@@ -287,13 +352,19 @@ if st.session_state["face_images_uploaded"]:
|
|
287 |
2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
|
288 |
}
|
289 |
st.session_state["model_inputs"] = {
|
290 |
-
"
|
291 |
-
|
292 |
-
"
|
|
|
|
|
|
|
|
|
293 |
"num_images": 50,
|
|
|
294 |
}
|
295 |
st.success("Success!")
|
296 |
-
st.session_state["
|
|
|
297 |
with col2:
|
298 |
submitted_4 = st.form_submit_button(
|
299 |
"If none of the themes interest you, click here!"
|
@@ -302,6 +373,7 @@ if st.session_state["face_images_uploaded"]:
|
|
302 |
st.session_state["view"] = True
|
303 |
|
304 |
if st.session_state["view"]:
|
|
|
305 |
custom_theme_images = st.empty()
|
306 |
with custom_theme_images.form("input_custom_themes"):
|
307 |
st.markdown("If none of the themes interest you, please input your own!")
|
@@ -310,86 +382,113 @@ if st.session_state["face_images_uploaded"]:
|
|
310 |
accept_multiple_files=True,
|
311 |
type=["png", "jpg", "jpeg"],
|
312 |
)
|
313 |
-
|
|
|
|
|
314 |
submitted_3 = st.form_submit_button("Submit!")
|
315 |
if submitted_3:
|
316 |
with st.spinner('Uploading...'):
|
317 |
-
st.session_state["
|
318 |
identifier, uploaded_files_2, "theme"
|
319 |
)
|
|
|
320 |
st.session_state["model_inputs"] = {
|
321 |
# Use presigned urls since backend does not have credentials
|
322 |
-
"
|
323 |
-
"
|
324 |
-
"
|
|
|
|
|
|
|
325 |
"num_images": 50,
|
|
|
326 |
}
|
327 |
st.success('Done!')
|
328 |
-
st.session_state["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
-
#fifth page - user enters captcha and trains model
|
331 |
-
if st.session_state["theme_images_uploaded"]:
|
332 |
-
if st.session_state["view"]:
|
333 |
-
custom_theme_images.empty()
|
334 |
-
preset_theme_images.empty()
|
335 |
-
train = st.empty()
|
336 |
-
with train.form("training"):
|
337 |
-
submitted = st.form_submit_button("Train Model!")
|
338 |
if submitted:
|
339 |
-
st.session_state["
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
import boto3
|
4 |
from botocore.config import Config
|
5 |
from dotenv import load_dotenv
|
6 |
import os
|
7 |
import shutil
|
8 |
+
from typing import List, Tuple, TYPE_CHECKING
|
9 |
import uuid
|
|
|
10 |
import argparse
|
11 |
import logging
|
12 |
import sendgrid
|
13 |
from sendgrid.helpers.mail import Mail, Email, To, Content
|
14 |
+
from enum import Enum
|
15 |
+
import tempfile
|
16 |
+
from pathlib import Path
|
|
|
17 |
|
18 |
import requests
|
19 |
import banana_dev as banana
|
20 |
import streamlit as st
|
21 |
from PIL import Image
|
|
|
22 |
from streamlit_image_select import image_select
|
23 |
import smart_open
|
24 |
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from io import BytesIO
|
27 |
+
|
28 |
logging.basicConfig()
|
29 |
logger = logging.getLogger(__name__)
|
30 |
logger.setLevel(logging.INFO)
|
|
|
34 |
load_dotenv()
|
35 |
|
36 |
_S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
|
37 |
+
CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
|
38 |
+
VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
|
39 |
+
|
40 |
+
class UxState(str, Enum):
|
41 |
+
LOGIN = "login"
|
42 |
+
VERIFY_EMAIL = "verify_email"
|
43 |
+
UPLOAD1 = "upload1"
|
44 |
+
UPLOAD2 = "upload2"
|
45 |
+
CAPTCHA = "captcha"
|
46 |
+
TRAIN = "train"
|
47 |
|
48 |
# Command-line arguments to control some stuff for easier local testing.
|
49 |
# Eventually may want to move everything into functions and have a
|
50 |
# if __name__ == "main" setup instead of everything inline.
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument(
|
53 |
+
"--dry-run", action="store_true",
|
54 |
+
help="Skip sending train request to backend server.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--train-endpoint-url", default=None,
|
58 |
+
help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
|
59 |
+
)
|
60 |
+
cli_args = parser.parse_args()
|
61 |
+
|
62 |
+
|
63 |
+
def setup_session_state():
|
64 |
+
if "key" not in st.session_state:
|
65 |
+
st.session_state["key"] = uuid.uuid4().hex
|
66 |
+
|
67 |
+
if "ux_state" not in st.session_state:
|
68 |
+
st.session_state["ux_state"] = UxState.LOGIN
|
69 |
+
|
70 |
+
if "model_inputs" not in st.session_state:
|
71 |
+
st.session_state["model_inputs"] = None
|
72 |
+
|
73 |
+
if "initial_concept_file_path" not in st.session_state:
|
74 |
+
st.session_state["initial_concept_file_path"] = None
|
75 |
+
|
76 |
+
if "initial_token" not in st.session_state:
|
77 |
+
st.session_state["initial_token"] = None
|
78 |
+
|
79 |
+
if "initial_class_token" not in st.session_state:
|
80 |
+
st.session_state["initial_class_token"] = None
|
81 |
+
|
82 |
+
if "secondary_concept_file_path" not in st.session_state:
|
83 |
+
st.session_state["secondary_concept_file_path"] = None
|
84 |
+
|
85 |
+
if "secondary_token" not in st.session_state:
|
86 |
+
st.session_state["secondary_token"] = None
|
87 |
+
|
88 |
+
if "secondary_class_token" not in st.session_state:
|
89 |
+
st.session_state["secondary_class_token"] = None
|
90 |
+
|
91 |
+
if "prompt_keywords" not in st.session_state:
|
92 |
+
st.session_state["prompt_keywords"] = None
|
93 |
+
|
94 |
+
if "view" not in st.session_state:
|
95 |
+
st.session_state["view"] = False
|
96 |
|
97 |
+
if "captcha_response" not in st.session_state:
|
98 |
+
st.session_state["captcha_response"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
if "captcha" not in st.session_state:
|
101 |
+
st.session_state["captcha"] = {}
|
102 |
|
103 |
+
if "user_email" not in st.session_state:
|
104 |
+
st.session_state["user_email"] = None
|
105 |
|
106 |
|
107 |
def bucket_parts(s3_path: str) -> Tuple[str, str]:
|
|
|
175 |
return upload_url
|
176 |
|
177 |
|
178 |
+
def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
|
179 |
"""Save images as zip file to s3 for use in backend.
|
180 |
|
181 |
Blocks until images are processed, added to zip file, and uploaded to S3.
|
182 |
|
183 |
Args:
|
184 |
identifier: unique identifier for the run, used in s3 link
|
185 |
+
uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
|
186 |
image_type: string to identify different batches of images used in the
|
187 |
backend model/training. Currently used values: "face", "theme"
|
188 |
|
189 |
Returns:
|
190 |
S3 location of zip file containing png images.
|
191 |
"""
|
192 |
+
with tempfile.TemporaryDirectory() as temp_dir_name:
|
193 |
+
logger.info(f"Working from temp dir to zip and upload images: {temp_dir_name}")
|
194 |
+
temp_dir = Path(temp_dir_name)
|
195 |
+
if not os.path.exists(temp_dir / identifier):
|
196 |
+
os.makedirs(temp_dir / identifier)
|
197 |
+
|
198 |
+
logger.info("Processing uploaded images")
|
199 |
+
for num, uploaded_file in enumerate(uploaded_files):
|
200 |
+
file_ = Image.open(uploaded_file).convert("RGB")
|
201 |
+
file_.save(temp_dir / identifier / f"{num}_test.png")
|
202 |
+
local_zip_filestem = str(temp_dir / f"{identifier}_{image_type}_images")
|
203 |
+
logger.info("Making zip archive")
|
204 |
+
shutil.make_archive(local_zip_filestem, "zip", temp_dir / identifier)
|
205 |
+
local_zip_filename = f"{local_zip_filestem}.zip"
|
206 |
+
|
207 |
+
logger.info("Uploading zip file to s3")
|
208 |
+
# TODO: can we define expiration when making the s3 path?
|
209 |
+
# Probably if we use the boto3 library instead of smart open
|
210 |
+
s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
|
211 |
+
|
212 |
+
with open(local_zip_filename, "rb") as fin:
|
213 |
+
with smart_open.open(s3_path, "wb") as fout:
|
214 |
+
fout.write(fin.read())
|
215 |
+
logger.info(f"Completed upload to {s3_path}")
|
216 |
+
|
217 |
+
return s3_path
|
218 |
|
219 |
def send_email(to_email, user_code):
|
220 |
sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
|
|
|
227 |
response = sg.client.mail.send.post(request_body=mail_json)
|
228 |
|
229 |
|
|
|
|
|
|
|
230 |
# Create a function to generate a captcha
|
231 |
def generate_captcha():
|
232 |
# Make a GET request to the API endpoint to generate a captcha
|
|
|
257 |
return {"error": "Failed to verify captcha"}
|
258 |
|
259 |
def train_model(model_inputs):
|
260 |
+
if cli_args.dry_run:
|
261 |
+
logger.info("Skipping model training since --dry-run is enabled.")
|
262 |
+
logger.info(f"model_inputs: {model_inputs}")
|
263 |
+
return
|
264 |
+
|
265 |
+
if cli_args.train_endpoint_url is None:
|
266 |
+
# Use banana backend
|
267 |
+
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
|
268 |
+
model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
|
269 |
+
st.markdown(str(model_inputs))
|
270 |
+
_ = banana.run(api_key, model_key, model_inputs)
|
271 |
+
else:
|
272 |
+
# Send request directly to specified url
|
273 |
+
_ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
|
274 |
|
275 |
+
|
276 |
+
def run_login():
|
277 |
user_email_input = st.empty()
|
278 |
+
with user_email_input.form(key='user_auth'):
|
279 |
text_input = st.text_input(label='Please Enter Your Email')
|
280 |
submit_button = st.form_submit_button(label='Submit')
|
281 |
+
if submit_button:
|
|
|
282 |
st.session_state["user_email"] = text_input
|
283 |
+
send_email(text_input, str(st.session_state["key"]))
|
284 |
+
st.session_state["ux_state"] = UxState.VERIFY_EMAIL
|
285 |
+
# TODO: alternately run this submit log in a callback to the input?
|
286 |
+
# or otherwise ensure we execute the runner for the new state
|
287 |
+
st.experimental_rerun()
|
288 |
+
|
289 |
+
|
290 |
+
def run_verify_email():
|
291 |
user_auth = st.empty()
|
292 |
+
with user_auth.form("one-code"):
|
293 |
text_input = st.text_input(label='Please Input One Time Code')
|
294 |
submit_button = st.form_submit_button(label='Submit')
|
295 |
if submit_button:
|
296 |
if text_input == st.session_state["key"]:
|
297 |
+
st.session_state["ux_state"] = UxState.UPLOAD1
|
298 |
+
st.experimental_rerun()
|
299 |
else:
|
300 |
st.markdown("Please Enter Correct Code!")
|
301 |
|
302 |
+
|
303 |
+
def run_upload_initial():
|
304 |
identifier = st.session_state["key"]
|
|
|
305 |
face_images = st.empty()
|
306 |
+
with face_images.form("my_form"):
|
307 |
uploaded_files = st.file_uploader(
|
308 |
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
|
309 |
)
|
310 |
+
initial_concept_token = st.text_input("Token Name")
|
311 |
+
initial_concept_class_token = st.text_input("Token Class")
|
312 |
submitted = st.form_submit_button(f"Upload")
|
313 |
if submitted:
|
314 |
with st.spinner('Uploading...'):
|
315 |
+
st.session_state["initial_concept_file_path"] = zip_and_upload_images(
|
316 |
identifier, uploaded_files, "face"
|
317 |
)
|
318 |
+
st.session_state["initial_token"] = initial_concept_token
|
319 |
+
st.session_state["initial_class_token"] = initial_concept_class_token
|
320 |
st.success(f'Uploading {len(uploaded_files)} files done!')
|
321 |
+
st.session_state["ux_state"] = UxState.UPLOAD2
|
322 |
+
st.experimental_rerun()
|
323 |
+
|
324 |
|
325 |
+
def run_upload_secondary():
|
326 |
+
identifier = st.session_state["key"]
|
|
|
327 |
preset_theme_images = st.empty()
|
328 |
with preset_theme_images.form("choose-preset-theme"):
|
329 |
img = image_select(
|
|
|
338 |
)
|
339 |
|
340 |
col1, col2 = st.columns([0.17, 1])
|
341 |
+
prompt_keywords = st.text_input("Prompt Keywords")
|
342 |
with col1:
|
343 |
submitted_3 = st.form_submit_button("Submit!")
|
344 |
if submitted_3:
|
|
|
352 |
2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
|
353 |
}
|
354 |
st.session_state["model_inputs"] = {
|
355 |
+
"secondary_concept_file_path": dictionary[img][0],
|
356 |
+
# Use presigned url since backend does not have credentials
|
357 |
+
"initial_token": st.session_state["initial_token"],
|
358 |
+
"secondary_token": dictionary[img][1],
|
359 |
+
"initial_class_token": st.session_state["initial_class_token"],
|
360 |
+
"secondary_class_token": 'superhero',
|
361 |
+
"initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
|
362 |
"num_images": 50,
|
363 |
+
"prompt_keywords": prompt_keywords
|
364 |
}
|
365 |
st.success("Success!")
|
366 |
+
st.session_state["ux_state"] = UxState.CAPTCHA
|
367 |
+
st.experimental_rerun()
|
368 |
with col2:
|
369 |
submitted_4 = st.form_submit_button(
|
370 |
"If none of the themes interest you, click here!"
|
|
|
373 |
st.session_state["view"] = True
|
374 |
|
375 |
if st.session_state["view"]:
|
376 |
+
# TODO: split into it's own ux state and function?
|
377 |
custom_theme_images = st.empty()
|
378 |
with custom_theme_images.form("input_custom_themes"):
|
379 |
st.markdown("If none of the themes interest you, please input your own!")
|
|
|
382 |
accept_multiple_files=True,
|
383 |
type=["png", "jpg", "jpeg"],
|
384 |
)
|
385 |
+
secondary_concept_token = st.text_input("Token Name")
|
386 |
+
secondary_concept_class_token = st.text_input("Token Class")
|
387 |
+
prompt_keywords = st.text_input("Prompt Keywords")
|
388 |
submitted_3 = st.form_submit_button("Submit!")
|
389 |
if submitted_3:
|
390 |
with st.spinner('Uploading...'):
|
391 |
+
st.session_state["secondary_concept_file_path"] = zip_and_upload_images(
|
392 |
identifier, uploaded_files_2, "theme"
|
393 |
)
|
394 |
+
#st.markdown(secondary_concept_file_path)
|
395 |
st.session_state["model_inputs"] = {
|
396 |
# Use presigned urls since backend does not have credentials
|
397 |
+
"initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
|
398 |
+
"secondary_concept_file_path": generate_s3_get_url(st.session_state["secondary_concept_file_path"], expiration_seconds=3600),
|
399 |
+
"initial_token": st.session_state["initial_token"],
|
400 |
+
"secondary_token": secondary_concept_token,
|
401 |
+
"initial_class_token": st.session_state["initial_class_token"],
|
402 |
+
"secondary_class_token": secondary_concept_class_token,
|
403 |
"num_images": 50,
|
404 |
+
"prompt_keywords": prompt_keywords
|
405 |
}
|
406 |
st.success('Done!')
|
407 |
+
st.session_state["ux_state"] = UxState.CAPTCHA
|
408 |
+
st.experimental_rerun()
|
409 |
+
|
410 |
+
|
411 |
+
|
412 |
+
def run_captcha():
|
413 |
+
captcha_form = st.empty()
|
414 |
+
with captcha_form.form("captcha_form", clear_on_submit=True):
|
415 |
+
# Create container to create image/text input out of order from the
|
416 |
+
# format submit button. Needed since we need to know the status of the
|
417 |
+
# form submit to know what the captcha should do.
|
418 |
+
captcha_container = st.container()
|
419 |
+
display_captcha = True
|
420 |
+
# TODO: Submit button renders first, then drops down once the image is
|
421 |
+
# fetched leading to page reflow. Would be nice to not have reflow, but
|
422 |
+
# we need to know if the submit button was previously pressed and if the
|
423 |
+
# captcha was solved to generate and display a new captcha or not.
|
424 |
+
# Possible solution is use an on_click callback to set a session_state
|
425 |
+
# variable to access whether the button was pushed or not instead of the
|
426 |
+
# return value here.
|
427 |
+
submitted = st.form_submit_button("Submit Captcha!")
|
428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
if submitted:
|
430 |
+
result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
|
431 |
+
del st.session_state["captcha_response"]
|
432 |
+
if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
|
433 |
+
st.session_state['captcha'] = {}
|
434 |
+
display_captcha = False
|
435 |
+
with st.spinner("Model Fine Tuning..."):
|
436 |
+
st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
|
437 |
+
st.session_state["model_inputs"]["email"] = st.session_state["user_email"]
|
438 |
+
s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
|
439 |
+
# The backend does not have s3 credentials, so generate
|
440 |
+
# presigned urls for the backend to use to write and read
|
441 |
+
# the generated images.
|
442 |
+
st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
|
443 |
+
s3_output_path, expiration_seconds=60 * 60 * 24,
|
444 |
+
)
|
445 |
+
st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
|
446 |
+
s3_output_path, expiration_seconds=3600,
|
447 |
+
)
|
448 |
+
train_model(st.session_state["model_inputs"])
|
449 |
+
st.session_state["ux_state"] = UxState.TRAIN
|
450 |
+
st.experimental_rerun()
|
451 |
+
else:
|
452 |
+
st.error(result['error'])
|
453 |
+
|
454 |
+
if display_captcha:
|
455 |
+
# Generate new captcha and display. Occurs on first run of the
|
456 |
+
# captcha state, or after previously failed captcha attempts.
|
457 |
+
result = generate_captcha()
|
458 |
+
captcha_id = result['uuid']
|
459 |
+
captcha_image = result['captcha']
|
460 |
+
|
461 |
+
st.session_state['captcha']['uuid'] = captcha_id
|
462 |
+
st.session_state['captcha']['captcha'] = captcha_image
|
463 |
+
|
464 |
+
captcha_container.image(captcha_image, width=300)
|
465 |
+
|
466 |
+
captcha_container.text_input("Enter the captcha response", key="captcha_response")
|
467 |
+
# Submit button already setup previously.
|
468 |
+
|
469 |
+
|
470 |
+
def run_train():
|
471 |
+
st.write(f"Congratulations, your model is training.")
|
472 |
+
st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
|
473 |
+
st.write("You may close this browser window/tab.")
|
474 |
+
|
475 |
+
|
476 |
+
if __name__ == "__main__":
|
477 |
+
setup_session_state()
|
478 |
+
|
479 |
+
ux_state = st.session_state["ux_state"]
|
480 |
+
|
481 |
+
if ux_state == UxState.LOGIN:
|
482 |
+
run_login()
|
483 |
+
elif ux_state == UxState.VERIFY_EMAIL:
|
484 |
+
run_verify_email()
|
485 |
+
elif ux_state == UxState.UPLOAD1:
|
486 |
+
run_upload_initial()
|
487 |
+
elif ux_state == UxState.UPLOAD2:
|
488 |
+
run_upload_secondary()
|
489 |
+
elif ux_state == UxState.CAPTCHA:
|
490 |
+
run_captcha()
|
491 |
+
elif ux_state == UxState.TRAIN:
|
492 |
+
run_train()
|
493 |
+
else:
|
494 |
+
raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")
|