pablovela5620 commited on
Commit
0e4dfc4
·
verified ·
1 Parent(s): 9cca3d8

Upload gradio_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_app.py +177 -0
gradio_app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ try:
5
+ import spaces # type: ignore
6
+
7
+ IN_SPACES = True
8
+ except ImportError:
9
+ print("Not running on Zero")
10
+ IN_SPACES = False
11
+ import torch
12
+
13
+ from monopriors.relative_depth_models import (
14
+ DepthAnythingV2Predictor,
15
+ RelativeDepthPrediction,
16
+ UniDepthRelativePredictor,
17
+ get_relative_predictor,
18
+ RELATIVE_PREDICTORS,
19
+ )
20
+ from monopriors.relative_depth_models.base_relative_depth import BaseRelativePredictor
21
+ from monopriors.rr_logging_utils import (
22
+ log_relative_pred,
23
+ create_depth_comparison_blueprint,
24
+ )
25
+ import rerun as rr
26
+ from gradio_rerun import Rerun
27
+ from pathlib import Path
28
+ from typing import Literal, get_args
29
+ import gc
30
+
31
+ from jaxtyping import UInt8
32
+
33
+ title = "# Depth Comparison"
34
+ description1 = """Demo to help compare different depth models. Including both Scale | Shift Invariant and Metric Depth types."""
35
+ description2 = """Invariant models mean they have no true scale and are only relative, where as Metric models have a true scale and are absolute (meters)."""
36
+ model_load_status: str = "Models loaded and ready to use!"
37
+ DEVICE: Literal["cuda"] | Literal["cpu"] = (
38
+ "cuda" if torch.cuda.is_available() else "cpu"
39
+ )
40
+ if gr.NO_RELOAD:
41
+ MODEL_1 = DepthAnythingV2Predictor(device=DEVICE)
42
+ MODEL_2 = UniDepthRelativePredictor(device=DEVICE)
43
+
44
+
45
+ def predict_depth(
46
+ model: BaseRelativePredictor, rgb: UInt8[np.ndarray, "h w 3"]
47
+ ) -> RelativeDepthPrediction:
48
+ model.set_model_device(device=DEVICE)
49
+ relative_pred: RelativeDepthPrediction = model(rgb, None)
50
+ return relative_pred
51
+
52
+
53
+ if IN_SPACES:
54
+ predict_depth = spaces.GPU(predict_depth)
55
+ # remove any model that fails on zerogpu spaces
56
+
57
+
58
+ def load_models(
59
+ model_1: RELATIVE_PREDICTORS,
60
+ model_2: RELATIVE_PREDICTORS,
61
+ progress=gr.Progress(),
62
+ ) -> str:
63
+ global MODEL_1, MODEL_2
64
+ # delete the previous models and clear gpu memory
65
+ if "MODEL_1" in globals():
66
+ del MODEL_1
67
+ if "MODEL_2" in globals():
68
+ del MODEL_2
69
+ torch.cuda.empty_cache()
70
+ gc.collect()
71
+
72
+ progress(0, desc="Loading Models please wait...")
73
+
74
+ models: list[int] = [model_1, model_2]
75
+ loaded_models = []
76
+
77
+ for model in models:
78
+ loaded_models.append(get_relative_predictor(model)(device=DEVICE))
79
+
80
+ progress(0.5, desc=f"Loaded {model}")
81
+
82
+ progress(1, desc="Models Loaded")
83
+ MODEL_1, MODEL_2 = loaded_models
84
+
85
+ return model_load_status
86
+
87
+
88
+ @rr.thread_local_stream("depth")
89
+ def on_submit(rgb: UInt8[np.ndarray, "h w 3"]):
90
+ stream: rr.BinaryStream = rr.binary_stream()
91
+ models_list = [MODEL_1, MODEL_2]
92
+ blueprint = create_depth_comparison_blueprint(models_list)
93
+ rr.send_blueprint(blueprint)
94
+ try:
95
+ for model in models_list:
96
+ # get the name of the model
97
+ parent_log_path = Path(f"{model.__class__.__name__}")
98
+ rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
99
+
100
+ relative_pred: RelativeDepthPrediction = predict_depth(model, rgb)
101
+
102
+ log_relative_pred(
103
+ parent_log_path=parent_log_path,
104
+ relative_pred=relative_pred,
105
+ rgb_hw3=rgb,
106
+ )
107
+
108
+ yield stream.read()
109
+ except Exception as e:
110
+ raise gr.Error(f"Error with model {model.__class__.__name__}: {e}")
111
+
112
+
113
+ with gr.Blocks() as demo:
114
+ gr.Markdown(title)
115
+ gr.Markdown(description1)
116
+ gr.Markdown(description2)
117
+ gr.Markdown("### Depth Prediction demo")
118
+
119
+ with gr.Row():
120
+ input_image = gr.Image(
121
+ label="Input Image",
122
+ type="numpy",
123
+ height=300,
124
+ )
125
+ with gr.Column():
126
+ gr.Radio(
127
+ choices=["Scale | Shift Invariant", "Metric (TODO)"],
128
+ label="Depth Type",
129
+ value="Scale | Shift Invariant",
130
+ interactive=True,
131
+ )
132
+ with gr.Row():
133
+ model_1_dropdown = gr.Dropdown(
134
+ choices=list(get_args(RELATIVE_PREDICTORS)),
135
+ label="Model1",
136
+ value="DepthAnythingV2Predictor",
137
+ )
138
+ model_2_dropdown = gr.Dropdown(
139
+ choices=list(get_args(RELATIVE_PREDICTORS)),
140
+ label="Model2",
141
+ value="UniDepthRelativePredictor",
142
+ )
143
+ model_status = gr.Textbox(
144
+ label="Model Status",
145
+ value=model_load_status,
146
+ interactive=False,
147
+ )
148
+
149
+ with gr.Row():
150
+ submit = gr.Button(value="Compute Depth")
151
+ load_models_btn = gr.Button(value="Load Models")
152
+ rr_viewer = Rerun(streaming=True, height=800)
153
+
154
+ submit.click(
155
+ on_submit,
156
+ inputs=[input_image],
157
+ outputs=[rr_viewer],
158
+ )
159
+
160
+ load_models_btn.click(
161
+ load_models,
162
+ inputs=[model_1_dropdown, model_2_dropdown],
163
+ outputs=[model_status],
164
+ )
165
+
166
+ examples_paths = Path("examples").glob("*.jpeg")
167
+ examples_list = sorted([str(path) for path in examples_paths])
168
+ examples = gr.Examples(
169
+ examples=examples_list,
170
+ inputs=[input_image],
171
+ outputs=[rr_viewer],
172
+ fn=on_submit,
173
+ cache_examples=False,
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ demo.launch()