LeonJHKIM commited on
Commit
90c7d8e
Β·
1 Parent(s): 5ab4237

update code

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +178 -2
  3. wait-waiting.gif +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ./test/
2
+ */test/*
3
+ test/
app.py CHANGED
@@ -1,4 +1,180 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ import os
5
 
6
+ import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
7
+ import uuid
8
+ import json
9
+ import urllib.request
10
+ import urllib.parse
11
+ import time
12
+ import random
13
+ import io
14
+ from icecream import ic
15
+
16
+ import base64
17
+ from io import BytesIO
18
+
19
+ # Access the secret
20
+ remote_url = os.environ.get("remote_url")
21
+ base_pass = os.environ.get("base_pass")
22
+ TOKEN = os.environ.get("TOKEN")
23
+
24
+
25
+ server_address = remote_url
26
+ client_id = str(uuid.uuid4())
27
+
28
+
29
+ # Convert PIL Image to a base64 string
30
+ def image_to_base64_str(image):
31
+ buffered = BytesIO()
32
+ image.save(buffered, format="png")
33
+ img_str = base64.b64encode(buffered.getvalue()).decode()
34
+ return img_str
35
+
36
+
37
+ # Function to check queue status
38
+ def get_queue_status(server_address):
39
+ try:
40
+ with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response:
41
+ queue_info = json.loads(response.read())
42
+ # Debug print to see what queue_info contains
43
+ # print("Queue Info:", queue_info)
44
+ running_count = len(queue_info.get('queue_running', []))
45
+ pending_count = len(queue_info.get('queue_pending', []))
46
+ return {'queue_running': running_count, 'queue_pending': pending_count}
47
+ except Exception as e:
48
+ print(f"Error retrieving queue info: {e}")
49
+ return {'queue_running': 0, 'queue_pending': 0} # Return 0s if there is an error or connection issue
50
+
51
+ def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64):
52
+ p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow,
53
+ "bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64}
54
+ data = json.dumps(p).encode('utf-8')
55
+ req = urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data)
56
+ # req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers={'Content-Type': 'application/json'})
57
+ return json.loads(urllib.request.urlopen(req).read())
58
+
59
+ def get_image(filename, subfolder, folder_type):
60
+ data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
61
+ url_values = urllib.parse.urlencode(data)
62
+ # with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
63
+ with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response:
64
+ return response.read()
65
+
66
+ def get_history(prompt_id):
67
+ # with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
68
+ with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response:
69
+ return json.loads(response.read())
70
+
71
+ def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64):
72
+ prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64)['prompt_id']
73
+ output_images = {}
74
+ while True:
75
+ out = ws.recv()
76
+ if isinstance(out, str):
77
+ message = json.loads(out)
78
+ if message['type'] == 'executing':
79
+ data = message['data']
80
+ if data['node'] is None and data['prompt_id'] == prompt_id:
81
+ break #Execution is done
82
+ else:
83
+ continue #previews are binary data
84
+
85
+ history = get_history(prompt_id)[prompt_id]
86
+
87
+ # for o in history['outputs']:
88
+ for node_id in history['outputs']:
89
+ node_output = history['outputs'][node_id]
90
+
91
+ if 'images' in node_output:
92
+ images_output = []
93
+ for image in node_output['images']:
94
+ image_data = get_image(image['filename'], image['subfolder'], image['type'])
95
+ images_output.append(image_data)
96
+ output_images[node_id] = images_output
97
+ return output_images
98
+
99
+
100
+ #------------------------------------------------------------
101
+
102
+ def main():
103
+ st.title("SILO.AI - AMDsnap ver.1.0.leon ")
104
+
105
+ workflow = 'amdsnap_style'
106
+ bg_color = "white"
107
+
108
+ st.text('')
109
+ st.markdown('##### 🀳 You can upload multiple images, Up to 10 πŸ₯ΊπŸ˜ƒπŸ€ͺπŸ™')
110
+ uploaded_files = st.file_uploader("**Upload photos of your face**",
111
+ type=["jpg", "jpeg", "png"],
112
+ accept_multiple_files=True,
113
+ help="Select up to 10 images. Maximum file size: 50MB each.")
114
+
115
+ if uploaded_files:
116
+ cols = st.columns(len(uploaded_files[:10]))
117
+ for i, uploaded_file in enumerate(uploaded_files[:10]):
118
+ with cols[i]:
119
+ st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100)
120
+
121
+ st.text('')
122
+ text_prompt = st.text_input("**πŸ“ Input your style prompt. Empty is OK!**")
123
+ st.text('')
124
+ password = st.text_input("**πŸ” Enter your password:**", type="password")
125
+
126
+ queue_status = get_queue_status(server_address)
127
+ if queue_status:
128
+ st.write(f"**πŸƒπŸ»β€β™€οΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**")
129
+
130
+ submit_button = st.button('Submit')
131
+
132
+ if submit_button:
133
+ if password != base_pass:
134
+ st.error("Incorrect password. Please try again.")
135
+ elif not uploaded_files:
136
+ st.error("Please upload your face photo (up to 10).")
137
+ else:
138
+ # Process uploads only when submit is pressed
139
+ images_base64 = []
140
+ for uploaded_file in uploaded_files[:10]:
141
+ if uploaded_file.size <= 50 * 1024 * 1024: # 50MB limit
142
+ image = Image.open(uploaded_file)
143
+ images_base64.append(image_to_base64_str(image))
144
+ else:
145
+ st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.")
146
+
147
+ if not images_base64:
148
+ st.error("No valid images to process. Please upload images under 50MB each.")
149
+ return
150
+
151
+ st.write(f"Your text prompt: {text_prompt}")
152
+ with st.spinner('Processing... Please wait.'):
153
+ ws = websocket.WebSocket()
154
+ try:
155
+ ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN))
156
+ except websocket._exceptions.WebSocketAddressException as e:
157
+ st.error(f"Failed to connect to WebSocket server: {e}")
158
+ st.error(f"Please check your server address: {server_address}")
159
+ return
160
+
161
+ images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64)
162
+ ws.close()
163
+
164
+ if images:
165
+ st.write("Click on an image to view in full size.")
166
+ first_node_id, first_image_data_list = list(images.items())[0]
167
+ if first_image_data_list:
168
+ first_image_data = first_image_data_list[0]
169
+ st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500)
170
+ st.download_button(label="Download full-size image",
171
+ data=first_image_data,
172
+ file_name=f"{client_id}.png",
173
+ mime="image/png")
174
+ else:
175
+ st.write("No images to display.")
176
+ else:
177
+ st.write("No images returned.")
178
+
179
+ if __name__ == "__main__":
180
+ main()
wait-waiting.gif ADDED