huntrezz commited on
Commit
3548ace
·
verified ·
1 Parent(s): 9815468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -4,7 +4,8 @@ import numpy as np
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
7
- from DepthVisualizer import DepthVisualizer
 
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
@@ -42,6 +43,15 @@ def preprocess_image(image):
42
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
43
  return image / 255.0
44
 
 
 
 
 
 
 
 
 
 
45
  @torch.inference_mode()
46
  def process_frame(image):
47
  if image is None:
@@ -53,13 +63,15 @@ def process_frame(image):
53
  # Normalize depth map
54
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
55
 
56
- # Convert depth map to point cloud
57
- point_cloud = visualizer.depth_map_to_point_cloud(depth_map)
58
 
59
- # Render point cloud
60
- rendered_image = visualizer.render_frame(point_cloud)
 
 
61
 
62
- return rendered_image
63
 
64
  interface = gr.Interface(
65
  fn=process_frame,
 
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
  import gradio as gr
6
  import torch.nn.utils.prune as prune
7
+ import matplotlib.pyplot as plt
8
+ from mpl_toolkits.mplot3d import Axes3D
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
 
43
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
44
  return image / 255.0
45
 
46
+ def plot_depth_map(depth_map):
47
+ fig = plt.figure()
48
+ ax = fig.add_subplot(111, projection='3d')
49
+ x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
50
+ ax.plot_surface(x, y, depth_map, cmap='viridis')
51
+ ax.set_zlim(0, 1)
52
+ plt.close(fig)
53
+ return fig
54
+
55
  @torch.inference_mode()
56
  def process_frame(image):
57
  if image is None:
 
63
  # Normalize depth map
64
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
65
 
66
+ # Create 3D plot
67
+ fig = plot_depth_map(depth_map)
68
 
69
+ # Convert plot to image
70
+ fig.canvas.draw()
71
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
72
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
73
 
74
+ return img
75
 
76
  interface = gr.Interface(
77
  fn=process_frame,