Spaces:
Sleeping
Sleeping
File size: 3,362 Bytes
f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 f8b3886 0143794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import time
import warnings
import asyncio
import json
import websockets
warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
cap = cv2.VideoCapture(0)
def resize_image(image, target_size=(256, 256)):
return cv2.resize(image, target_size)
def manual_normalize(depth_map):
min_val = np.min(depth_map)
max_val = np.max(depth_map)
if min_val != max_val:
normalized = (depth_map - min_val) / (max_val - min_val)
return (normalized * 255).astype(np.uint8)
else:
return np.zeros_like(depth_map, dtype=np.uint8)
frame_skip = 4
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
connected = set()
async def broadcast(message):
for websocket in connected:
try:
await websocket.send(message)
except websockets.exceptions.ConnectionClosed:
connected.remove(websocket)
async def handler(websocket, path):
connected.add(websocket)
try:
await websocket.wait_closed()
finally:
connected.remove(websocket)
async def process_frames():
frame_count = 0
prev_frame_time = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized_frame = resize_image(rgb_frame)
inputs = processor(images=resized_frame, return_tensors="pt").to(device)
inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
depth_map = depth_map.astype(np.float32)
if depth_map.size == 0:
depth_map = np.zeros((256, 256), dtype=np.uint8)
else:
if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
else:
depth_map = np.zeros_like(depth_map, dtype=np.uint8)
if np.all(depth_map == 0):
depth_map = manual_normalize(depth_map)
data = {
'depthMap': depth_map.tolist(),
'rgbFrame': rgb_frame.tolist()
}
await broadcast(json.dumps(data))
new_frame_time = time.time()
fps = 1 / (new_frame_time - prev_frame_time)
prev_frame_time = new_frame_time
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
async def main():
server = await websockets.serve(handler, "localhost", 8765)
await asyncio.gather(server.wait_closed(), process_frames())
if __name__ == "__main__":
asyncio.run(main()) |