huntrezz commited on
Commit
dcf5ae7
·
verified ·
1 Parent(s): edd526e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -19,7 +19,7 @@ parameters_to_prune = [
19
  prune.global_unstructured(
20
  parameters_to_prune,
21
  pruning_method=prune.L1Unstructured,
22
- amount=0.4, # Prune 40% of weights
23
  )
24
 
25
  for module, _ in parameters_to_prune:
@@ -41,15 +41,25 @@ def preprocess_image(image):
41
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
42
  return image / 255.0
43
 
44
- def plot_depth_map(depth_map):
45
- fig = plt.figure(figsize=(16, 9)) # Set figure size to 16:9
46
  ax = fig.add_subplot(111, projection='3d')
47
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
48
- ax.plot_surface(x, y, depth_map, cmap='viridis')
49
- ax.view_init(azim=180, elev=0) # Rotate the view forward and clockwise
 
 
 
 
50
  ax.set_zlim(0, 1)
 
51
  plt.close(fig)
52
- return fig
 
 
 
 
 
53
 
54
  @torch.inference_mode()
55
  def process_frame(image):
@@ -62,15 +72,11 @@ def process_frame(image):
62
  # Normalize depth map
63
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
64
 
65
- # Create 3D plot
66
- fig = plot_depth_map(depth_map)
67
-
68
- # Convert plot to image
69
- fig.canvas.draw()
70
- img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
71
- img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
72
 
73
- return img
74
 
75
  interface = gr.Interface(
76
  fn=process_frame,
 
19
  prune.global_unstructured(
20
  parameters_to_prune,
21
  pruning_method=prune.L1Unstructured,
22
+ amount=0.2, # Prune 20% of weights
23
  )
24
 
25
  for module, _ in parameters_to_prune:
 
41
  image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
42
  return image / 255.0
43
 
44
+ def plot_depth_map(depth_map, original_image):
45
+ fig = plt.figure(figsize=(16, 9))
46
  ax = fig.add_subplot(111, projection='3d')
47
  x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
48
+
49
+ # Resize original image to match depth map dimensions
50
+ original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
51
+ colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0
52
+
53
+ ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
54
  ax.set_zlim(0, 1)
55
+ plt.axis('off')
56
  plt.close(fig)
57
+
58
+ fig.canvas.draw()
59
+ img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
60
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
61
+
62
+ return img
63
 
64
  @torch.inference_mode()
65
  def process_frame(image):
 
72
  # Normalize depth map
73
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
74
 
75
+ # Convert BGR to RGB if necessary
76
+ if image.shape[2] == 3: # Check if it's a color image
77
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
 
 
 
78
 
79
+ return plot_depth_map(depth_map, image)
80
 
81
  interface = gr.Interface(
82
  fn=process_frame,