huntrezz's picture
Update app.py
2a0602d verified
raw
history blame
3.3 kB
import cv2
import torch
import numpy as np
from transformers import DPTImageProcessor
import gradio as gr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load your custom trained model
class CompressedStudentModel(nn.Module):
def __init__(self):
super(CompressedStudentModel, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=3, padding=1),
)
def forward(self, x):
features = self.encoder(x)
depth = self.decoder(features)
return depth
# Initialize and load weights into the student model
model = CompressedStudentModel().to(device)
model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device))
model.eval()
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
def preprocess_image(image):
image = cv2.resize(image, (128, 72))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def plot_depth_map(depth_map, original_image):
fig = plt.figure(figsize=(16, 9))
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0
ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
ax.set_zlim(0, 1)
ax.view_init(elev=150, azim=90)
plt.axis('off')
plt.show()
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).squeeze().cpu().numpy()
depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
if image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return plot_depth_map(depth_map, image)
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()