File size: 3,362 Bytes
f8b3886
 
 
0143794
f8b3886
 
0143794
 
 
 
f8b3886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0143794
f8b3886
0143794
 
 
 
 
 
f8b3886
0143794
 
 
 
 
 
f8b3886
0143794
 
 
f8b3886
0143794
 
 
 
f8b3886
0143794
 
 
f8b3886
0143794
 
f8b3886
0143794
 
f8b3886
0143794
 
 
f8b3886
0143794
f8b3886
0143794
 
 
 
 
f8b3886
0143794
 
 
 
f8b3886
0143794
 
f8b3886
0143794
 
 
 
 
 
f8b3886
0143794
 
 
f8b3886
0143794
 
f8b3886
0143794
 
f8b3886
0143794
 
 
f8b3886
0143794
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import time
import warnings
import asyncio
import json
import websockets

warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")

cap = cv2.VideoCapture(0)

def resize_image(image, target_size=(256, 256)):
    return cv2.resize(image, target_size)

def manual_normalize(depth_map):
    min_val = np.min(depth_map)
    max_val = np.max(depth_map)
    if min_val != max_val:
        normalized = (depth_map - min_val) / (max_val - min_val)
        return (normalized * 255).astype(np.uint8)
    else:
        return np.zeros_like(depth_map, dtype=np.uint8)

frame_skip = 4
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)

connected = set()

async def broadcast(message):
    for websocket in connected:
        try:
            await websocket.send(message)
        except websockets.exceptions.ConnectionClosed:
            connected.remove(websocket)

async def handler(websocket, path):
    connected.add(websocket)
    try:
        await websocket.wait_closed()
    finally:
        connected.remove(websocket)

async def process_frames():
    frame_count = 0
    prev_frame_time = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        if frame_count % frame_skip != 0:
            continue

        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        resized_frame = resize_image(rgb_frame)

        inputs = processor(images=resized_frame, return_tensors="pt").to(device)
        inputs = {k: v.to(torch.float16) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            predicted_depth = outputs.predicted_depth

        depth_map = predicted_depth.squeeze().cpu().numpy()

        depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
        depth_map = depth_map.astype(np.float32)

        if depth_map.size == 0:
            depth_map = np.zeros((256, 256), dtype=np.uint8)
        else:
            if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
                depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            else:
                depth_map = np.zeros_like(depth_map, dtype=np.uint8)

        if np.all(depth_map == 0):
            depth_map = manual_normalize(depth_map)

        data = {
            'depthMap': depth_map.tolist(),
            'rgbFrame': rgb_frame.tolist()
        }
        
        await broadcast(json.dumps(data))

        new_frame_time = time.time()
        fps = 1 / (new_frame_time - prev_frame_time)
        prev_frame_time = new_frame_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

async def main():
    server = await websockets.serve(handler, "localhost", 8765)
    await asyncio.gather(server.wait_closed(), process_frames())

if __name__ == "__main__":
    asyncio.run(main())