santhosh97 commited on
Commit
a40e4e2
·
1 Parent(s): b16d55b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -151
app.py CHANGED
@@ -1,29 +1,30 @@
1
- import base64
 
2
  import boto3
3
  from botocore.config import Config
4
  from dotenv import load_dotenv
5
  import os
6
  import shutil
7
- from typing import List, Tuple
8
  import uuid
9
- import zipfile
10
  import argparse
11
  import logging
12
  import sendgrid
13
  from sendgrid.helpers.mail import Mail, Email, To, Content
14
-
15
- from glob import glob
16
- from io import BytesIO
17
- from itertools import cycle
18
 
19
  import requests
20
  import banana_dev as banana
21
  import streamlit as st
22
  from PIL import Image
23
- from st_btn_select import st_btn_select
24
  from streamlit_image_select import image_select
25
  import smart_open
26
 
 
 
 
27
  logging.basicConfig()
28
  logger = logging.getLogger(__name__)
29
  logger.setLevel(logging.INFO)
@@ -33,29 +34,74 @@ logger.setLevel(logging.INFO)
33
  load_dotenv()
34
 
35
  _S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Command-line arguments to control some stuff for easier local testing.
38
  # Eventually may want to move everything into functions and have a
39
  # if __name__ == "main" setup instead of everything inline.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- bool_session_variables = ['captcha_bool', 'view', 'train_view', 'login', 'user_auth_sess', 'face_images_uploaded', 'theme_images_uploaded']
42
- for var in bool_session_variables:
43
- if var not in st.session_state:
44
- st.session_state[var] = False
45
-
46
- obj_session_variables = ['s3_face_file_path', 's3_theme_file_path', 'captcha_response', 'user_email']
47
- for var in obj_session_variables:
48
- if var not in st.session_state:
49
- st.session_state[var] = None
50
-
51
- if "key" not in st.session_state:
52
- st.session_state["key"] = uuid.uuid4().hex
53
 
54
- if "captcha" not in st.session_state:
55
- st.session_state["captcha"] = {}
56
 
57
- def callback():
58
- st.session_state["button_clicked"] = True
59
 
60
 
61
  def bucket_parts(s3_path: str) -> Tuple[str, str]:
@@ -129,43 +175,46 @@ def generate_s3_put_url(s3_path: str, expiration_seconds: int) -> str:
129
  return upload_url
130
 
131
 
132
- def zip_and_upload_images(identifier: str, uploaded_files: List[str], image_type: str) -> str:
133
  """Save images as zip file to s3 for use in backend.
134
 
135
  Blocks until images are processed, added to zip file, and uploaded to S3.
136
 
137
  Args:
138
  identifier: unique identifier for the run, used in s3 link
139
- uploaded_files: list of file names
140
  image_type: string to identify different batches of images used in the
141
  backend model/training. Currently used values: "face", "theme"
142
 
143
  Returns:
144
  S3 location of zip file containing png images.
145
  """
146
- if not os.path.exists(identifier):
147
- os.makedirs(identifier)
148
-
149
- logger.info("Processing uploaded images")
150
- for num, uploaded_file in enumerate(uploaded_files):
151
- file_ = Image.open(uploaded_file).convert("RGB")
152
- file_.save(f"{identifier}/{num}_test.png")
153
- local_zip_filestem = f"{identifier}_{image_type}_images"
154
- logger.info("Making zip archive")
155
- shutil.make_archive(local_zip_filestem, "zip", identifier)
156
- local_zip_filename = f"{local_zip_filestem}.zip"
157
-
158
- logger.info("Uploading zip file to s3")
159
- # TODO: can we define expiration when making the s3 path?
160
- # Probably if we use the boto3 library instead of smart open
161
- s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
162
-
163
- with open(local_zip_filename, "rb") as fin:
164
- with smart_open.open(s3_path, "wb") as fout:
165
- fout.write(fin.read())
166
- logger.info(f"Completed upload to {s3_path}")
167
-
168
- return s3_path
 
 
 
169
 
170
  def send_email(to_email, user_code):
171
  sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
@@ -178,9 +227,6 @@ def send_email(to_email, user_code):
178
  response = sg.client.mail.send.post(request_body=mail_json)
179
 
180
 
181
- CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
182
- VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
183
-
184
  # Create a function to generate a captcha
185
  def generate_captcha():
186
  # Make a GET request to the API endpoint to generate a captcha
@@ -211,55 +257,73 @@ def verify_captcha(captcha_id, captcha_response):
211
  return {"error": "Failed to verify captcha"}
212
 
213
  def train_model(model_inputs):
214
- api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
215
- model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d"
216
- _ = banana.run(api_key, model_key, model_inputs)
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- #first page - user provides email
219
- if st.session_state["user_email"] is None:
220
  user_email_input = st.empty()
221
- with user_email_input.form(key='user_email_provided'):
222
  text_input = st.text_input(label='Please Enter Your Email')
223
  submit_button = st.form_submit_button(label='Submit')
224
- if submit_button:
225
- st.session_state["user_auth_sess"] = True
226
  st.session_state["user_email"] = text_input
227
- send_email(st.session_state["user_email"], str(st.session_state["key"]))
228
- user_email_input.empty()
229
-
230
- #second page - user inputs one time code
231
- if st.session_state["user_auth_sess"]:
 
 
 
232
  user_auth = st.empty()
233
- with user_auth.form("one-time-code"):
234
  text_input = st.text_input(label='Please Input One Time Code')
235
  submit_button = st.form_submit_button(label='Submit')
236
  if submit_button:
237
  if text_input == st.session_state["key"]:
238
- st.session_state["login"] = True
 
239
  else:
240
  st.markdown("Please Enter Correct Code!")
241
 
242
- #third page - user inputs face images
243
- if st.session_state["login"]:
244
  identifier = st.session_state["key"]
245
- user_auth.empty()
246
  face_images = st.empty()
247
- with face_images.form("face_images_upload"):
248
  uploaded_files = st.file_uploader(
249
  "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
250
  )
 
 
251
  submitted = st.form_submit_button(f"Upload")
252
  if submitted:
253
  with st.spinner('Uploading...'):
254
- st.session_state["s3_face_file_path"] = zip_and_upload_images(
255
  identifier, uploaded_files, "face"
256
  )
 
 
257
  st.success(f'Uploading {len(uploaded_files)} files done!')
258
- st.session_state["face_images_uploaded"] = True
 
 
259
 
260
- #fourth page - user inputs theme images
261
- if st.session_state["face_images_uploaded"]:
262
- face_images.empty()
263
  preset_theme_images = st.empty()
264
  with preset_theme_images.form("choose-preset-theme"):
265
  img = image_select(
@@ -274,6 +338,7 @@ if st.session_state["face_images_uploaded"]:
274
  )
275
 
276
  col1, col2 = st.columns([0.17, 1])
 
277
  with col1:
278
  submitted_3 = st.form_submit_button("Submit!")
279
  if submitted_3:
@@ -287,13 +352,19 @@ if st.session_state["face_images_uploaded"]:
287
  2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
288
  }
289
  st.session_state["model_inputs"] = {
290
- "superhero_file_path": dictionary[img][0],
291
- "person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
292
- "superhero_prompt": dictionary[img][1],
 
 
 
 
293
  "num_images": 50,
 
294
  }
295
  st.success("Success!")
296
- st.session_state["theme_images_uploaded"] = True
 
297
  with col2:
298
  submitted_4 = st.form_submit_button(
299
  "If none of the themes interest you, click here!"
@@ -302,6 +373,7 @@ if st.session_state["face_images_uploaded"]:
302
  st.session_state["view"] = True
303
 
304
  if st.session_state["view"]:
 
305
  custom_theme_images = st.empty()
306
  with custom_theme_images.form("input_custom_themes"):
307
  st.markdown("If none of the themes interest you, please input your own!")
@@ -310,86 +382,113 @@ if st.session_state["face_images_uploaded"]:
310
  accept_multiple_files=True,
311
  type=["png", "jpg", "jpeg"],
312
  )
313
- title = st.text_input("Theme Name")
 
 
314
  submitted_3 = st.form_submit_button("Submit!")
315
  if submitted_3:
316
  with st.spinner('Uploading...'):
317
- st.session_state["s3_theme_file_path"] = zip_and_upload_images(
318
  identifier, uploaded_files_2, "theme"
319
  )
 
320
  st.session_state["model_inputs"] = {
321
  # Use presigned urls since backend does not have credentials
322
- "superhero_file_path": generate_s3_get_url(st.session_state["s3_theme_file_path"], expiration_seconds=3600),
323
- "person_file_path": generate_s3_get_url(st.session_state["s3_face_file_path"], expiration_seconds=3600),
324
- "superhero_prompt": title,
 
 
 
325
  "num_images": 50,
 
326
  }
327
  st.success('Done!')
328
- st.session_state["theme_images_uploaded"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- #fifth page - user enters captcha and trains model
331
- if st.session_state["theme_images_uploaded"]:
332
- if st.session_state["view"]:
333
- custom_theme_images.empty()
334
- preset_theme_images.empty()
335
- train = st.empty()
336
- with train.form("training"):
337
- submitted = st.form_submit_button("Train Model!")
338
  if submitted:
339
- st.session_state["captcha_bool"] = True
340
-
341
- if st.session_state["captcha_bool"]:
342
- captcha_form = st.empty()
343
- with captcha_form.form("captcha_form", clear_on_submit=True):
344
- # Create container to create image/text input out of order from the
345
- # format submit button. Needed since we need to know the status of the
346
- # form submit to know what the captcha should do.
347
- captcha_container = st.container()
348
- display_captcha = True
349
- # TODO: Submit button renders first, then drops down once the image is
350
- # fetched leading to page reflow. Would be nice to not have reflow, but
351
- # we need to know if the submit button was previously pressed and if the
352
- # captcha was solved to generate and display a new captcha or not.
353
- # Possible solution is use an on_click callback to set a session_state
354
- # variable to access whether the button was pushed or not instead of the
355
- # return value here.
356
- submitted = st.form_submit_button("Submit Captcha!")
357
-
358
- if submitted:
359
- result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
360
- del st.session_state["captcha_response"]
361
- if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
362
- st.session_state['captcha'] = {}
363
- display_captcha = False
364
- with st.spinner("Model Fine Tuning..."):
365
- st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
366
- st.session_state["model_inputs"]["email"] = st.session_state["user_email"]
367
- s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
368
- # The backend does not have s3 credentials, so generate
369
- # presigned urls for the backend to use to write and read
370
- # the generated images.
371
- st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
372
- s3_output_path, expiration_seconds=60 * 60 * 24,
373
- )
374
- st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
375
- s3_output_path, expiration_seconds=3600,
376
- )
377
- train_model(st.session_state["model_inputs"])
378
- st.session_state["train_view"] = True
379
- else:
380
- st.error(result['error'])
381
-
382
- if display_captcha:
383
- # Generate new captcha and display. Occurs on first load with the
384
- # captcha_bool=True, or after previously failed captcha attempts.
385
- result = generate_captcha()
386
- captcha_id = result['uuid']
387
- captcha_image = result['captcha']
388
-
389
- st.session_state['captcha']['uuid'] = captcha_id
390
- st.session_state['captcha']['captcha'] = captcha_image
391
-
392
- captcha_container.image(captcha_image, width=300)
393
-
394
- captcha_container.text_input("Enter the captcha response", key="captcha_response")
395
- # Submit button already setup previously.
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
  import boto3
4
  from botocore.config import Config
5
  from dotenv import load_dotenv
6
  import os
7
  import shutil
8
+ from typing import List, Tuple, TYPE_CHECKING
9
  import uuid
 
10
  import argparse
11
  import logging
12
  import sendgrid
13
  from sendgrid.helpers.mail import Mail, Email, To, Content
14
+ from enum import Enum
15
+ import tempfile
16
+ from pathlib import Path
 
17
 
18
  import requests
19
  import banana_dev as banana
20
  import streamlit as st
21
  from PIL import Image
 
22
  from streamlit_image_select import image_select
23
  import smart_open
24
 
25
+ if TYPE_CHECKING:
26
+ from io import BytesIO
27
+
28
  logging.basicConfig()
29
  logger = logging.getLogger(__name__)
30
  logger.setLevel(logging.INFO)
 
34
  load_dotenv()
35
 
36
  _S3_PATH_OUTPUT = "s3://gretel-image-synthetics-use2/data/{identifier}/{image_type}_images.zip"
37
+ CAPTCHA_ENDPOINT = "https://captcha-api.akshit.me/v2/generate"
38
+ VERIFY_ENDPOINT = "https://captcha-api.akshit.me/v2/verify"
39
+
40
+ class UxState(str, Enum):
41
+ LOGIN = "login"
42
+ VERIFY_EMAIL = "verify_email"
43
+ UPLOAD1 = "upload1"
44
+ UPLOAD2 = "upload2"
45
+ CAPTCHA = "captcha"
46
+ TRAIN = "train"
47
 
48
  # Command-line arguments to control some stuff for easier local testing.
49
  # Eventually may want to move everything into functions and have a
50
  # if __name__ == "main" setup instead of everything inline.
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument(
53
+ "--dry-run", action="store_true",
54
+ help="Skip sending train request to backend server.",
55
+ )
56
+ parser.add_argument(
57
+ "--train-endpoint-url", default=None,
58
+ help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
59
+ )
60
+ cli_args = parser.parse_args()
61
+
62
+
63
+ def setup_session_state():
64
+ if "key" not in st.session_state:
65
+ st.session_state["key"] = uuid.uuid4().hex
66
+
67
+ if "ux_state" not in st.session_state:
68
+ st.session_state["ux_state"] = UxState.LOGIN
69
+
70
+ if "model_inputs" not in st.session_state:
71
+ st.session_state["model_inputs"] = None
72
+
73
+ if "initial_concept_file_path" not in st.session_state:
74
+ st.session_state["initial_concept_file_path"] = None
75
+
76
+ if "initial_token" not in st.session_state:
77
+ st.session_state["initial_token"] = None
78
+
79
+ if "initial_class_token" not in st.session_state:
80
+ st.session_state["initial_class_token"] = None
81
+
82
+ if "secondary_concept_file_path" not in st.session_state:
83
+ st.session_state["secondary_concept_file_path"] = None
84
+
85
+ if "secondary_token" not in st.session_state:
86
+ st.session_state["secondary_token"] = None
87
+
88
+ if "secondary_class_token" not in st.session_state:
89
+ st.session_state["secondary_class_token"] = None
90
+
91
+ if "prompt_keywords" not in st.session_state:
92
+ st.session_state["prompt_keywords"] = None
93
+
94
+ if "view" not in st.session_state:
95
+ st.session_state["view"] = False
96
 
97
+ if "captcha_response" not in st.session_state:
98
+ st.session_state["captcha_response"] = None
 
 
 
 
 
 
 
 
 
 
99
 
100
+ if "captcha" not in st.session_state:
101
+ st.session_state["captcha"] = {}
102
 
103
+ if "user_email" not in st.session_state:
104
+ st.session_state["user_email"] = None
105
 
106
 
107
  def bucket_parts(s3_path: str) -> Tuple[str, str]:
 
175
  return upload_url
176
 
177
 
178
+ def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_type: str) -> str:
179
  """Save images as zip file to s3 for use in backend.
180
 
181
  Blocks until images are processed, added to zip file, and uploaded to S3.
182
 
183
  Args:
184
  identifier: unique identifier for the run, used in s3 link
185
+ uploaded_files: BytesIO or UploadedFile from streamlit fileuploader
186
  image_type: string to identify different batches of images used in the
187
  backend model/training. Currently used values: "face", "theme"
188
 
189
  Returns:
190
  S3 location of zip file containing png images.
191
  """
192
+ with tempfile.TemporaryDirectory() as temp_dir_name:
193
+ logger.info(f"Working from temp dir to zip and upload images: {temp_dir_name}")
194
+ temp_dir = Path(temp_dir_name)
195
+ if not os.path.exists(temp_dir / identifier):
196
+ os.makedirs(temp_dir / identifier)
197
+
198
+ logger.info("Processing uploaded images")
199
+ for num, uploaded_file in enumerate(uploaded_files):
200
+ file_ = Image.open(uploaded_file).convert("RGB")
201
+ file_.save(temp_dir / identifier / f"{num}_test.png")
202
+ local_zip_filestem = str(temp_dir / f"{identifier}_{image_type}_images")
203
+ logger.info("Making zip archive")
204
+ shutil.make_archive(local_zip_filestem, "zip", temp_dir / identifier)
205
+ local_zip_filename = f"{local_zip_filestem}.zip"
206
+
207
+ logger.info("Uploading zip file to s3")
208
+ # TODO: can we define expiration when making the s3 path?
209
+ # Probably if we use the boto3 library instead of smart open
210
+ s3_path = _S3_PATH_OUTPUT.format(identifier=identifier, image_type=image_type)
211
+
212
+ with open(local_zip_filename, "rb") as fin:
213
+ with smart_open.open(s3_path, "wb") as fout:
214
+ fout.write(fin.read())
215
+ logger.info(f"Completed upload to {s3_path}")
216
+
217
+ return s3_path
218
 
219
  def send_email(to_email, user_code):
220
  sg = sendgrid.SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY'))
 
227
  response = sg.client.mail.send.post(request_body=mail_json)
228
 
229
 
 
 
 
230
  # Create a function to generate a captcha
231
  def generate_captcha():
232
  # Make a GET request to the API endpoint to generate a captcha
 
257
  return {"error": "Failed to verify captcha"}
258
 
259
  def train_model(model_inputs):
260
+ if cli_args.dry_run:
261
+ logger.info("Skipping model training since --dry-run is enabled.")
262
+ logger.info(f"model_inputs: {model_inputs}")
263
+ return
264
+
265
+ if cli_args.train_endpoint_url is None:
266
+ # Use banana backend
267
+ api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
268
+ model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
269
+ st.markdown(str(model_inputs))
270
+ _ = banana.run(api_key, model_key, model_inputs)
271
+ else:
272
+ # Send request directly to specified url
273
+ _ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
274
 
275
+
276
+ def run_login():
277
  user_email_input = st.empty()
278
+ with user_email_input.form(key='user_auth'):
279
  text_input = st.text_input(label='Please Enter Your Email')
280
  submit_button = st.form_submit_button(label='Submit')
281
+ if submit_button:
 
282
  st.session_state["user_email"] = text_input
283
+ send_email(text_input, str(st.session_state["key"]))
284
+ st.session_state["ux_state"] = UxState.VERIFY_EMAIL
285
+ # TODO: alternately run this submit log in a callback to the input?
286
+ # or otherwise ensure we execute the runner for the new state
287
+ st.experimental_rerun()
288
+
289
+
290
+ def run_verify_email():
291
  user_auth = st.empty()
292
+ with user_auth.form("one-code"):
293
  text_input = st.text_input(label='Please Input One Time Code')
294
  submit_button = st.form_submit_button(label='Submit')
295
  if submit_button:
296
  if text_input == st.session_state["key"]:
297
+ st.session_state["ux_state"] = UxState.UPLOAD1
298
+ st.experimental_rerun()
299
  else:
300
  st.markdown("Please Enter Correct Code!")
301
 
302
+
303
+ def run_upload_initial():
304
  identifier = st.session_state["key"]
 
305
  face_images = st.empty()
306
+ with face_images.form("my_form"):
307
  uploaded_files = st.file_uploader(
308
  "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
309
  )
310
+ initial_concept_token = st.text_input("Token Name")
311
+ initial_concept_class_token = st.text_input("Token Class")
312
  submitted = st.form_submit_button(f"Upload")
313
  if submitted:
314
  with st.spinner('Uploading...'):
315
+ st.session_state["initial_concept_file_path"] = zip_and_upload_images(
316
  identifier, uploaded_files, "face"
317
  )
318
+ st.session_state["initial_token"] = initial_concept_token
319
+ st.session_state["initial_class_token"] = initial_concept_class_token
320
  st.success(f'Uploading {len(uploaded_files)} files done!')
321
+ st.session_state["ux_state"] = UxState.UPLOAD2
322
+ st.experimental_rerun()
323
+
324
 
325
+ def run_upload_secondary():
326
+ identifier = st.session_state["key"]
 
327
  preset_theme_images = st.empty()
328
  with preset_theme_images.form("choose-preset-theme"):
329
  img = image_select(
 
338
  )
339
 
340
  col1, col2 = st.columns([0.17, 1])
341
+ prompt_keywords = st.text_input("Prompt Keywords")
342
  with col1:
343
  submitted_3 = st.form_submit_button("Submit!")
344
  if submitted_3:
 
352
  2: ["https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/data/thor.zip", "thor"],
353
  }
354
  st.session_state["model_inputs"] = {
355
+ "secondary_concept_file_path": dictionary[img][0],
356
+ # Use presigned url since backend does not have credentials
357
+ "initial_token": st.session_state["initial_token"],
358
+ "secondary_token": dictionary[img][1],
359
+ "initial_class_token": st.session_state["initial_class_token"],
360
+ "secondary_class_token": 'superhero',
361
+ "initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
362
  "num_images": 50,
363
+ "prompt_keywords": prompt_keywords
364
  }
365
  st.success("Success!")
366
+ st.session_state["ux_state"] = UxState.CAPTCHA
367
+ st.experimental_rerun()
368
  with col2:
369
  submitted_4 = st.form_submit_button(
370
  "If none of the themes interest you, click here!"
 
373
  st.session_state["view"] = True
374
 
375
  if st.session_state["view"]:
376
+ # TODO: split into it's own ux state and function?
377
  custom_theme_images = st.empty()
378
  with custom_theme_images.form("input_custom_themes"):
379
  st.markdown("If none of the themes interest you, please input your own!")
 
382
  accept_multiple_files=True,
383
  type=["png", "jpg", "jpeg"],
384
  )
385
+ secondary_concept_token = st.text_input("Token Name")
386
+ secondary_concept_class_token = st.text_input("Token Class")
387
+ prompt_keywords = st.text_input("Prompt Keywords")
388
  submitted_3 = st.form_submit_button("Submit!")
389
  if submitted_3:
390
  with st.spinner('Uploading...'):
391
+ st.session_state["secondary_concept_file_path"] = zip_and_upload_images(
392
  identifier, uploaded_files_2, "theme"
393
  )
394
+ #st.markdown(secondary_concept_file_path)
395
  st.session_state["model_inputs"] = {
396
  # Use presigned urls since backend does not have credentials
397
+ "initial_concept_file_path": generate_s3_get_url(st.session_state["initial_concept_file_path"], expiration_seconds=3600),
398
+ "secondary_concept_file_path": generate_s3_get_url(st.session_state["secondary_concept_file_path"], expiration_seconds=3600),
399
+ "initial_token": st.session_state["initial_token"],
400
+ "secondary_token": secondary_concept_token,
401
+ "initial_class_token": st.session_state["initial_class_token"],
402
+ "secondary_class_token": secondary_concept_class_token,
403
  "num_images": 50,
404
+ "prompt_keywords": prompt_keywords
405
  }
406
  st.success('Done!')
407
+ st.session_state["ux_state"] = UxState.CAPTCHA
408
+ st.experimental_rerun()
409
+
410
+
411
+
412
+ def run_captcha():
413
+ captcha_form = st.empty()
414
+ with captcha_form.form("captcha_form", clear_on_submit=True):
415
+ # Create container to create image/text input out of order from the
416
+ # format submit button. Needed since we need to know the status of the
417
+ # form submit to know what the captcha should do.
418
+ captcha_container = st.container()
419
+ display_captcha = True
420
+ # TODO: Submit button renders first, then drops down once the image is
421
+ # fetched leading to page reflow. Would be nice to not have reflow, but
422
+ # we need to know if the submit button was previously pressed and if the
423
+ # captcha was solved to generate and display a new captcha or not.
424
+ # Possible solution is use an on_click callback to set a session_state
425
+ # variable to access whether the button was pushed or not instead of the
426
+ # return value here.
427
+ submitted = st.form_submit_button("Submit Captcha!")
428
 
 
 
 
 
 
 
 
 
429
  if submitted:
430
+ result = verify_captcha(st.session_state['captcha']['uuid'], st.session_state["captcha_response"])
431
+ del st.session_state["captcha_response"]
432
+ if 'message' in result and result['message'] == 'CAPTCHA_SOLVED':
433
+ st.session_state['captcha'] = {}
434
+ display_captcha = False
435
+ with st.spinner("Model Fine Tuning..."):
436
+ st.session_state["model_inputs"]["identifier"] = st.session_state["key"]
437
+ st.session_state["model_inputs"]["email"] = st.session_state["user_email"]
438
+ s3_output_path = _S3_PATH_OUTPUT.format(identifier=st.session_state["key"], image_type="generated")
439
+ # The backend does not have s3 credentials, so generate
440
+ # presigned urls for the backend to use to write and read
441
+ # the generated images.
442
+ st.session_state["model_inputs"]["output_s3_url_get"] = generate_s3_get_url(
443
+ s3_output_path, expiration_seconds=60 * 60 * 24,
444
+ )
445
+ st.session_state["model_inputs"]["output_s3_url_put"] = generate_s3_put_url(
446
+ s3_output_path, expiration_seconds=3600,
447
+ )
448
+ train_model(st.session_state["model_inputs"])
449
+ st.session_state["ux_state"] = UxState.TRAIN
450
+ st.experimental_rerun()
451
+ else:
452
+ st.error(result['error'])
453
+
454
+ if display_captcha:
455
+ # Generate new captcha and display. Occurs on first run of the
456
+ # captcha state, or after previously failed captcha attempts.
457
+ result = generate_captcha()
458
+ captcha_id = result['uuid']
459
+ captcha_image = result['captcha']
460
+
461
+ st.session_state['captcha']['uuid'] = captcha_id
462
+ st.session_state['captcha']['captcha'] = captcha_image
463
+
464
+ captcha_container.image(captcha_image, width=300)
465
+
466
+ captcha_container.text_input("Enter the captcha response", key="captcha_response")
467
+ # Submit button already setup previously.
468
+
469
+
470
+ def run_train():
471
+ st.write(f"Congratulations, your model is training.")
472
+ st.write(f"We'll send an email to {st.session_state['user_email']} when it's finished, usually about 20-30 minutes.")
473
+ st.write("You may close this browser window/tab.")
474
+
475
+
476
+ if __name__ == "__main__":
477
+ setup_session_state()
478
+
479
+ ux_state = st.session_state["ux_state"]
480
+
481
+ if ux_state == UxState.LOGIN:
482
+ run_login()
483
+ elif ux_state == UxState.VERIFY_EMAIL:
484
+ run_verify_email()
485
+ elif ux_state == UxState.UPLOAD1:
486
+ run_upload_initial()
487
+ elif ux_state == UxState.UPLOAD2:
488
+ run_upload_secondary()
489
+ elif ux_state == UxState.CAPTCHA:
490
+ run_captcha()
491
+ elif ux_state == UxState.TRAIN:
492
+ run_train()
493
+ else:
494
+ raise ValueError(f"Internal app error, unknown ux_state='{ux_state}'")