|
|
|
|
|
|
|
import websocket |
|
import uuid |
|
import json |
|
import urllib.request |
|
import urllib.parse |
|
import random |
|
import io |
|
from fastapi import FastAPI, Response |
|
from pydantic import BaseModel |
|
from PIL import Image |
|
|
|
server_address = "127.0.0.1:8188" |
|
client_id = str(uuid.uuid4()) |
|
|
|
def queue_prompt(prompt): |
|
p = {"prompt": prompt, "client_id": client_id} |
|
data = json.dumps(p).encode('utf-8') |
|
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) |
|
return json.loads(urllib.request.urlopen(req).read()) |
|
|
|
def get_image(filename, subfolder, folder_type): |
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
|
url_values = urllib.parse.urlencode(data) |
|
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: |
|
return response.read() |
|
|
|
def get_history(prompt_id): |
|
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: |
|
return json.loads(response.read()) |
|
|
|
def get_images(ws, prompt): |
|
prompt_id = queue_prompt(prompt)['prompt_id'] |
|
output_images = {} |
|
while True: |
|
out = ws.recv() |
|
if isinstance(out, str): |
|
message = json.loads(out) |
|
if message['type'] == 'executing': |
|
data = message['data'] |
|
if data['node'] is None and data['prompt_id'] == prompt_id: |
|
break |
|
else: |
|
|
|
|
|
|
|
continue |
|
|
|
history = get_history(prompt_id)[prompt_id] |
|
for node_id in history['outputs']: |
|
node_output = history['outputs'][node_id] |
|
images_output = [] |
|
if 'images' in node_output: |
|
for image in node_output['images']: |
|
image_data = get_image(image['filename'], image['subfolder'], image['type']) |
|
images_output.append(image_data) |
|
output_images[node_id] = images_output |
|
|
|
return output_images |
|
|
|
app = FastAPI() |
|
|
|
class PromptRequest(BaseModel): |
|
prompt: str |
|
|
|
@app.post("/generate-image") |
|
async def generate_image(prompt_request: PromptRequest): |
|
|
|
with open("wsj-api-rnd-v2.json", "r", encoding="utf-8") as f: |
|
workflow_jsondata = f.read() |
|
|
|
jsonwf = json.loads(workflow_jsondata) |
|
|
|
|
|
jsonwf["6"]["inputs"]["text"] = prompt_request.prompt |
|
|
|
|
|
seednum = random.randint(0, 10000) |
|
jsonwf["25"]["inputs"]["noise_seed"] = seednum |
|
|
|
|
|
ws = websocket.WebSocket() |
|
ws.connect(f"ws://{server_address}/ws?clientId={client_id}") |
|
|
|
|
|
images = get_images(ws, jsonwf) |
|
ws.close() |
|
|
|
|
|
first_node = next(iter(images)) |
|
image_data = images[first_node][0] |
|
|
|
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
return Response(content=img_byte_arr, media_type="image/png") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|