File size: 3,304 Bytes
f8b3886
 
 
d5bc6d9
e8486cb
3548ace
 
4f3e058
f8b3886
 
893be2d
d5bc6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893be2d
d5bc6d9
 
 
 
fd26002
d5bc6d9
 
 
 
893be2d
7b83683
f8b3886
a42d79c
edd2a5b
99bbe3e
 
a42d79c
dcf5ae7
 
3548ace
 
dcf5ae7
 
 
 
 
ba195a7
772e909
2a0602d
dcf5ae7
2505cbf
 
dcf5ae7
 
 
 
 
 
3548ace
79684c1
d6a18b3
1f906f0
 
cafea28
d5bc6d9
40334e7
d5bc6d9
40334e7
d5bc6d9
dcf5ae7
edd2a5b
dcf5ae7
7bc8ed0
4f1fd81
 
d6a18b3
4f1fd81
 
 
f8b3886
f170544
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import cv2
import torch
import numpy as np
from transformers import DPTImageProcessor
import gradio as gr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your custom trained model
class CompressedStudentModel(nn.Module):
    def __init__(self):
        super(CompressedStudentModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
        )

    def forward(self, x):
        features = self.encoder(x)
        depth = self.decoder(features)
        return depth

# Initialize and load weights into the student model
model = CompressedStudentModel().to(device)
model.load_state_dict(torch.load("huntrezz_depth_v2.pt", map_location=device))
model.eval()

processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")

def preprocess_image(image):
    image = cv2.resize(image, (128, 72))
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
    return image / 255.0

def plot_depth_map(depth_map, original_image):
    fig = plt.figure(figsize=(16, 9))
    ax = fig.add_subplot(111, projection='3d')
    x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
    
    original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
    colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0
    
    ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
    ax.set_zlim(0, 1)
    
    ax.view_init(elev=150, azim=90)
    plt.axis('off')
    
    plt.show()
    
    fig.canvas.draw()
    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    return img

@torch.inference_mode()
def process_frame(image):
    if image is None:
        return None
    preprocessed = preprocess_image(image)
    predicted_depth = model(preprocessed).squeeze().cpu().numpy()
    
    depth_map = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
    
    if image.shape[2] == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    return plot_depth_map(depth_map, image)

interface = gr.Interface(
    fn=process_frame,
    inputs=gr.Image(sources="webcam", streaming=True),
    outputs="image",
    live=True
)

interface.launch()