Spaces:
Runtime error
Runtime error
File size: 5,388 Bytes
9206066 097c210 9206066 e52db89 9206066 724e585 9206066 724e585 9206066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import base64
import os
import shutil
import uuid
import zipfile
from argparse import Namespace
from glob import glob
from io import BytesIO
from itertools import cycle
import banana_dev as banana
import streamlit as st
from PIL import Image
from st_btn_select import st_btn_select
from streamlit_image_select import image_select
if "key" not in st.session_state:
st.session_state["key"] = uuid.uuid4().hex
if "model_inputs" not in st.session_state:
st.session_state["model_inputs"] = None
if (
"s3_face_file_path" not in st.session_state
and "s3_theme_file_path" not in st.session_state
):
st.session_state["s3_face_file_path"] = None
st.session_state["s3_theme_file_path"] = None
if "view" not in st.session_state:
st.session_state["view"] = False
def callback():
st.session_state["button_clicked"] = True
os.system('aws configure set default.s3.multipart_threshold 200MB')
def zip_and_upload_images(identifier, uploaded_files, image_type):
if not os.path.exists(identifier):
os.makedirs(identifier)
for num, uploaded_file in enumerate(uploaded_files):
file_ = Image.open(uploaded_file).convert("RGB")
file_.save(f"{identifier}/{num}_test.png")
shutil.make_archive(f"{identifier}_{image_type}_images", "zip", identifier)
os.system(
f"aws s3 cp {identifier}_{image_type}_images.zip s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip --no-sign-request"
)
return f"s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip"
def train_model(model_inputs):
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
model_key = "fb9e7bcc-7291-4af6-b2fc-2e98a3b6e7e5"
st.markdown(str(model_inputs))
# out = banana.run(api_key, model_key, model_inputs)
# if not os.path.exists("generated"):
# os.makedirs("generated")
# for num, img in enumerate(out["modelOutputs"][0]["image_base64"]):
# image_encoded = img.encode("utf-8")
# image_bytes = BytesIO(base64.b64decode(image_encoded))
# image = Image.open(image_bytes)
# image.save(f"{num}_output.jpg")
identifier = st.session_state["key"]
face_images = st.empty()
with face_images.form("my_form"):
uploaded_files = st.file_uploader(
"Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
)
submitted = st.form_submit_button("Submit")
if submitted:
with st.spinner('Uploading...'):
st.session_state["s3_face_file_path"] = zip_and_upload_images(
identifier, uploaded_files, "face"
)
st.success('Done!')
preset_theme_images = st.empty()
with preset_theme_images.form("choose-preset-theme"):
img = image_select(
"Choose a Theme!",
images=[
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
"https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/thor.png",
],
captions=["Game of Thrones", "Iron Man", "Thor"],
return_value="index",
)
col1, col2 = st.columns([0.15, 1])
with col1:
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
dictionary = {
0: [
"s3://gretel-image-synthetics/data/game-of-thrones.zip",
"game-of-thrones",
],
1: ["s3://gretel-image-synthetics/data/iron-man.zip", "iron-man"],
2: ["s3://gretel-image-synthetics/data/thor.zip", "thor"],
}
st.session_state["model_inputs"] = {
"superhero_file_path": dictionary[img][0],
"person_file_path": st.session_state["s3_face_file_path"],
"superhero_prompt": dictionary[img][1],
"num_images": 50,
}
with col2:
submitted_4 = st.form_submit_button(
"If none of the themes interest you, click here!"
)
if submitted_4:
st.session_state["view"] = True
if st.session_state["view"]:
custom_theme_images = st.empty()
with custom_theme_images.form("input_custom_themes"):
st.markdown("If none of the themes interest you, please input your own!")
uploaded_files_2 = st.file_uploader(
"Choose image files",
accept_multiple_files=True,
type=["png", "jpg", "jpeg"],
)
title = st.text_input("Theme Name")
submitted_3 = st.form_submit_button("Submit!")
if submitted_3:
with st.spinner('Uploading...'):
st.session_state["s3_theme_file_path"] = zip_and_upload_images(
identifier, uploaded_files_2, "theme"
)
st.session_state["model_inputs"] = {
"superhero_file_path": st.session_state["s3_theme_file_path"],
"person_file_path": st.session_state["s3_face_file_path"],
"superhero_prompt": title,
"num_images": 50,
}
st.success('Done!')
train = st.empty()
with train.form("training"):
submitted = st.form_submit_button("Train Model!")
if submitted:
train_model(st.session_state["model_inputs"])
|