xingyang1 commited on
Commit
b47d04c
·
verified ·
1 Parent(s): 7f2e027

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +63 -97
app.py CHANGED
@@ -1,102 +1,68 @@
1
  import gradio as gr
 
 
2
  import cv2
3
- import matplotlib
4
  import numpy as np
 
 
 
 
5
  import os
6
- from PIL import Image
7
- import spaces
8
- import torch
9
- import tempfile
10
- from gradio_imageslider import ImageSlider
11
- from huggingface_hub import hf_hub_download
12
-
13
- from depth_anything_v2.dpt import DepthAnythingV2
14
-
15
- css = """
16
- #img-display-container {
17
- max-height: 100vh;
18
- }
19
- #img-display-input {
20
- max-height: 80vh;
21
- }
22
- #img-display-output {
23
- max-height: 80vh;
24
- }
25
- #download {
26
- height: 62px;
27
- }
28
- """
29
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
- model_configs = {
31
- 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
- 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
33
- 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
- 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
35
- }
36
- encoder2name = {
37
- 'vits': 'Small',
38
- 'vitb': 'Base',
39
- 'vitl': 'Large',
40
- 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
41
- }
42
- encoder = 'vitl'
43
- model_name = encoder2name[encoder]
44
- model = DepthAnythingV2(**model_configs[encoder])
45
- filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
46
- state_dict = torch.load(filepath, map_location="cpu")
47
- model.load_state_dict(state_dict)
48
- model = model.to(DEVICE).eval()
49
-
50
- title = "# Depth Anything V2"
51
- description = """Official demo for **Depth Anything V2**.
52
- Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
53
-
54
- @spaces.GPU
55
- def predict_depth(image):
56
- return model.infer_image(image)
57
-
58
- with gr.Blocks(css=css) as demo:
59
- gr.Markdown(title)
60
- gr.Markdown(description)
61
- gr.Markdown("### Depth Prediction demo")
62
-
63
- with gr.Row():
64
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
65
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
66
- submit = gr.Button(value="Compute Depth")
67
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
68
- raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
69
-
70
- cmap = matplotlib.colormaps.get_cmap('Spectral_r')
71
-
72
- def on_submit(image):
73
- original_image = image.copy()
74
-
75
- h, w = image.shape[:2]
76
-
77
- depth = predict_depth(image[:, :, ::-1])
78
-
79
- raw_depth = Image.fromarray(depth.astype('uint16'))
80
- tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
81
- raw_depth.save(tmp_raw_depth.name)
82
-
83
- depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
84
- depth = depth.astype(np.uint8)
85
- colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
86
-
87
- gray_depth = Image.fromarray(depth)
88
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
89
- gray_depth.save(tmp_gray_depth.name)
90
-
91
- return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
92
-
93
- submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
94
-
95
- example_files = os.listdir('assets/examples')
96
- example_files.sort()
97
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
98
- examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
99
-
100
 
101
- if __name__ == '__main__':
102
- demo.queue().launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
  import cv2
 
5
  import numpy as np
6
+ from geobench.modeling.archs.dam.dam import DepthAnything
7
+ from geobench.utils.image_util import colorize_depth_maps
8
+ from geobench.midas.transforms import Resize, NormalizeImage, PrepareForNet
9
+ from torchvision.transforms import Compose
10
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Helper function to load model (same as your original code)
13
+ def load_model_by_name(arch_name, checkpoint_path, device):
14
+ if arch_name == 'depthanything':
15
+ if '.safetensors' in checkpoint_path:
16
+ model = DepthAnything.from_pretrained(os.path.dirname(checkpoint_path)).to(device)
17
+ else:
18
+ raise NotImplementedError("Model architecture not implemented.")
19
+ else:
20
+ raise NotImplementedError(f"Unknown architecture: {arch_name}")
21
+ return model
22
+
23
+ # Image processing function (same as your original code, modified for Gradio)
24
+ def process_image(image, model, device, mode='rel_depth'):
25
+ # Preprocess the image
26
+ image_np = np.array(image)[..., ::-1] / 255
27
+ transform = Compose([
28
+ Resize(512, 512, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, image_interpolation_method=cv2.INTER_CUBIC),
29
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ PrepareForNet()
31
+ ])
32
+
33
+ image_tensor = transform({'image': image_np})['image']
34
+ image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad(): # Disable autograd since we don't need gradients on CPU
37
+ pred_disp, _ = model(image_tensor)
38
+ pred_disp_np = pred_disp.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
39
+ pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
40
+
41
+ # Colorize depth map
42
+ cmap = "Spectral_r" if mode != 'metric' else 'Spectral_r'
43
+ depth_colored = colorize_depth_maps(pred_disp[None, ...], 0, 1, cmap=cmap).squeeze()
44
+ depth_colored = (depth_colored * 255).astype(np.uint8)
45
+
46
+ depth_image = Image.fromarray(depth_colored)
47
+ return depth_image
48
+
49
+ # Gradio interface function
50
+ def gradio_interface(image, mode='rel_depth'):
51
+ # Set device to CPU explicitly
52
+ device = torch.device("cpu") # Force using CPU
53
+ model = load_model_by_name("depthanything", "your_checkpoint_path_here", device)
54
+
55
+ # Process image and return output
56
+ return process_image(image, model, device, mode)
57
+
58
+ # Create Gradio interface
59
+ iface = gr.Interface(
60
+ fn=gradio_interface,
61
+ inputs=[gr.Image(type="pil"), gr.Dropdown(choices=['rel_depth', 'metric_depth', 'disparity'], label="Mode")],
62
+ outputs=gr.Image(type="pil"),
63
+ title="Depth Estimation Demo",
64
+ description="Upload an image to see the depth estimation results."
65
+ )
66
+
67
+ # Launch the Gradio interface
68
+ iface.launch()