import cv2 import torch import numpy as np from transformers import DPTImageProcessor import gradio as gr import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import torch.nn as nn device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load your custom trained model class CompressedStudentModel(nn.Module): def __init__(self): super(CompressedStudentModel, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(), ) self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.Conv2d(64, 1, kernel_size=3, padding=1), ) def forward(self, x): features = self.encoder(x) depth = self.decoder(features) return depth # Initialize and load weights into the student model model = CompressedStudentModel().to(device) model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device)) model.eval() processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256") def preprocess_image(image): image = cv2.resize(image, (128, 72)) image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device) return image / 255.0 def plot_depth_map(depth_map, original_image): fig = plt.figure(figsize=(16, 9)) ax = fig.add_subplot(111, projection='3d') x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0])) original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0])) colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0 ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False) ax.set_zlim(0, 1) ax.view_init(elev=150, azim=90) plt.axis('off') plt.show() fig.canvas.draw() img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return img @torch.inference_mode() def process_frame(image): if image is None: return None preprocessed = preprocess_image(image) predicted_depth = model(preprocessed).squeeze().cpu().numpy() depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min()) if image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return plot_depth_map(depth_map, image) interface = gr.Interface( fn=process_frame, inputs=gr.Image(sources="webcam", streaming=True), outputs="image", live=True ) interface.launch()