import base64 import os import shutil import uuid import zipfile from argparse import Namespace from glob import glob from io import BytesIO from itertools import cycle import banana_dev as banana import streamlit as st from PIL import Image from st_btn_select import st_btn_select from streamlit_image_select import image_select if "key" not in st.session_state: st.session_state["key"] = uuid.uuid4().hex if "model_inputs" not in st.session_state: st.session_state["model_inputs"] = None if ( "s3_face_file_path" not in st.session_state and "s3_theme_file_path" not in st.session_state ): st.session_state["s3_face_file_path"] = None st.session_state["s3_theme_file_path"] = None if "view" not in st.session_state: st.session_state["view"] = False def callback(): st.session_state["button_clicked"] = True os.system('aws configure set default.s3.multipart_threshold 200MB') def zip_and_upload_images(identifier, uploaded_files, image_type): if not os.path.exists(identifier): os.makedirs(identifier) for num, uploaded_file in enumerate(uploaded_files): file_ = Image.open(uploaded_file).convert("RGB") file_.save(f"{identifier}/{num}_test.png") shutil.make_archive(f"{identifier}_{image_type}_images", "zip", identifier) os.system( f"aws s3 cp {identifier}_{image_type}_images.zip s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip --no-sign-request" ) return f"s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip" def train_model(model_inputs): api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e" model_key = "fb9e7bcc-7291-4af6-b2fc-2e98a3b6e7e5" st.markdown(str(model_inputs)) # out = banana.run(api_key, model_key, model_inputs) # if not os.path.exists("generated"): # os.makedirs("generated") # for num, img in enumerate(out["modelOutputs"][0]["image_base64"]): # image_encoded = img.encode("utf-8") # image_bytes = BytesIO(base64.b64decode(image_encoded)) # image = Image.open(image_bytes) # image.save(f"{num}_output.jpg") identifier = st.session_state["key"] face_images = st.empty() with face_images.form("my_form"): uploaded_files = st.file_uploader( "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"] ) submitted = st.form_submit_button("Submit") if submitted: with st.spinner('Uploading...'): st.session_state["s3_face_file_path"] = zip_and_upload_images( identifier, uploaded_files, "face" ) st.success('Done!') preset_theme_images = st.empty() with preset_theme_images.form("choose-preset-theme"): img = image_select( "Choose a Theme!", images=[ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png", "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png", "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/thor.png", ], captions=["Game of Thrones", "Iron Man", "Thor"], return_value="index", ) col1, col2 = st.columns([0.15, 1]) with col1: submitted_3 = st.form_submit_button("Submit!") if submitted_3: dictionary = { 0: [ "s3://gretel-image-synthetics/data/game-of-thrones.zip", "game-of-thrones", ], 1: ["s3://gretel-image-synthetics/data/iron-man.zip", "iron-man"], 2: ["s3://gretel-image-synthetics/data/thor.zip", "thor"], } st.session_state["model_inputs"] = { "superhero_file_path": dictionary[img][0], "person_file_path": st.session_state["s3_face_file_path"], "superhero_prompt": dictionary[img][1], "num_images": 50, } with col2: submitted_4 = st.form_submit_button( "If none of the themes interest you, click here!" ) if submitted_4: st.session_state["view"] = True if st.session_state["view"]: custom_theme_images = st.empty() with custom_theme_images.form("input_custom_themes"): st.markdown("If none of the themes interest you, please input your own!") uploaded_files_2 = st.file_uploader( "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"], ) title = st.text_input("Theme Name") submitted_3 = st.form_submit_button("Submit!") if submitted_3: with st.spinner('Uploading...'): st.session_state["s3_theme_file_path"] = zip_and_upload_images( identifier, uploaded_files_2, "theme" ) st.session_state["model_inputs"] = { "superhero_file_path": st.session_state["s3_theme_file_path"], "person_file_path": st.session_state["s3_face_file_path"], "superhero_prompt": title, "num_images": 50, } st.success('Done!') train = st.empty() with train.form("training"): submitted = st.form_submit_button("Train Model!") if submitted: train_model(st.session_state["model_inputs"])