File size: 7,256 Bytes
e04dce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# DO NOT HOST PUBLICLY - SECURITY RISKS!
# (the API will only be on with --api starting option)
# Currently no API stability guarantees are provided - API may break on any new commit (but hopefully won't).

import os
import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image

import gradio as gr

from typing import Dict, List
from modules.api import api

from src.core import core_generation_funnel, run_makevideo
from src.misc import SCRIPT_VERSION
from src import backbone
from src.common_constants import GenerationOptions as go


def encode_to_base64(image):
    if type(image) is str:
        return image
    elif type(image) is Image.Image:
        return api.encode_pil_to_base64(image)
    elif type(image) is np.ndarray:
        return encode_np_to_base64(image)
    else:
        return ""


def encode_np_to_base64(image):
    pil = Image.fromarray(image)
    return api.encode_pil_to_base64(pil)


def to_base64_PIL(encoding: str):
    return Image.fromarray(np.array(api.decode_base64_to_image(encoding)).astype('uint8'))


def depth_api(_: gr.Blocks, app: FastAPI):
    @app.get("/depth/version")
    async def version():
        return {"version": SCRIPT_VERSION}

    @app.get("/depth/get_options")
    async def get_options():
        return {"options": sorted([x.name.lower() for x in go])}

    # TODO: some potential inputs not supported (like custom depthmaps)
    @app.post("/depth/generate")
    async def process(

        depth_input_images: List[str] = Body([], title='Input Images'),

        options: Dict[str, object] = Body("options", title='Generation options'),

    ):
        # TODO: restrict mesh options

        if len(depth_input_images) == 0:
            raise HTTPException(status_code=422, detail="No images supplied")
        print(f"Processing {str(len(depth_input_images))} images trough the API")

        pil_images = []
        for input_image in depth_input_images:
            pil_images.append(to_base64_PIL(input_image))
        outpath = backbone.get_outpath()
        gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)

        results_based = []
        for count, type, result in gen_obj:
            if not isinstance(result, Image.Image):
                continue
            results_based += [encode_to_base64(result)]

        return {"images": results_based, "info": "Success"}

    @app.post("/depth/generate/video")
    async def process_video(

        depth_input_images: List[str] = Body([], title='Input Images'),

        options: Dict[str, object] = Body("options", title='Generation options'),

    ):
        if len(depth_input_images) == 0:
            raise HTTPException(status_code=422, detail="No images supplied")
        print(f"Processing {str(len(depth_input_images))} images trough the API")

        # You can use either these strings, or integers
        available_models = {
            'res101': 0,
            'dpt_beit_large_512': 1, #midas 3.1
            'dpt_beit_large_384': 2, #midas 3.1
            'dpt_large_384': 3, #midas 3.0
            'dpt_hybrid_384': 4, #midas 3.0
            'midas_v21': 5,
            'midas_v21_small': 6,
            'zoedepth_n': 7, #indoor
            'zoedepth_k': 8, #outdoor
            'zoedepth_nk': 9,
            'marigold_v1': 10,
            'depth_anything': 11,
            'depth_anything_v2_small': 12,
            'depth_anything_v2_base': 13,
            'depth_anything_v2_large': 14
        }

        model_type = options["model_type"]
        
        model_id = None
        if isinstance(model_type, str):
            # Check if the string is in the available_models dictionary
            if model_type in available_models:
                model_id = available_models[model_type]
            else:
                available_strings = list(available_models.keys())
                raise HTTPException(status_code=400, detail={'error': 'Invalid model string', 'available_models': available_strings})
        elif isinstance(model_type, int):
            model_id = model_type
        else:
            raise HTTPException(status_code=400, detail={'error': 'Invalid model parameter type'})
        
        options["model_type"] = model_id

        video_parameters = options["video_parameters"]

        required_params = ["vid_numframes", "vid_fps", "vid_traj", "vid_shift", "vid_border", "dolly", "vid_format", "vid_ssaa", "output_filename"]
        
        missing_params = [param for param in required_params if param not in video_parameters]
        
        if missing_params:
            raise HTTPException(status_code=400, detail={'error': f"Missing required parameter(s): {', '.join(missing_params)}"})

        vid_numframes = video_parameters["vid_numframes"]
        vid_fps = video_parameters["vid_fps"]
        vid_traj = video_parameters["vid_traj"]
        vid_shift = video_parameters["vid_shift"]
        vid_border = video_parameters["vid_border"]
        dolly = video_parameters["dolly"]
        vid_format = video_parameters["vid_format"]
        vid_ssaa = int(video_parameters["vid_ssaa"])
        
        output_filename = video_parameters["output_filename"]
        output_path = os.path.dirname(output_filename)
        basename, extension = os.path.splitext(os.path.basename(output_filename))

        # Comparing video_format with the extension
        if vid_format != extension[1:]:
            raise HTTPException(status_code=400, detail={'error': f"Video format '{vid_format}' does not match with the extension '{extension}'."})

        pil_images = []
        for input_image in depth_input_images:
            pil_images.append(to_base64_PIL(input_image))
        outpath = backbone.get_outpath()

        mesh_fi_filename = video_parameters.get('mesh_fi_filename', None)

        if mesh_fi_filename and os.path.exists(mesh_fi_filename):
            mesh_fi = mesh_fi_filename
            print("Loaded existing mesh from: ", mesh_fi)
        else:
            # If there is no mesh file generate it.
            options["GEN_INPAINTED_MESH"] = True

            gen_obj = core_generation_funnel(outpath, pil_images, None, None, options)

            mesh_fi = None
            for count, type, result in gen_obj:
                if type == 'inpainted_mesh':
                    mesh_fi = result
                    break
                
            if mesh_fi:
                print("Created mesh in: ", mesh_fi)
            else:
                raise HTTPException(status_code=400, detail={'error': "The mesh has not been created"})

        run_makevideo(mesh_fi, vid_numframes, vid_fps, vid_traj, vid_shift, vid_border, dolly, vid_format, vid_ssaa, output_path, basename)

        return {"info": "Success"}


try:
    import modules.script_callbacks as script_callbacks
    if backbone.get_cmd_opt('api', False):
        script_callbacks.on_app_started(depth_api)
        print("Started the depthmap API. DO NOT HOST PUBLICLY - SECURITY RISKS!")
except:
    print('DepthMap API could not start')