File size: 2,470 Bytes
b5ba7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, Response, Form 
from fastapi.responses import JSONResponse 
from modules import sd_models, shared, scripts
import asyncio
import gradio as gr


def test_api(_: gr.Blocks, app: FastAPI):    
    """
        it kept yelling at me without the stupid gradio import there, i'm sure i did something wrong
        --------------------------
            Error executing callback app_started_callback for E:\storage\stable-diffusion-webui\extensions\apitest\scripts\api.py
            Traceback (most recent call last):
            File "E:\storage\stable-diffusion-webui\modules\script_callbacks.py", line 88, in app_started_callback
            c.callback(demo, app)
            TypeError: test_api() takes 1 positional argument but 2 were given
    """
    @app.post("/openOutpaint/unet-count")
    async def return_model_unet_channel_count(
        model_name: str = Form(description="the model to be inspected")
    ):
        err_msg = ""
        try:
            model = sd_models.checkpoints_list[model_name]
        except: 
            err_msg = "submitted model failed loading, falling back to loaded model"
            model = sd_models.checkpoints_list[get_current_model()]
        theta_0 = sd_models.read_state_dict(model.filename, map_location='cpu')
        channelCount = theta_0["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
        return {            
            "unet_channels": channelCount,
            "estimated_type": switchAssumption(channelCount),            
            "tested_model": model,
            "additional_data": err_msg
        }

def switchAssumption(channelCount):
    return {
        4: "traditional",
        5: "sdv2 depth2img",
        7: "sdv2 upscale 4x",
        8: "instruct-pix2pix",
        9: "inpainting"
    }.get(channelCount, "¯\_(ツ)_/¯")

def get_current_model():
    options = {}
    for key in shared.opts.data.keys():
        metadata = shared.opts.data_labels.get(key)
        if(metadata is not None):
            options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
        else:
            options.update({key: shared.opts.data.get(key, None)})

    return options["sd_model_checkpoint"] # super inefficient but i'm a moron


try:
    import modules.script_callbacks as script_callbacks
    script_callbacks.on_app_started(test_api)
except:
    print("[openOutpaint-webui-extension] UNET API failed to initialize")