huntrezz commited on
Commit
e687f81
1 Parent(s): 2a0602d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import matplotlib.pyplot as plt
7
  from mpl_toolkits.mplot3d import Axes3D
8
  import torch.nn as nn
 
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
@@ -55,19 +56,38 @@ def preprocess_image(image):
55
  return image / 255.0
56
 
57
  def plot_depth_map(depth_map, original_image):
58
- fig = plt.figure(figsize=(16, 9))
59
- ax = fig.add_subplot(111, projection='3d')
60
- x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
61
 
 
 
 
 
 
 
 
 
62
  original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
63
- colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0
 
 
64
 
65
- ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
66
- ax.set_zlim(0, 1)
 
 
 
 
 
67
 
68
- ax.view_init(elev=150, azim=90)
 
 
 
 
 
69
  plt.axis('off')
70
 
 
71
  plt.show()
72
 
73
  fig.canvas.draw()
 
6
  import matplotlib.pyplot as plt
7
  from mpl_toolkits.mplot3d import Axes3D
8
  import torch.nn as nn
9
+ from scipy.interpolate import interp2d
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
 
56
  return image / 255.0
57
 
58
  def plot_depth_map(depth_map, original_image):
59
+ fig = plt.figure(figsize=(32, 9))
 
 
60
 
61
+ # Increase resolution of the meshgrid
62
+ x, y = np.meshgrid(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
63
+
64
+ # Interpolate depth map
65
+ depth_interp = interp2d(np.arange(depth_map.shape[1]), np.arange(depth_map.shape[0]), depth_map)
66
+ z = depth_interp(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
67
+
68
+ # Interpolate colors
69
  original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
70
+ colors = original_image_resized.reshape(-1, original_image_resized.shape[1], 3) / 255.0
71
+ colors_interp = interp2d(np.arange(colors.shape[1]), np.arange(colors.shape[0]), colors.reshape(-1, colors.shape[1]), kind='linear')
72
+ new_colors = colors_interp(np.linspace(0, colors.shape[1]-1, 255), np.linspace(0, colors.shape[0]-1, 255))
73
 
74
+ # Plot with depth map color
75
+ ax1 = fig.add_subplot(121, projection='3d')
76
+ surf1 = ax1.plot_surface(x, y, z, facecolors=plt.cm.viridis(z), shade=False)
77
+ ax1.set_zlim(0, 1)
78
+ ax1.view_init(elev=150, azim=90)
79
+ ax1.set_title("Depth Map Color")
80
+ plt.axis('off')
81
 
82
+ # Plot with RGB color
83
+ ax2 = fig.add_subplot(122, projection='3d')
84
+ surf2 = ax2.plot_surface(x, y, z, facecolors=new_colors, shade=False)
85
+ ax2.set_zlim(0, 1)
86
+ ax2.view_init(elev=150, azim=90)
87
+ ax2.set_title("RGB Color")
88
  plt.axis('off')
89
 
90
+ plt.tight_layout()
91
  plt.show()
92
 
93
  fig.canvas.draw()