rayochoajr's picture
Update app.py
3e09e46 verified
Here's the complete code with the additional option for the server:
```python
import subprocess
import threading
import gradio as gr
import websocket
import uuid
import json
import urllib.request
import urllib.parse
from PIL import Image
import io
# πŸš€ Chapter 1: Install Necessary Packages πŸš€
def install_packages():
packages = [
"gradio",
"websocket-client",
"pillow"
]
for package in packages:
subprocess.check_call(["pip", "install", package])
# Use threading to run the installation in the background
install_thread = threading.Thread(target=install_packages)
install_thread.start()
install_thread.join()
# 🌟 Chapter 2: Generate Client ID 🌟
client_id = str(uuid.uuid4())
# 🌟 Chapter 3: Queue Prompt Function 🌟
def queue_prompt(prompt, server_address):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{server_address}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
# 🌟 Chapter 4: Get Image Function 🌟
def get_image(filename, subfolder, folder_type, server_address):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response:
return response.read()
# 🌟 Chapter 5: Get History Function 🌟
def get_history(prompt_id, server_address):
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response:
return json.loads(response.read())
# 🌟 Chapter 6: Get Images Function 🌟
def get_images(ws, prompt, server_address):
prompt_id = queue_prompt(prompt, server_address)['prompt_id']
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['prompt_id'] == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
# 🌟 Chapter 7: Generate Image Function 🌟
def generate_image(text_prompt, seed, quality, server):
steps_map = {
"Low": 8,
"Medium": 16,
"High": 30
}
steps = steps_map[quality]
prompt_text = f"""
{{
"3": {{
"class_type": "KSampler",
"inputs": {{
"cfg": 8,
"denoise": 1,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "euler",
"scheduler": "normal",
"seed": {seed},
"steps": {steps}
}}
}},
"4": {{
"class_type": "CheckpointLoaderSimple",
"inputs": {{
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
}}
}},
"5": {{
"class_type": "EmptyLatentImage",
"inputs": {{
"batch_size": 1,
"height": 512,
"width": 768
}}
}},
"6": {{
"class_type": "CLIPTextEncode",
"inputs": {{
"clip": [
"4",
1
],
"text": "{text_prompt}"
}}
}},
"7": {{
"class_type": "CLIPTextEncode",
"inputs": {{
"clip": [
"4",
1
],
"text": "bad hands"
}}
}},
"8": {{
"class_type": "VAEDecode",
"inputs": {{
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}}
}},
"save_image_websocket_node": {{
"class_type": "SaveImageWebsocket",
"inputs": {{
"images": [
"8",
0
]
}}
}}
}}
"""
prompt = json.loads(prompt_text)
if server == "AWS Server":
server_address = "3.14.144.23:8188"
elif server == "Home Server":
server_address = "192.168.50.136:8188"
else:
server_address = "73.206.199.71:18188"
ws = websocket.WebSocket()
ws.connect(f"ws://{server_address}/ws?clientId={client_id}")
images = get_images(ws, prompt, server_address)
image = None
for node_id in images:
for image_data in images[node_id]:
image = Image.open(io.BytesIO(image_data))
break
if image:
break
return image
# 🌟 Chapter 8: Cancel Request Function 🌟
def cancel_request():
return "Request Cancelled"
# 🌟 Chapter 9: Gradio Interface 🌟
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with Websockets API")
gr.Markdown("Generate images using a Websockets API and SaveImageWebsocket node.")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(label="Text Prompt", value="masterpiece best quality man")
seed = gr.Number(label="Seed", value=5)
quality = gr.Radio(label="Quality", choices=["Low", "Medium", "High"], value="Low")
server = gr.Radio(label="Server", choices=["AWS Server", "Home Server", "R - PC"], value="AWS Server")
generate_button = gr.Button("Generate Image")
cancel_button = gr.Button("Cancel Request")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_button.click(fn=generate_image, inputs=[text_prompt, seed, quality, server], outputs=output_image)
cancel_button.click(fn=cancel_request, inputs=[], outputs=[])
demo.launch()
```