AMDsnap / app.py
LeonJHKIM's picture
pause service
9dcd503
import streamlit as st
from PIL import Image
import numpy as np
import os
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import time
import random
import io
from icecream import ic
import base64
from io import BytesIO
# Access the secret
remote_url = os.environ.get("remote_url")
base_pass = os.environ.get("base_pass")
TOKEN = os.environ.get("TOKEN")
server_address = remote_url
client_id = str(uuid.uuid4())
# Convert PIL Image to a base64 string
def image_to_base64_str(image):
buffered = BytesIO()
image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
# Function to check queue status
def get_queue_status(server_address):
try:
with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response:
queue_info = json.loads(response.read())
# Debug print to see what queue_info contains
# print("Queue Info:", queue_info)
running_count = len(queue_info.get('queue_running', []))
pending_count = len(queue_info.get('queue_pending', []))
return {'queue_running': running_count, 'queue_pending': pending_count}
except Exception as e:
print(f"Error retrieving queue info: {e}")
return {'queue_running': 0, 'queue_pending': 0} # Return 0s if there is an error or connection issue
def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic):
p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow,
"bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64,
"id_strength":id_strength, "consistency_magic":consistency_magic}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data)
# req = urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers={'Content-Type': 'application/json'})
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
# with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response:
return response.read()
def get_history(prompt_id):
# with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response:
return json.loads(response.read())
def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic):
prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic)['prompt_id']
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
# for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
#------------------------------------------------------------
def main():
st.title("SILO.AI - AMDsnap ver.1.2.leon ✨")
workflow = 'amdsnap_style'
bg_color = "white"
st.text('')
st.markdown('##### 🀳 You can upload multiple images, Up to 10 πŸ₯ΊπŸ˜ƒπŸ€ͺπŸ™')
uploaded_files = st.file_uploader("**Upload photos of your face**",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
help="Select up to 10 images. Maximum file size: 50MB each.")
if uploaded_files:
cols = st.columns(len(uploaded_files[:10]))
for i, uploaded_file in enumerate(uploaded_files[:10]):
with cols[i]:
st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100)
st.text('')
text_prompt = st.text_input("**πŸ“ Input your style prompt. ex)background color, hair style...**")
st.text('')
password = st.text_input("**πŸ” Enter your password:**", type="password")
# # Your app code here
# value = st.slider("Identity Extraction Strenth", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
# # You can use the selected value in your app
# st.write(f"You selected: {value}")
#------------------------------------------------------------------------
id_strength = st.slider(
"πŸ’‰ Identity Extraction Strength",
min_value=0.0,
max_value=1.0,
value=0.75,
step=0.01
)
if id_strength > 0.75:
st.warning("High strength: This will heavily focus on extracted identity.")
elif id_strength < 0.55:
st.warning("Low strength: This will mostly adapt to the prompt.")
else:
st.success("Balanced strength: This provides a mix of identity extraction and prompt adaptation.")
#------------------------------------------------------------------------
# Number input for consistency magic value
consistency_magic = st.number_input(
"πŸ¦„ Style Consistency Magic Value (1 ~ 4,294,967,294):",
min_value=1,
max_value=4294967294,
value=None,
step=1,
help="Enter a value between 1 and 4,294,967,294 to use as a seed for random number generation. Leave empty for a server-generated random value."
)
st.caption("If left empty, the server will generate a random value.")
#------------------------------------------------------------------------
queue_status = get_queue_status(server_address)
if queue_status:
st.write(f"**πŸƒπŸ»β€β™€οΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**")
submit_button = st.button('Submit')
if submit_button:
if password != base_pass:
st.error("Incorrect password. Please try again.")
elif not uploaded_files:
st.error("Please upload your face photo (up to 10).")
else:
# Process uploads only when submit is pressed
images_base64 = []
for uploaded_file in uploaded_files[:10]:
if uploaded_file.size <= 50 * 1024 * 1024: # 50MB limit
image = Image.open(uploaded_file)
images_base64.append(image_to_base64_str(image))
else:
st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.")
if not images_base64:
st.error("No valid images to process. Please upload images under 50MB each.")
return
with st.spinner('Processing... Please wait.'):
gif_runner = st.image("https://huggingface.co/spaces/LeonJHK/AMDsnap/resolve/main/wait-waiting.gif")
ws = websocket.WebSocket()
try:
ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN))
except websocket._exceptions.WebSocketAddressException as e:
st.error(f"Failed to connect to WebSocket server: {e}")
st.error(f"Please check your server address: {server_address}")
return
images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64, id_strength, consistency_magic)
ws.close()
gif_runner.empty() # Remove the GIF after processing is done
if images:
st.write("Click on an image to view in full size.")
first_node_id, first_image_data_list = list(images.items())[0]
if first_image_data_list:
first_image_data = first_image_data_list[0]
st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500)
st.download_button(label="Download full-size image",
data=first_image_data,
file_name=f"{client_id}.png",
mime="image/png")
else:
st.write("No images to display.")
else:
st.write("No images returned.")
if __name__ == "__main__":
# main()
st.title("See you on next version. Thank you for interesting! πŸ‘‹πŸΌπŸ™‡πŸ»β€β™‚οΈ -Leon-")