|
import streamlit as st |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
import websocket |
|
import uuid |
|
import json |
|
import urllib.request |
|
import urllib.parse |
|
import time |
|
import random |
|
import io |
|
from icecream import ic |
|
|
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
remote_url = os.environ.get("remote_url") |
|
base_pass = os.environ.get("base_pass") |
|
TOKEN = os.environ.get("TOKEN") |
|
|
|
|
|
|
|
|
|
server_address = remote_url |
|
client_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
def image_to_base64_str(image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="png") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
return img_str |
|
|
|
|
|
|
|
def get_queue_status(server_address): |
|
try: |
|
with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response: |
|
queue_info = json.loads(response.read()) |
|
|
|
|
|
running_count = len(queue_info.get('queue_running', [])) |
|
pending_count = len(queue_info.get('queue_pending', [])) |
|
return {'queue_running': running_count, 'queue_pending': pending_count} |
|
except Exception as e: |
|
print(f"Error retrieving queue info: {e}") |
|
return {'queue_running': 0, 'queue_pending': 0} |
|
|
|
def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic): |
|
p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow, |
|
"bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64, |
|
"id_strength":id_strength, "consistency_magic":consistency_magic} |
|
data = json.dumps(p).encode('utf-8') |
|
req = urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data) |
|
|
|
return json.loads(urllib.request.urlopen(req).read()) |
|
|
|
def get_image(filename, subfolder, folder_type): |
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
|
url_values = urllib.parse.urlencode(data) |
|
|
|
with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response: |
|
return response.read() |
|
|
|
def get_history(prompt_id): |
|
|
|
with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response: |
|
return json.loads(response.read()) |
|
|
|
def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic): |
|
prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64, id_strength, consistency_magic)['prompt_id'] |
|
output_images = {} |
|
while True: |
|
out = ws.recv() |
|
if isinstance(out, str): |
|
message = json.loads(out) |
|
if message['type'] == 'executing': |
|
data = message['data'] |
|
if data['node'] is None and data['prompt_id'] == prompt_id: |
|
break |
|
else: |
|
continue |
|
|
|
history = get_history(prompt_id)[prompt_id] |
|
|
|
|
|
for node_id in history['outputs']: |
|
node_output = history['outputs'][node_id] |
|
|
|
if 'images' in node_output: |
|
images_output = [] |
|
for image in node_output['images']: |
|
image_data = get_image(image['filename'], image['subfolder'], image['type']) |
|
images_output.append(image_data) |
|
output_images[node_id] = images_output |
|
return output_images |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("SILO.AI - AMDsnap ver.1.2.leon β¨") |
|
|
|
workflow = 'amdsnap_style' |
|
bg_color = "white" |
|
|
|
st.text('') |
|
st.markdown('##### π€³ You can upload multiple images, Up to 10 π₯Ίππ€ͺπ') |
|
uploaded_files = st.file_uploader("**Upload photos of your face**", |
|
type=["jpg", "jpeg", "png"], |
|
accept_multiple_files=True, |
|
help="Select up to 10 images. Maximum file size: 50MB each.") |
|
|
|
if uploaded_files: |
|
cols = st.columns(len(uploaded_files[:10])) |
|
for i, uploaded_file in enumerate(uploaded_files[:10]): |
|
with cols[i]: |
|
st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100) |
|
|
|
st.text('') |
|
text_prompt = st.text_input("**π Input your style prompt. ex)background color, hair style...**") |
|
st.text('') |
|
password = st.text_input("**π Enter your password:**", type="password") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
id_strength = st.slider( |
|
"π Identity Extraction Strength", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.75, |
|
step=0.01 |
|
) |
|
if id_strength > 0.75: |
|
st.warning("High strength: This will heavily focus on extracted identity.") |
|
elif id_strength < 0.55: |
|
st.warning("Low strength: This will mostly adapt to the prompt.") |
|
else: |
|
st.success("Balanced strength: This provides a mix of identity extraction and prompt adaptation.") |
|
|
|
|
|
|
|
|
|
consistency_magic = st.number_input( |
|
"π¦ Style Consistency Magic Value (1 ~ 4,294,967,294):", |
|
min_value=1, |
|
max_value=4294967294, |
|
value=None, |
|
step=1, |
|
help="Enter a value between 1 and 4,294,967,294 to use as a seed for random number generation. Leave empty for a server-generated random value." |
|
) |
|
|
|
st.caption("If left empty, the server will generate a random value.") |
|
|
|
|
|
queue_status = get_queue_status(server_address) |
|
if queue_status: |
|
st.write(f"**ππ»ββοΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**") |
|
|
|
submit_button = st.button('Submit') |
|
|
|
if submit_button: |
|
if password != base_pass: |
|
st.error("Incorrect password. Please try again.") |
|
elif not uploaded_files: |
|
st.error("Please upload your face photo (up to 10).") |
|
else: |
|
|
|
images_base64 = [] |
|
for uploaded_file in uploaded_files[:10]: |
|
if uploaded_file.size <= 50 * 1024 * 1024: |
|
image = Image.open(uploaded_file) |
|
images_base64.append(image_to_base64_str(image)) |
|
else: |
|
st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.") |
|
|
|
if not images_base64: |
|
st.error("No valid images to process. Please upload images under 50MB each.") |
|
return |
|
|
|
with st.spinner('Processing... Please wait.'): |
|
gif_runner = st.image("https://huggingface.co/spaces/LeonJHK/AMDsnap/resolve/main/wait-waiting.gif") |
|
|
|
ws = websocket.WebSocket() |
|
try: |
|
ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN)) |
|
except websocket._exceptions.WebSocketAddressException as e: |
|
st.error(f"Failed to connect to WebSocket server: {e}") |
|
st.error(f"Please check your server address: {server_address}") |
|
return |
|
|
|
images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64, id_strength, consistency_magic) |
|
ws.close() |
|
gif_runner.empty() |
|
|
|
|
|
if images: |
|
st.write("Click on an image to view in full size.") |
|
first_node_id, first_image_data_list = list(images.items())[0] |
|
if first_image_data_list: |
|
first_image_data = first_image_data_list[0] |
|
st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500) |
|
st.download_button(label="Download full-size image", |
|
data=first_image_data, |
|
file_name=f"{client_id}.png", |
|
mime="image/png") |
|
else: |
|
st.write("No images to display.") |
|
else: |
|
st.write("No images returned.") |
|
|
|
if __name__ == "__main__": |
|
|
|
st.title("See you on next version. Thank you for interesting! ππΌππ»ββοΈ -Leon-") |