wsj-server / wsj-server.py
vettorazi's picture
new dockerfile
a756e18
#This is an example that uses the websockets api to know when a prompt execution is done
#Once the prompt execution is done it downloads the images using the /history endpoint
import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
import io
from fastapi import FastAPI, Response
from pydantic import BaseModel
from PIL import Image
server_address = "127.0.0.1:8188"
client_id = str(uuid.uuid4())
def queue_prompt(prompt):
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break #Execution is done
else:
# If you want to be able to decode the binary stream for latent previews, here is how you can do it:
# bytesIO = BytesIO(out[8:])
# preview_image = Image.open(bytesIO) # This is your preview in PIL image format, store it in a global
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
@app.post("/generate-image")
async def generate_image(prompt_request: PromptRequest):
# Load the workflow JSON
with open("wsj-api-rnd-v2.json", "r", encoding="utf-8") as f:
workflow_jsondata = f.read()
jsonwf = json.loads(workflow_jsondata)
# Set the text prompt
jsonwf["6"]["inputs"]["text"] = prompt_request.prompt
# Set a random seed
seednum = random.randint(0, 10000)
jsonwf["25"]["inputs"]["noise_seed"] = seednum
# Connect to WebSocket
ws = websocket.WebSocket()
ws.connect(f"ws://{server_address}/ws?clientId={client_id}")
# Generate images
images = get_images(ws, jsonwf)
ws.close()
# Assuming we want to return the first image from the first node
first_node = next(iter(images))
image_data = images[first_node][0]
# Convert image data to PIL Image
image = Image.open(io.BytesIO(image_data))
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)