from typing import List import numpy as np from fastapi import FastAPI, Body from fastapi.exceptions import HTTPException from PIL import Image import gradio as gr from modules.api import api from .global_state import ( get_all_preprocessor_names, get_all_controlnet_names, get_preprocessor, get_all_preprocessor_tags, select_control_type, ) from .utils import judge_image_type from .logging import logger def encode_to_base64(image): if isinstance(image, str): return image elif not judge_image_type(image): return "Detect result is not image" elif isinstance(image, Image.Image): return api.encode_pil_to_base64(image) elif isinstance(image, np.ndarray): return encode_np_to_base64(image) else: logger.warn("Unable to encode image.") return "" def encode_np_to_base64(image): pil = Image.fromarray(image) return api.encode_pil_to_base64(pil) def controlnet_api(_: gr.Blocks, app: FastAPI): @app.get("/controlnet/model_list") async def model_list(): up_to_date_model_list = get_all_controlnet_names() logger.debug(up_to_date_model_list) return {"model_list": up_to_date_model_list} @app.get("/controlnet/module_list") async def module_list(): module_list = get_all_preprocessor_names() logger.debug(module_list) return { "module_list": module_list, # TODO: Add back module detail. # "module_detail": external_code.get_modules_detail(alias_names), } @app.get("/controlnet/control_types") async def control_types(): def format_control_type( filtered_preprocessor_list, filtered_model_list, default_option, default_model, ): control_dict = { "module_list": filtered_preprocessor_list, "model_list": filtered_model_list, "default_option": default_option, "default_model": default_model, } return control_dict return { "control_types": { control_type: format_control_type(*select_control_type(control_type)) for control_type in get_all_preprocessor_tags() } } @app.post("/controlnet/detect") async def detect( controlnet_module: str = Body("none", title="Controlnet Module"), controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"), controlnet_processor_res: int = Body( 512, title="Controlnet Processor Resolution" ), controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"), controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"), ): processor_module = get_preprocessor(controlnet_module) if processor_module is None: raise HTTPException(status_code=422, detail="Module not available") if len(controlnet_input_images) == 0: raise HTTPException(status_code=422, detail="No image selected") logger.debug( f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module." ) results = [] poses = [] for input_image in controlnet_input_images: img = np.array(api.decode_base64_to_image(input_image)).astype('uint8') class JsonAcceptor: def __init__(self) -> None: self.value = None def accept(self, json_dict: dict) -> None: self.value = json_dict json_acceptor = JsonAcceptor() results.append( processor_module( img, resolution=controlnet_processor_res, slider_1=controlnet_threshold_a, slider_2=controlnet_threshold_b, json_pose_callback=json_acceptor.accept, ) ) if "openpose" in controlnet_module: assert json_acceptor.value is not None poses.append(json_acceptor.value) results64 = [encode_to_base64(img) for img in results] res = {"images": results64, "info": "Success"} if poses: res["poses"] = poses return res