|
import asyncio |
|
import json |
|
import logging |
|
import random |
|
import urllib.parse |
|
import urllib.request |
|
from typing import Optional |
|
|
|
import websocket |
|
from open_webui.env import SRC_LOG_LEVELS |
|
from pydantic import BaseModel |
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["COMFYUI"]) |
|
|
|
default_headers = {"User-Agent": "Mozilla/5.0"} |
|
|
|
|
|
def queue_prompt(prompt, client_id, base_url): |
|
log.info("queue_prompt") |
|
p = {"prompt": prompt, "client_id": client_id} |
|
data = json.dumps(p).encode("utf-8") |
|
log.debug(f"queue_prompt data: {data}") |
|
try: |
|
req = urllib.request.Request( |
|
f"{base_url}/prompt", data=data, headers=default_headers |
|
) |
|
response = urllib.request.urlopen(req).read() |
|
return json.loads(response) |
|
except Exception as e: |
|
log.exception(f"Error while queuing prompt: {e}") |
|
raise e |
|
|
|
|
|
def get_image(filename, subfolder, folder_type, base_url): |
|
log.info("get_image") |
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
|
url_values = urllib.parse.urlencode(data) |
|
req = urllib.request.Request( |
|
f"{base_url}/view?{url_values}", headers=default_headers |
|
) |
|
with urllib.request.urlopen(req) as response: |
|
return response.read() |
|
|
|
|
|
def get_image_url(filename, subfolder, folder_type, base_url): |
|
log.info("get_image") |
|
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} |
|
url_values = urllib.parse.urlencode(data) |
|
return f"{base_url}/view?{url_values}" |
|
|
|
|
|
def get_history(prompt_id, base_url): |
|
log.info("get_history") |
|
|
|
req = urllib.request.Request( |
|
f"{base_url}/history/{prompt_id}", headers=default_headers |
|
) |
|
with urllib.request.urlopen(req) as response: |
|
return json.loads(response.read()) |
|
|
|
|
|
def get_images(ws, prompt, client_id, base_url): |
|
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"] |
|
output_images = [] |
|
while True: |
|
out = ws.recv() |
|
if isinstance(out, str): |
|
message = json.loads(out) |
|
if message["type"] == "executing": |
|
data = message["data"] |
|
if data["node"] is None and data["prompt_id"] == prompt_id: |
|
break |
|
else: |
|
continue |
|
|
|
history = get_history(prompt_id, base_url)[prompt_id] |
|
for o in history["outputs"]: |
|
for node_id in history["outputs"]: |
|
node_output = history["outputs"][node_id] |
|
if "images" in node_output: |
|
for image in node_output["images"]: |
|
url = get_image_url( |
|
image["filename"], image["subfolder"], image["type"], base_url |
|
) |
|
output_images.append({"url": url}) |
|
return {"data": output_images} |
|
|
|
|
|
class ComfyUINodeInput(BaseModel): |
|
type: Optional[str] = None |
|
node_ids: list[str] = [] |
|
key: Optional[str] = "text" |
|
value: Optional[str] = None |
|
|
|
|
|
class ComfyUIWorkflow(BaseModel): |
|
workflow: str |
|
nodes: list[ComfyUINodeInput] |
|
|
|
|
|
class ComfyUIGenerateImageForm(BaseModel): |
|
workflow: ComfyUIWorkflow |
|
|
|
prompt: str |
|
negative_prompt: Optional[str] = None |
|
width: int |
|
height: int |
|
n: int = 1 |
|
|
|
steps: Optional[int] = None |
|
seed: Optional[int] = None |
|
|
|
|
|
async def comfyui_generate_image( |
|
model: str, payload: ComfyUIGenerateImageForm, client_id, base_url |
|
): |
|
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") |
|
workflow = json.loads(payload.workflow.workflow) |
|
|
|
for node in payload.workflow.nodes: |
|
if node.type: |
|
if node.type == "model": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][node.key] = model |
|
elif node.type == "prompt": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "text" |
|
] = payload.prompt |
|
elif node.type == "negative_prompt": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "text" |
|
] = payload.negative_prompt |
|
elif node.type == "width": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "width" |
|
] = payload.width |
|
elif node.type == "height": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "height" |
|
] = payload.height |
|
elif node.type == "n": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "batch_size" |
|
] = payload.n |
|
elif node.type == "steps": |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][ |
|
node.key if node.key else "steps" |
|
] = payload.steps |
|
elif node.type == "seed": |
|
seed = ( |
|
payload.seed |
|
if payload.seed |
|
else random.randint(0, 18446744073709551614) |
|
) |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][node.key] = seed |
|
else: |
|
for node_id in node.node_ids: |
|
workflow[node_id]["inputs"][node.key] = node.value |
|
|
|
try: |
|
ws = websocket.WebSocket() |
|
ws.connect(f"{ws_url}/ws?clientId={client_id}") |
|
log.info("WebSocket connection established.") |
|
except Exception as e: |
|
log.exception(f"Failed to connect to WebSocket server: {e}") |
|
return None |
|
|
|
try: |
|
log.info("Sending workflow to WebSocket server.") |
|
log.info(f"Workflow: {workflow}") |
|
images = await asyncio.to_thread(get_images, ws, workflow, client_id, base_url) |
|
except Exception as e: |
|
log.exception(f"Error while receiving images: {e}") |
|
images = None |
|
|
|
ws.close() |
|
|
|
return images |
|
|