File size: 4,414 Bytes
ad93086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import List

import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import gradio as gr

from modules.api import api
from .global_state import (
    get_all_preprocessor_names,
    get_all_controlnet_names,
    get_preprocessor,
    get_all_preprocessor_tags,
    select_control_type,
)
from .utils import judge_image_type
from .logging import logger


def encode_to_base64(image):
    if isinstance(image, str):
        return image
    elif not judge_image_type(image):
        return "Detect result is not image"
    elif isinstance(image, Image.Image):
        return api.encode_pil_to_base64(image)
    elif isinstance(image, np.ndarray):
        return encode_np_to_base64(image)
    else:
        logger.warn("Unable to encode image.")
        return ""


def encode_np_to_base64(image):
    pil = Image.fromarray(image)
    return api.encode_pil_to_base64(pil)


def controlnet_api(_: gr.Blocks, app: FastAPI):
    @app.get("/controlnet/model_list")
    async def model_list():
        up_to_date_model_list = get_all_controlnet_names()
        logger.debug(up_to_date_model_list)
        return {"model_list": up_to_date_model_list}

    @app.get("/controlnet/module_list")
    async def module_list():
        module_list = get_all_preprocessor_names()
        logger.debug(module_list)

        return {
            "module_list": module_list,
            # TODO: Add back module detail.
            # "module_detail": external_code.get_modules_detail(alias_names),
        }

    @app.get("/controlnet/control_types")
    async def control_types():
        def format_control_type(
            filtered_preprocessor_list,
            filtered_model_list,
            default_option,
            default_model,
        ):
            control_dict = {
                    "module_list": filtered_preprocessor_list,
                    "model_list": filtered_model_list,
                    "default_option": default_option,
                    "default_model": default_model,
                }

            return control_dict

        return {
            "control_types": {
                control_type: format_control_type(*select_control_type(control_type))
                for control_type in get_all_preprocessor_tags()
            }
        }

    @app.post("/controlnet/detect")
    async def detect(
        controlnet_module: str = Body("none", title="Controlnet Module"),
        controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"),
        controlnet_processor_res: int = Body(
            512, title="Controlnet Processor Resolution"
        ),
        controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"),
        controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"),
    ):
        processor_module = get_preprocessor(controlnet_module)
        if processor_module is None:
            raise HTTPException(status_code=422, detail="Module not available")

        if len(controlnet_input_images) == 0:
            raise HTTPException(status_code=422, detail="No image selected")

        logger.debug(
            f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module."
        )

        results = []
        poses = []

        for input_image in controlnet_input_images:
            img = np.array(api.decode_base64_to_image(input_image)).astype('uint8')

            class JsonAcceptor:
                def __init__(self) -> None:
                    self.value = None

                def accept(self, json_dict: dict) -> None:
                    self.value = json_dict

            json_acceptor = JsonAcceptor()

            results.append(
                processor_module(
                    img,
                    resolution=controlnet_processor_res,
                    slider_1=controlnet_threshold_a,
                    slider_2=controlnet_threshold_b,
                    json_pose_callback=json_acceptor.accept,
                )
            )

            if "openpose" in controlnet_module:
                assert json_acceptor.value is not None
                poses.append(json_acceptor.value)

        results64 = [encode_to_base64(img) for img in results]
        res = {"images": results64, "info": "Success"}
        if poses:
            res["poses"] = poses

        return res