Santhosh Subramanian commited on
Commit
9206066
·
1 Parent(s): 422fbbd

Add application file

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ import zipfile
6
+ from argparse import Namespace
7
+ from glob import glob
8
+ from io import BytesIO
9
+ from itertools import cycle
10
+
11
+ import banana_dev as banana
12
+ import streamlit as st
13
+ import torch
14
+ from diffusers import StableDiffusionPipeline
15
+ from PIL import Image
16
+ from st_btn_select import st_btn_select
17
+ from streamlit_image_select import image_select
18
+ from streamlit_multipage import MultiPage
19
+ from torch import autocast
20
+
21
+ if "key" not in st.session_state:
22
+ st.session_state["key"] = uuid.uuid4().hex
23
+
24
+ if "model_inputs" not in st.session_state:
25
+ st.session_state["model_inputs"] = None
26
+
27
+ if (
28
+ "s3_face_file_path" not in st.session_state
29
+ and "s3_theme_file_path" not in st.session_state
30
+ ):
31
+ st.session_state["s3_face_file_path"] = None
32
+ st.session_state["s3_theme_file_path"] = None
33
+
34
+ if "view" not in st.session_state:
35
+ st.session_state["view"] = False
36
+
37
+
38
+ def callback():
39
+ st.session_state["button_clicked"] = True
40
+
41
+
42
+ def zip_and_upload_images(identifier, uploaded_files, image_type):
43
+ if not os.path.exists(identifier):
44
+ os.makedirs(identifier)
45
+ for num, uploaded_file in enumerate(uploaded_files):
46
+ file_ = Image.open(uploaded_file).convert("RGB")
47
+ file_.save(f"{identifier}/{num}_test.png")
48
+ shutil.make_archive(f"{identifier}_{image_type}_images", "zip", identifier)
49
+ os.system(
50
+ f"aws s3 cp {identifier}_{image_type}_images.zip s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip --no-sign-request"
51
+ )
52
+ return f"s3://gretel-image-synthetics/data/{identifier}/{image_type}_images.zip"
53
+
54
+
55
+ def train_model(model_inputs):
56
+ api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
57
+ model_key = "fb9e7bcc-7291-4af6-b2fc-2e98a3b6e7e5"
58
+ st.markdown(str(model_inputs))
59
+ # out = banana.run(api_key, model_key, model_inputs)
60
+ # if not os.path.exists("generated"):
61
+ # os.makedirs("generated")
62
+ # for num, img in enumerate(out["modelOutputs"][0]["image_base64"]):
63
+ # image_encoded = img.encode("utf-8")
64
+ # image_bytes = BytesIO(base64.b64decode(image_encoded))
65
+ # image = Image.open(image_bytes)
66
+ # image.save(f"{num}_output.jpg")
67
+
68
+
69
+ identifier = st.session_state["key"]
70
+ face_images = st.empty()
71
+ with face_images.form("my_form"):
72
+ uploaded_files = st.file_uploader(
73
+ "Choose image files", accept_multiple_files=True, type=["png", "jpg", "jpeg"]
74
+ )
75
+ submitted = st.form_submit_button("Submit")
76
+ if submitted:
77
+ st.session_state["s3_face_file_path"] = zip_and_upload_images(
78
+ identifier, uploaded_files, "face"
79
+ )
80
+
81
+ preset_theme_images = st.empty()
82
+ with preset_theme_images.form("choose-preset-theme"):
83
+ img = image_select(
84
+ "Choose a Theme!",
85
+ images=[
86
+ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/got.png",
87
+ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/ironman.png",
88
+ "https://gretel-image-synthetics.s3.us-west-2.amazonaws.com/theme-images/thor.png",
89
+ ],
90
+ captions=["Game of Thrones", "Iron Man", "Thor"],
91
+ return_value="index",
92
+ )
93
+
94
+ col1, col2 = st.columns([0.15, 1])
95
+ with col1:
96
+ submitted_3 = st.form_submit_button("Submit!")
97
+ if submitted_3:
98
+ dictionary = {
99
+ 0: [
100
+ "s3://gretel-image-synthetics/data/game-of-thrones.zip",
101
+ "game-of-thrones",
102
+ ],
103
+ 1: ["s3://gretel-image-synthetics/data/iron-man.zip", "iron-man"],
104
+ 2: ["s3://gretel-image-synthetics/data/thor.zip", "thor"],
105
+ }
106
+ st.session_state["model_inputs"] = {
107
+ "superhero_file_path": dictionary[img][0],
108
+ "person_file_path": st.session_state["s3_face_file_path"],
109
+ "superhero_prompt": dictionary[img][1],
110
+ "num_images": 50,
111
+ }
112
+ with col2:
113
+ submitted_4 = st.form_submit_button(
114
+ "If none of the themes interest you, click here!"
115
+ )
116
+ if submitted_4:
117
+ st.session_state["view"] = True
118
+
119
+ if st.session_state["view"]:
120
+ custom_theme_images = st.empty()
121
+ with custom_theme_images.form("input_custom_themes"):
122
+ st.markdown("If none of the themes interest you, please input your own!")
123
+ uploaded_files_2 = st.file_uploader(
124
+ "Choose image files",
125
+ accept_multiple_files=True,
126
+ type=["png", "jpg", "jpeg"],
127
+ )
128
+ title = st.text_input("Theme Name")
129
+ submitted_3 = st.form_submit_button("Submit!")
130
+ if submitted_3:
131
+ st.session_state["s3_theme_file_path"] = zip_and_upload_images(
132
+ identifier, uploaded_files_2, "theme"
133
+ )
134
+ st.session_state["model_inputs"] = {
135
+ "superhero_file_path": st.session_state["s3_theme_file_path"],
136
+ "person_file_path": st.session_state["s3_face_file_path"],
137
+ "superhero_prompt": title,
138
+ "num_images": 50,
139
+ }
140
+
141
+ train = st.empty()
142
+ with train.form("training"):
143
+ submitted = st.form_submit_button("Train Model!")
144
+ if submitted:
145
+ train_model(st.session_state["model_inputs"])