xiank he commited on
Commit
72cd992
·
1 Parent(s): 89a1e10

distill-any-depth

Browse files
Files changed (1) hide show
  1. app.py +74 -21
app.py CHANGED
@@ -1,32 +1,42 @@
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
- import cv2
5
  import numpy as np
6
  from distillanydepth.modeling.archs.dam.dam import DepthAnything
7
- from distillanydepth.utils.image_util import colorize_depth_maps
8
  from distillanydepth.midas.transforms import Resize, NormalizeImage, PrepareForNet
9
  from torchvision.transforms import Compose
10
- import os
 
 
11
 
12
- # Helper function to load model (same as your original code)
13
  def load_model_by_name(arch_name, checkpoint_path, device):
 
14
  if arch_name == 'depthanything':
15
- if '.safetensors' in checkpoint_path:
16
- model = DepthAnything.from_pretrained(os.path.dirname(checkpoint_path)).to(device)
17
- else:
18
- raise NotImplementedError("Model architecture not implemented.")
 
 
 
 
19
  else:
20
  raise NotImplementedError(f"Unknown architecture: {arch_name}")
21
  return model
22
 
23
- # Image processing function (same as your original code, modified for Gradio)
24
  def process_image(image, model, device):
 
 
 
25
  # Preprocess the image
26
  image_np = np.array(image)[..., ::-1] / 255
 
27
  transform = Compose([
28
- Resize(512, 512, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, image_interpolation_method=cv2.INTER_CUBIC),
29
- NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
  PrepareForNet()
31
  ])
32
 
@@ -35,31 +45,74 @@ def process_image(image, model, device):
35
 
36
  with torch.no_grad(): # Disable autograd since we don't need gradients on CPU
37
  pred_disp, _ = model(image_tensor)
38
- pred_disp_np = pred_disp.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
39
- pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
40
 
 
 
 
 
 
 
41
  # Colorize depth map
42
- cmap = "Spectral_r" # Default colormap for relative depth
43
- depth_colored = colorize_depth_maps(pred_disp[None, ...], 0, 1, cmap=cmap).squeeze()
 
 
44
  depth_colored = (depth_colored * 255).astype(np.uint8)
45
 
46
- depth_image = Image.fromarray(depth_colored)
 
 
 
 
 
 
 
 
47
  return depth_image
48
 
49
  # Gradio interface function
50
  def gradio_interface(image):
51
- # Set device to CPU explicitly
52
- device = torch.device("cpu") # Force using CPU
53
- model = load_model_by_name("depthanything", "your_checkpoint_path_here", device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
55
  # Process image and return output
56
- return process_image(image, model, device)
 
57
 
58
  # Create Gradio interface
59
  iface = gr.Interface(
60
  fn=gradio_interface,
61
  inputs=gr.Image(type="pil"), # Only image input, no mode selection
62
- outputs=gr.Image(type="pil"),
63
  title="Depth Estimation Demo",
64
  description="Upload an image to see the depth estimation results."
65
  )
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
 
4
  import numpy as np
5
  from distillanydepth.modeling.archs.dam.dam import DepthAnything
6
+ from distillanydepth.utils.image_util import chw2hwc, colorize_depth_maps
7
  from distillanydepth.midas.transforms import Resize, NormalizeImage, PrepareForNet
8
  from torchvision.transforms import Compose
9
+ import cv2
10
+ from huggingface_hub import hf_hub_download
11
+ from safetensors.torch import load_file # 导入 safetensors 库
12
 
13
+ # Helper function to load model from Hugging Face
14
  def load_model_by_name(arch_name, checkpoint_path, device):
15
+ model = None
16
  if arch_name == 'depthanything':
17
+ # 使用 safetensors 加载模型权重
18
+ model_weights = load_file(checkpoint_path) # safetensors 加载方式
19
+
20
+ # 初始化模型
21
+ model = DepthAnything(checkpoint_path=None).to(device)
22
+ model.load_state_dict(model_weights) # 将加载的权重应用到模型
23
+
24
+ model = model.to(device) # 确保模型在正确的设备上
25
  else:
26
  raise NotImplementedError(f"Unknown architecture: {arch_name}")
27
  return model
28
 
29
+ # Image processing function
30
  def process_image(image, model, device):
31
+ if model is None:
32
+ return None
33
+
34
  # Preprocess the image
35
  image_np = np.array(image)[..., ::-1] / 255
36
+
37
  transform = Compose([
38
+ Resize(756, 756, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
39
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
40
  PrepareForNet()
41
  ])
42
 
 
45
 
46
  with torch.no_grad(): # Disable autograd since we don't need gradients on CPU
47
  pred_disp, _ = model(image_tensor)
 
 
48
 
49
+ # Ensure the depth map is in the correct shape before colorization
50
+ pred_disp_np = pred_disp.cpu().detach().numpy()[0, 0, :, :] # Remove extra singleton dimensions
51
+
52
+ # Normalize depth map
53
+ pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
54
+
55
  # Colorize depth map
56
+ cmap = "Spectral_r"
57
+ depth_colored = colorize_depth_maps(pred_disp[None, ..., None], 0, 1, cmap=cmap).squeeze() # Ensure correct dimension
58
+
59
+ # Convert to uint8 for image display
60
  depth_colored = (depth_colored * 255).astype(np.uint8)
61
 
62
+ # Convert to HWC format (height, width, channels)
63
+ depth_colored_hwc = chw2hwc(depth_colored)
64
+
65
+ # Resize to match the original image dimensions (height, width)
66
+ h, w = image_np.shape[:2]
67
+ depth_colored_hwc = cv2.resize(depth_colored_hwc, (w, h), cv2.INTER_LINEAR)
68
+
69
+ # Convert to a PIL image
70
+ depth_image = Image.fromarray(depth_colored_hwc)
71
  return depth_image
72
 
73
  # Gradio interface function
74
  def gradio_interface(image):
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+
77
+ model_kwargs = dict(
78
+ vitb=dict(
79
+ encoder='vitb',
80
+ features=128,
81
+ out_channels=[96, 192, 384, 768],
82
+ ),
83
+ vitl=dict(
84
+ encoder="vitl",
85
+ features=256,
86
+ out_channels=[256, 512, 1024, 1024],
87
+ use_bn=False,
88
+ use_clstoken=False,
89
+ max_depth=150.0,
90
+ mode='disparity',
91
+ pretrain_type='dinov2',
92
+ del_mask_token=False
93
+ )
94
+ )
95
+ # Load model
96
+ model = DepthAnything(**model_kwargs['vitl']).to(device)
97
+ checkpoint_path = hf_hub_download(repo_id=f"xingyang1/Distill-Any-Depth", filename=f"large/model.safetensors", repo_type="model")
98
 
99
+ # 使用 safetensors 加载模型权重
100
+ model_weights = load_file(checkpoint_path) # safetensors 加载方式
101
+ model.load_state_dict(model_weights)
102
+ model = model.to(device) # 确保模型在正确的设备上
103
+
104
+ if model is None:
105
+ return None
106
+
107
  # Process image and return output
108
+ depth_image = process_image(image, model, device)
109
+ return depth_image
110
 
111
  # Create Gradio interface
112
  iface = gr.Interface(
113
  fn=gradio_interface,
114
  inputs=gr.Image(type="pil"), # Only image input, no mode selection
115
+ outputs=gr.Image(type="pil"), # Only depth image output, no debug info
116
  title="Depth Estimation Demo",
117
  description="Upload an image to see the depth estimation results."
118
  )