GtyTongji commited on
Commit
74d5114
·
verified ·
1 Parent(s): 4b84b1c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ os.system("pip freeze")
4
+ import spaces
5
+
6
+ import gradio as gr
7
+ import torch as torch
8
+ from diffusers import MarigoldDepthPipeline, DDIMScheduler
9
+ from gradio_dualvision import DualVisionApp
10
+ from huggingface_hub import login
11
+ from PIL import Image
12
+
13
+ CHECKPOINT = "prs-eth/marigold-depth-v1-1"
14
+
15
+ if "Gty20030709" in os.environ:
16
+ login(token=os.environ["Gty20030709"])
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
20
+
21
+ pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
22
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
23
+ pipe = pipe.to(device=device, dtype=dtype)
24
+ try:
25
+ import xformers
26
+ pipe.enable_xformers_memory_efficient_attention()
27
+ except:
28
+ pass
29
+
30
+
31
+ class MarigoldDepthApp(DualVisionApp):
32
+ DEFAULT_SEED = 2024
33
+ DEFAULT_ENSEMBLE_SIZE = 1
34
+ DEFAULT_DENOISE_STEPS = 4
35
+ DEFAULT_PROCESSING_RES = 768
36
+
37
+ def make_header(self):
38
+ gr.Markdown(
39
+ """
40
+ <h2><a href="https://huggingface.co/spaces/prs-eth/marigold" style="color: black;">Marigold Depth Estimation</a></h2>
41
+ """
42
+ )
43
+ with gr.Row(elem_classes="remove-elements"):
44
+ gr.Markdown(
45
+ )
46
+
47
+ def build_user_components(self):
48
+ with gr.Column():
49
+ ensemble_size = gr.Slider(
50
+ label="Ensemble size",
51
+ minimum=1,
52
+ maximum=10,
53
+ step=1,
54
+ value=self.DEFAULT_ENSEMBLE_SIZE,
55
+ )
56
+ denoise_steps = gr.Slider(
57
+ label="Number of denoising steps",
58
+ minimum=1,
59
+ maximum=20,
60
+ step=1,
61
+ value=self.DEFAULT_DENOISE_STEPS,
62
+ )
63
+ processing_res = gr.Radio(
64
+ [
65
+ ("Native", 0),
66
+ ("Recommended", 768),
67
+ ],
68
+ label="Processing resolution",
69
+ value=self.DEFAULT_PROCESSING_RES,
70
+ )
71
+ return {
72
+ "ensemble_size": ensemble_size,
73
+ "denoise_steps": denoise_steps,
74
+ "processing_res": processing_res,
75
+ }
76
+
77
+ def process(self, image_in: Image.Image, **kwargs):
78
+ ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
79
+ denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
80
+ processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
81
+ generator = torch.Generator(device=device).manual_seed(self.DEFAULT_SEED)
82
+
83
+ pipe_out = pipe(
84
+ image_in,
85
+ ensemble_size=ensemble_size,
86
+ num_inference_steps=denoise_steps,
87
+ processing_resolution=processing_res,
88
+ batch_size=1 if processing_res == 0 else 2,
89
+ output_uncertainty=ensemble_size >= 3,
90
+ generator=generator,
91
+ )
92
+
93
+ depth_vis = pipe.image_processor.visualize_depth(pipe_out.prediction)[0]
94
+ depth_16bit = pipe.image_processor.export_depth_to_16bit_png(pipe_out.prediction)[0]
95
+
96
+ out_modalities = {
97
+ "Depth Visualization": depth_vis,
98
+ "Depth 16-bit": depth_16bit,
99
+ }
100
+ if ensemble_size >= 3:
101
+ uncertainty = pipe.image_processor.visualize_uncertainty(pipe_out.uncertainty)[0]
102
+ out_modalities["Uncertainty"] = uncertainty
103
+
104
+ out_settings = {
105
+ "ensemble_size": ensemble_size,
106
+ "denoise_steps": denoise_steps,
107
+ "processing_res": processing_res,
108
+ }
109
+ return out_modalities, out_settings
110
+
111
+
112
+ with MarigoldDepthApp(
113
+ title="Marigold Depth",
114
+ examples_path="files",
115
+ examples_per_page=12,
116
+ squeeze_canvas=True,
117
+ spaces_zero_gpu_enabled=True,
118
+ ) as demo:
119
+ demo.queue(
120
+ api_open=False,
121
+ ).launch(
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ )