File size: 7,824 Bytes
5ab4237 90c7d8e 5ab4237 90c7d8e 23eadb3 90c7d8e 23eadb3 90c7d8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import streamlit as st
from PIL import Image
import numpy as np
import os
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import time
import random
import io
from icecream import ic
import base64
from io import BytesIO
# Access the secret
remote_url = os.environ.get("remote_url")
base_pass = os.environ.get("base_pass")
TOKEN = os.environ.get("TOKEN")
server_address = remote_url
client_id = str(uuid.uuid4())
# Convert PIL Image to a base64 string
def image_to_base64_str(image):
buffered = BytesIO()
image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
# Function to check queue status
def get_queue_status(server_address):
try:
with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response:
queue_info = json.loads(response.read())
# Debug print to see what queue_info contains
# print("Queue Info:", queue_info)
running_count = len(queue_info.get('queue_running', []))
pending_count = len(queue_info.get('queue_pending', []))
return {'queue_running': running_count, 'queue_pending': pending_count}
except Exception as e:
print(f"Error retrieving queue info: {e}")
return {'queue_running': 0, 'queue_pending': 0} # Return 0s if there is an error or connection issue
def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64):
p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow,
"bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data)
# req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers={'Content-Type': 'application/json'})
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
# with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response:
return response.read()
def get_history(prompt_id):
# with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response:
return json.loads(response.read())
def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64):
prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64)['prompt_id']
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
# for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
#------------------------------------------------------------
def main():
st.title("SILO.AI - AMDsnap ver.1.0.leon ")
workflow = 'amdsnap_style'
bg_color = "white"
st.text('')
st.markdown('##### π€³ You can upload multiple images, Up to 10 π₯Ίππ€ͺπ')
uploaded_files = st.file_uploader("**Upload photos of your face**",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
help="Select up to 10 images. Maximum file size: 50MB each.")
if uploaded_files:
cols = st.columns(len(uploaded_files[:10]))
for i, uploaded_file in enumerate(uploaded_files[:10]):
with cols[i]:
st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100)
st.text('')
text_prompt = st.text_input("**π Input your style prompt. Empty is OK!**")
st.text('')
password = st.text_input("**π Enter your password:**", type="password")
queue_status = get_queue_status(server_address)
if queue_status:
st.write(f"**ππ»ββοΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**")
submit_button = st.button('Submit')
if submit_button:
if password != base_pass:
st.error("Incorrect password. Please try again.")
elif not uploaded_files:
st.error("Please upload your face photo (up to 10).")
else:
# Process uploads only when submit is pressed
images_base64 = []
for uploaded_file in uploaded_files[:10]:
if uploaded_file.size <= 50 * 1024 * 1024: # 50MB limit
image = Image.open(uploaded_file)
images_base64.append(image_to_base64_str(image))
else:
st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.")
if not images_base64:
st.error("No valid images to process. Please upload images under 50MB each.")
return
st.write(f"Your text prompt: {text_prompt}")
with st.spinner('Processing... Please wait.'):
gif_runner = st.image("https://huggingface.co/spaces/LeonJHK/AMDsnap/blob/main/wait-waiting.gif")
ws = websocket.WebSocket()
try:
ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN))
except websocket._exceptions.WebSocketAddressException as e:
st.error(f"Failed to connect to WebSocket server: {e}")
st.error(f"Please check your server address: {server_address}")
return
images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64)
ws.close()
gif_runner.empty() # Remove the GIF after processing is done
if images:
st.write("Click on an image to view in full size.")
first_node_id, first_image_data_list = list(images.items())[0]
if first_image_data_list:
first_image_data = first_image_data_list[0]
st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500)
st.download_button(label="Download full-size image",
data=first_image_data,
file_name=f"{client_id}.png",
mime="image/png")
else:
st.write("No images to display.")
else:
st.write("No images returned.")
if __name__ == "__main__":
main() |