import cv2 import torch import numpy as np from transformers import DPTImageProcessor import gradio as gr import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import torch.nn as nn from scipy.interpolate import interp2d device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load your custom trained model class CompressedStudentModel(nn.Module): def __init__(self): super(CompressedStudentModel, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.Conv2d(64, 1, kernel_size=3, padding=1), ) def forward(self, x): features = self.encoder(x) depth = self.decoder(features) return depth # Initialize and load weights into the student model model = CompressedStudentModel().to(device) model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device)) model.eval() processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") def preprocess_image(image): image = cv2.resize(image, (128, 72)) image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device) return image / 255.0 def plot_depth_map(depth_map, original_image): fig = plt.figure(figsize=(32, 9)) # Increase resolution of the meshgrid x, y = np.meshgrid(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255)) # Interpolate depth map depth_interp = interp2d(np.arange(depth_map.shape[1]), np.arange(depth_map.shape[0]), depth_map) z = depth_interp(np.linspace(0, depth_map.shape[1]-1, 255), np.linspace(0, depth_map.shape[0]-1, 255)) # Interpolate colors original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0])) colors = original_image_resized.reshape(-1, original_image_resized.shape[1], 3) / 255.0 colors_interp = interp2d(np.arange(colors.shape[1]), np.arange(colors.shape[0]), colors.reshape(-1, colors.shape[1]), kind='linear') new_colors = colors_interp(np.linspace(0, colors.shape[1]-1, 255), np.linspace(0, colors.shape[0]-1, 255)) # Plot with depth map color ax1 = fig.add_subplot(121, projection='3d') surf1 = ax1.plot_surface(x, y, z, facecolors=plt.cm.viridis(z), shade=False) ax1.set_zlim(0, 1) ax1.view_init(elev=150, azim=90) ax1.set_title("Depth Map Color") plt.axis('off') # Plot with RGB color ax2 = fig.add_subplot(122, projection='3d') surf2 = ax2.plot_surface(x, y, z, facecolors=new_colors, shade=False) ax2.set_zlim(0, 1) ax2.view_init(elev=150, azim=90) ax2.set_title("RGB Color") plt.axis('off') plt.tight_layout() plt.show() fig.canvas.draw() img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return img @torch.inference_mode() def process_frame(image): if image is None: return None preprocessed = preprocess_image(image) predicted_depth = model(preprocessed).squeeze().cpu().numpy() depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min()) if image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return plot_depth_map(depth_map, image) interface = gr.Interface( fn=process_frame, inputs=gr.Image(sources="webcam", streaming=True), outputs="image", live=True ) interface.launch()