huntrezz commited on
Commit
8ad09ea
1 Parent(s): 40334e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -38
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
7
- import open3d as o3d
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
@@ -32,31 +31,11 @@ model = model.to(device)
32
 
33
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
34
 
35
- color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
36
- color_map = torch.from_numpy(color_map).to(device)
37
-
38
  def preprocess_image(image):
39
  image = cv2.resize(image, (128, 128))
40
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
41
  return image / 255.0
42
 
43
- def create_point_cloud(depth_map, color_image):
44
- rows, cols = depth_map.shape
45
- c, r = np.meshgrid(np.arange(cols), np.arange(rows), sparse=True)
46
- valid = (depth_map > 0) & (depth_map < 1000)
47
- z = np.where(valid, depth_map, 0)
48
- x = np.where(valid, z * (c - cols / 2) / cols, 0)
49
- y = np.where(valid, z * (r - rows / 2) / rows, 0)
50
-
51
- points = np.dstack((x, y, z)).reshape(-1, 3)
52
- colors = color_image.reshape(-1, 3)
53
-
54
- pcd = o3d.geometry.PointCloud()
55
- pcd.points = o3d.utility.Vector3dVector(points)
56
- pcd.colors = o3d.utility.Vector3dVector(colors / 255.0)
57
-
58
- return pcd
59
-
60
  @torch.inference_mode()
61
  def process_frame(image):
62
  if image is None:
@@ -68,24 +47,14 @@ def process_frame(image):
68
  # Normalize depth map
69
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
70
 
71
- # Create point cloud
72
- pcd = create_point_cloud(depth_map, image)
73
-
74
- # Visualize point cloud
75
- vis = o3d.visualization.Visualizer()
76
- vis.create_window()
77
- vis.add_geometry(pcd)
78
- vis.poll_events()
79
- vis.update_renderer()
80
-
81
- # Capture the visualization as an image
82
- image = vis.capture_screen_float_buffer(False)
83
- vis.destroy_window()
84
 
85
- # Convert the image to numpy array
86
- point_cloud_image = (np.asarray(image) * 255).astype(np.uint8)
 
87
 
88
- return point_cloud_image
89
 
90
  interface = gr.Interface(
91
  fn=process_frame,
@@ -94,4 +63,4 @@ interface = gr.Interface(
94
  live=True
95
  )
96
 
97
- interface.launch()
 
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
 
7
 
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
 
31
 
32
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
33
 
 
 
 
34
  def preprocess_image(image):
35
  image = cv2.resize(image, (128, 128))
36
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
37
  return image / 255.0
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @torch.inference_mode()
40
  def process_frame(image):
41
  if image is None:
 
47
  # Normalize depth map
48
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
49
 
50
+ # Create a more visually informative depth map
51
+ depth_color = cv2.applyColorMap((depth_map * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Blend original image with depth map for context
54
+ original_resized = cv2.resize(image, (128, 128))
55
+ blended = cv2.addWeighted(original_resized, 0.6, depth_color, 0.4, 0)
56
 
57
+ return blended
58
 
59
  interface = gr.Interface(
60
  fn=process_frame,
 
63
  live=True
64
  )
65
 
66
+ interface.launch()