update code
Browse files- .gitignore +3 -0
- app.py +178 -2
- wait-waiting.gif +0 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
./test/
|
2 |
+
*/test/*
|
3 |
+
test/
|
app.py
CHANGED
@@ -1,4 +1,180 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
|
6 |
+
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
|
7 |
+
import uuid
|
8 |
+
import json
|
9 |
+
import urllib.request
|
10 |
+
import urllib.parse
|
11 |
+
import time
|
12 |
+
import random
|
13 |
+
import io
|
14 |
+
from icecream import ic
|
15 |
+
|
16 |
+
import base64
|
17 |
+
from io import BytesIO
|
18 |
+
|
19 |
+
# Access the secret
|
20 |
+
remote_url = os.environ.get("remote_url")
|
21 |
+
base_pass = os.environ.get("base_pass")
|
22 |
+
TOKEN = os.environ.get("TOKEN")
|
23 |
+
|
24 |
+
|
25 |
+
server_address = remote_url
|
26 |
+
client_id = str(uuid.uuid4())
|
27 |
+
|
28 |
+
|
29 |
+
# Convert PIL Image to a base64 string
|
30 |
+
def image_to_base64_str(image):
|
31 |
+
buffered = BytesIO()
|
32 |
+
image.save(buffered, format="png")
|
33 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
34 |
+
return img_str
|
35 |
+
|
36 |
+
|
37 |
+
# Function to check queue status
|
38 |
+
def get_queue_status(server_address):
|
39 |
+
try:
|
40 |
+
with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response:
|
41 |
+
queue_info = json.loads(response.read())
|
42 |
+
# Debug print to see what queue_info contains
|
43 |
+
# print("Queue Info:", queue_info)
|
44 |
+
running_count = len(queue_info.get('queue_running', []))
|
45 |
+
pending_count = len(queue_info.get('queue_pending', []))
|
46 |
+
return {'queue_running': running_count, 'queue_pending': pending_count}
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error retrieving queue info: {e}")
|
49 |
+
return {'queue_running': 0, 'queue_pending': 0} # Return 0s if there is an error or connection issue
|
50 |
+
|
51 |
+
def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64):
|
52 |
+
p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow,
|
53 |
+
"bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64}
|
54 |
+
data = json.dumps(p).encode('utf-8')
|
55 |
+
req = urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data)
|
56 |
+
# req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers={'Content-Type': 'application/json'})
|
57 |
+
return json.loads(urllib.request.urlopen(req).read())
|
58 |
+
|
59 |
+
def get_image(filename, subfolder, folder_type):
|
60 |
+
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
61 |
+
url_values = urllib.parse.urlencode(data)
|
62 |
+
# with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
|
63 |
+
with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response:
|
64 |
+
return response.read()
|
65 |
+
|
66 |
+
def get_history(prompt_id):
|
67 |
+
# with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
|
68 |
+
with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response:
|
69 |
+
return json.loads(response.read())
|
70 |
+
|
71 |
+
def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64):
|
72 |
+
prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64)['prompt_id']
|
73 |
+
output_images = {}
|
74 |
+
while True:
|
75 |
+
out = ws.recv()
|
76 |
+
if isinstance(out, str):
|
77 |
+
message = json.loads(out)
|
78 |
+
if message['type'] == 'executing':
|
79 |
+
data = message['data']
|
80 |
+
if data['node'] is None and data['prompt_id'] == prompt_id:
|
81 |
+
break #Execution is done
|
82 |
+
else:
|
83 |
+
continue #previews are binary data
|
84 |
+
|
85 |
+
history = get_history(prompt_id)[prompt_id]
|
86 |
+
|
87 |
+
# for o in history['outputs']:
|
88 |
+
for node_id in history['outputs']:
|
89 |
+
node_output = history['outputs'][node_id]
|
90 |
+
|
91 |
+
if 'images' in node_output:
|
92 |
+
images_output = []
|
93 |
+
for image in node_output['images']:
|
94 |
+
image_data = get_image(image['filename'], image['subfolder'], image['type'])
|
95 |
+
images_output.append(image_data)
|
96 |
+
output_images[node_id] = images_output
|
97 |
+
return output_images
|
98 |
+
|
99 |
+
|
100 |
+
#------------------------------------------------------------
|
101 |
+
|
102 |
+
def main():
|
103 |
+
st.title("SILO.AI - AMDsnap ver.1.0.leon ")
|
104 |
+
|
105 |
+
workflow = 'amdsnap_style'
|
106 |
+
bg_color = "white"
|
107 |
+
|
108 |
+
st.text('')
|
109 |
+
st.markdown('##### π€³ You can upload multiple images, Up to 10 π₯Ίππ€ͺπ')
|
110 |
+
uploaded_files = st.file_uploader("**Upload photos of your face**",
|
111 |
+
type=["jpg", "jpeg", "png"],
|
112 |
+
accept_multiple_files=True,
|
113 |
+
help="Select up to 10 images. Maximum file size: 50MB each.")
|
114 |
+
|
115 |
+
if uploaded_files:
|
116 |
+
cols = st.columns(len(uploaded_files[:10]))
|
117 |
+
for i, uploaded_file in enumerate(uploaded_files[:10]):
|
118 |
+
with cols[i]:
|
119 |
+
st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100)
|
120 |
+
|
121 |
+
st.text('')
|
122 |
+
text_prompt = st.text_input("**π Input your style prompt. Empty is OK!**")
|
123 |
+
st.text('')
|
124 |
+
password = st.text_input("**π Enter your password:**", type="password")
|
125 |
+
|
126 |
+
queue_status = get_queue_status(server_address)
|
127 |
+
if queue_status:
|
128 |
+
st.write(f"**ππ»ββοΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**")
|
129 |
+
|
130 |
+
submit_button = st.button('Submit')
|
131 |
+
|
132 |
+
if submit_button:
|
133 |
+
if password != base_pass:
|
134 |
+
st.error("Incorrect password. Please try again.")
|
135 |
+
elif not uploaded_files:
|
136 |
+
st.error("Please upload your face photo (up to 10).")
|
137 |
+
else:
|
138 |
+
# Process uploads only when submit is pressed
|
139 |
+
images_base64 = []
|
140 |
+
for uploaded_file in uploaded_files[:10]:
|
141 |
+
if uploaded_file.size <= 50 * 1024 * 1024: # 50MB limit
|
142 |
+
image = Image.open(uploaded_file)
|
143 |
+
images_base64.append(image_to_base64_str(image))
|
144 |
+
else:
|
145 |
+
st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.")
|
146 |
+
|
147 |
+
if not images_base64:
|
148 |
+
st.error("No valid images to process. Please upload images under 50MB each.")
|
149 |
+
return
|
150 |
+
|
151 |
+
st.write(f"Your text prompt: {text_prompt}")
|
152 |
+
with st.spinner('Processing... Please wait.'):
|
153 |
+
ws = websocket.WebSocket()
|
154 |
+
try:
|
155 |
+
ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN))
|
156 |
+
except websocket._exceptions.WebSocketAddressException as e:
|
157 |
+
st.error(f"Failed to connect to WebSocket server: {e}")
|
158 |
+
st.error(f"Please check your server address: {server_address}")
|
159 |
+
return
|
160 |
+
|
161 |
+
images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64)
|
162 |
+
ws.close()
|
163 |
+
|
164 |
+
if images:
|
165 |
+
st.write("Click on an image to view in full size.")
|
166 |
+
first_node_id, first_image_data_list = list(images.items())[0]
|
167 |
+
if first_image_data_list:
|
168 |
+
first_image_data = first_image_data_list[0]
|
169 |
+
st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500)
|
170 |
+
st.download_button(label="Download full-size image",
|
171 |
+
data=first_image_data,
|
172 |
+
file_name=f"{client_id}.png",
|
173 |
+
mime="image/png")
|
174 |
+
else:
|
175 |
+
st.write("No images to display.")
|
176 |
+
else:
|
177 |
+
st.write("No images returned.")
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
main()
|
wait-waiting.gif
ADDED
![]() |