huntrezz's picture
Update app.py
0143794 verified
raw
history blame
3.36 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import time
import warnings
import asyncio
import json
import websockets
warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
cap = cv2.VideoCapture(0)
def resize_image(image, target_size=(256, 256)):
return cv2.resize(image, target_size)
def manual_normalize(depth_map):
min_val = np.min(depth_map)
max_val = np.max(depth_map)
if min_val != max_val:
normalized = (depth_map - min_val) / (max_val - min_val)
return (normalized * 255).astype(np.uint8)
else:
return np.zeros_like(depth_map, dtype=np.uint8)
frame_skip = 4
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
connected = set()
async def broadcast(message):
for websocket in connected:
try:
await websocket.send(message)
except websockets.exceptions.ConnectionClosed:
connected.remove(websocket)
async def handler(websocket, path):
connected.add(websocket)
try:
await websocket.wait_closed()
finally:
connected.remove(websocket)
async def process_frames():
frame_count = 0
prev_frame_time = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % frame_skip != 0:
continue
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized_frame = resize_image(rgb_frame)
inputs = processor(images=resized_frame, return_tensors="pt").to(device)
inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
depth_map = depth_map.astype(np.float32)
if depth_map.size == 0:
depth_map = np.zeros((256, 256), dtype=np.uint8)
else:
if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
else:
depth_map = np.zeros_like(depth_map, dtype=np.uint8)
if np.all(depth_map == 0):
depth_map = manual_normalize(depth_map)
data = {
'depthMap': depth_map.tolist(),
'rgbFrame': rgb_frame.tolist()
}
await broadcast(json.dumps(data))
new_frame_time = time.time()
fps = 1 / (new_frame_time - prev_frame_time)
prev_frame_time = new_frame_time
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
async def main():
server = await websockets.serve(handler, "localhost", 8765)
await asyncio.gather(server.wait_closed(), process_frames())
if __name__ == "__main__":
asyncio.run(main())