File size: 7,824 Bytes
5ab4237
90c7d8e
 
 
5ab4237
90c7d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23eadb3
90c7d8e
 
 
 
 
 
 
 
 
 
23eadb3
 
90c7d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
from PIL import Image
import numpy as np
import os

import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import time
import random
import io
from icecream import ic

import base64
from io import BytesIO

# Access the secret
remote_url = os.environ.get("remote_url")
base_pass = os.environ.get("base_pass")  
TOKEN = os.environ.get("TOKEN") 


server_address = remote_url
client_id = str(uuid.uuid4())


# Convert PIL Image to a base64 string
def image_to_base64_str(image):
    buffered = BytesIO()
    image.save(buffered, format="png")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str


# Function to check queue status
def get_queue_status(server_address):
    try:
        with urllib.request.urlopen(f"http://{server_address}/queue?token={TOKEN}") as response:
            queue_info = json.loads(response.read())
            # Debug print to see what queue_info contains
            # print("Queue Info:", queue_info)
            running_count = len(queue_info.get('queue_running', []))
            pending_count = len(queue_info.get('queue_pending', []))
            return {'queue_running': running_count, 'queue_pending': pending_count}
    except Exception as e:
        print(f"Error retrieving queue info: {e}")
        return {'queue_running': 0, 'queue_pending': 0}  # Return 0s if there is an error or connection issue

def queue_prompt(workflow, bg_color, pos_prompt, password, images_base64):
    p = {"prompt": {"dummy":"dummy"}, "client_id": client_id, "api":"silosnap_portrait", "workflow":workflow, 
         "bg_color":bg_color, "pos_prompt":pos_prompt, "pass":password, "upimages":images_base64}
    data = json.dumps(p).encode('utf-8')
    req =  urllib.request.Request("http://{}/prompt?token={}".format(server_address, TOKEN), data=data)
    # req =  urllib.request.Request("http://{}/prompt".format(server_address), data=data, headers={'Content-Type': 'application/json'})
    return json.loads(urllib.request.urlopen(req).read())

def get_image(filename, subfolder, folder_type):
    data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    url_values = urllib.parse.urlencode(data)
    # with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
    with urllib.request.urlopen("http://{}/view?{}&token={}".format(server_address, url_values, TOKEN)) as response:
        return response.read()

def get_history(prompt_id):
    # with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
    with urllib.request.urlopen("http://{}/history/{}?token={}".format(server_address, prompt_id, TOKEN)) as response:
        return json.loads(response.read())

def get_images(ws, workflow, bg_color, pos_prompt, password, images_base64):
    prompt_id = queue_prompt(workflow, bg_color, pos_prompt, password, images_base64)['prompt_id']
    output_images = {}
    while True:
        out = ws.recv()
        if isinstance(out, str):
            message = json.loads(out)
            if message['type'] == 'executing':
                data = message['data']
                if data['node'] is None and data['prompt_id'] == prompt_id:
                    break #Execution is done
        else:
            continue #previews are binary data

    history = get_history(prompt_id)[prompt_id]
    
    # for o in history['outputs']:
    for node_id in history['outputs']:
        node_output = history['outputs'][node_id]
        
        if 'images' in node_output:
            images_output = []
            for image in node_output['images']:
                image_data = get_image(image['filename'], image['subfolder'], image['type'])
                images_output.append(image_data) 
            output_images[node_id] = images_output
    return output_images


#------------------------------------------------------------

def main():
    st.title("SILO.AI - AMDsnap ver.1.0.leon ")
    
    workflow = 'amdsnap_style'
    bg_color = "white"

    st.text('')
    st.markdown('##### 🀳 You can upload multiple images, Up to 10 πŸ₯ΊπŸ˜ƒπŸ€ͺπŸ™')
    uploaded_files = st.file_uploader("**Upload photos of your face**", 
                                      type=["jpg", "jpeg", "png"], 
                                      accept_multiple_files=True, 
                                      help="Select up to 10 images. Maximum file size: 50MB each.")

    if uploaded_files:
        cols = st.columns(len(uploaded_files[:10]))
        for i, uploaded_file in enumerate(uploaded_files[:10]):
            with cols[i]:
                st.image(Image.open(uploaded_file), caption=uploaded_file.name, width=100)

    st.text('')
    text_prompt = st.text_input("**πŸ“ Input your style prompt. Empty is OK!**")
    st.text('')
    password = st.text_input("**πŸ” Enter your password:**", type="password")

    queue_status = get_queue_status(server_address)
    if queue_status:
        st.write(f"**πŸƒπŸ»β€β™€οΈ Current queue - Running: {queue_status['queue_running']}, Pending: {queue_status['queue_pending']}**")

    submit_button = st.button('Submit')

    if submit_button:
        if password != base_pass:
            st.error("Incorrect password. Please try again.")
        elif not uploaded_files:
            st.error("Please upload your face photo (up to 10).")
        else:
            # Process uploads only when submit is pressed
            images_base64 = []
            for uploaded_file in uploaded_files[:10]:
                if uploaded_file.size <= 50 * 1024 * 1024:  # 50MB limit
                    image = Image.open(uploaded_file)
                    images_base64.append(image_to_base64_str(image))
                else:
                    st.warning(f"File {uploaded_file.name} exceeds 50MB limit and was not processed.")

            if not images_base64:
                st.error("No valid images to process. Please upload images under 50MB each.")
                return

            st.write(f"Your text prompt: {text_prompt}")
            with st.spinner('Processing... Please wait.'):
                gif_runner = st.image("https://huggingface.co/spaces/LeonJHK/AMDsnap/blob/main/wait-waiting.gif")
                ws = websocket.WebSocket()
                try:
                    ws.connect("ws://{}/ws?clientId={}&token={}".format(server_address, client_id, TOKEN))
                except websocket._exceptions.WebSocketAddressException as e:
                    st.error(f"Failed to connect to WebSocket server: {e}")
                    st.error(f"Please check your server address: {server_address}")
                    return

                images = get_images(ws, workflow, bg_color, text_prompt, password, images_base64)
                ws.close()
                gif_runner.empty()  # Remove the GIF after processing is done
                
                
            if images:
                st.write("Click on an image to view in full size.")
                first_node_id, first_image_data_list = list(images.items())[0]
                if first_image_data_list:
                    first_image_data = first_image_data_list[0]
                    st.image(Image.open(io.BytesIO(first_image_data)), caption=f"Result Image", width=500)
                    st.download_button(label="Download full-size image",
                                       data=first_image_data,
                                       file_name=f"{client_id}.png",
                                       mime="image/png")
                else:
                    st.write("No images to display.")
            else:
                st.write("No images returned.")

if __name__ == "__main__":
    main()