Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from mpl_toolkits.mplot3d import Axes3D
|
8 |
import torch.nn as nn
|
|
|
9 |
|
10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
|
@@ -55,19 +56,38 @@ def preprocess_image(image):
|
|
55 |
return image / 255.0
|
56 |
|
57 |
def plot_depth_map(depth_map, original_image):
|
58 |
-
fig = plt.figure(figsize=(
|
59 |
-
ax = fig.add_subplot(111, projection='3d')
|
60 |
-
x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
|
63 |
-
colors = original_image_resized.reshape(
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
plt.axis('off')
|
70 |
|
|
|
71 |
plt.show()
|
72 |
|
73 |
fig.canvas.draw()
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from mpl_toolkits.mplot3d import Axes3D
|
8 |
import torch.nn as nn
|
9 |
+
from scipy.interpolate import interp2d
|
10 |
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
|
|
56 |
return image / 255.0
|
57 |
|
58 |
def plot_depth_map(depth_map, original_image):
|
59 |
+
fig = plt.figure(figsize=(32, 9))
|
|
|
|
|
60 |
|
61 |
+
# Increase resolution of the meshgrid
|
62 |
+
x, y = np.meshgrid(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
|
63 |
+
|
64 |
+
# Interpolate depth map
|
65 |
+
depth_interp = interp2d(np.arange(depth_map.shape[1]), np.arange(depth_map.shape[0]), depth_map)
|
66 |
+
z = depth_interp(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255))
|
67 |
+
|
68 |
+
# Interpolate colors
|
69 |
original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
|
70 |
+
colors = original_image_resized.reshape(-1, original_image_resized.shape[1], 3) / 255.0
|
71 |
+
colors_interp = interp2d(np.arange(colors.shape[1]), np.arange(colors.shape[0]), colors.reshape(-1, colors.shape[1]), kind='linear')
|
72 |
+
new_colors = colors_interp(np.linspace(0, colors.shape[1]-1, 255), np.linspace(0, colors.shape[0]-1, 255))
|
73 |
|
74 |
+
# Plot with depth map color
|
75 |
+
ax1 = fig.add_subplot(121, projection='3d')
|
76 |
+
surf1 = ax1.plot_surface(x, y, z, facecolors=plt.cm.viridis(z), shade=False)
|
77 |
+
ax1.set_zlim(0, 1)
|
78 |
+
ax1.view_init(elev=150, azim=90)
|
79 |
+
ax1.set_title("Depth Map Color")
|
80 |
+
plt.axis('off')
|
81 |
|
82 |
+
# Plot with RGB color
|
83 |
+
ax2 = fig.add_subplot(122, projection='3d')
|
84 |
+
surf2 = ax2.plot_surface(x, y, z, facecolors=new_colors, shade=False)
|
85 |
+
ax2.set_zlim(0, 1)
|
86 |
+
ax2.view_init(elev=150, azim=90)
|
87 |
+
ax2.set_title("RGB Color")
|
88 |
plt.axis('off')
|
89 |
|
90 |
+
plt.tight_layout()
|
91 |
plt.show()
|
92 |
|
93 |
fig.canvas.draw()
|