Files changed (1) hide show
  1. app.py +29 -78
app.py CHANGED
@@ -2,20 +2,12 @@ import cv2
2
  import torch
3
  import numpy as np
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
- import time
6
- import warnings
7
- import asyncio
8
- import json
9
- import websockets
10
-
11
- warnings.filterwarnings("ignore", message="It looks like you are trying to rescale already rescaled images.")
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device)
15
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
16
 
17
- cap = cv2.VideoCapture(0)
18
-
19
  def resize_image(image, target_size=(256, 256)):
20
  return cv2.resize(image, target_size)
21
 
@@ -28,84 +20,43 @@ def manual_normalize(depth_map):
28
  else:
29
  return np.zeros_like(depth_map, dtype=np.uint8)
30
 
31
- frame_skip = 4
32
  color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
33
 
34
- connected = set()
35
-
36
- async def broadcast(message):
37
- for websocket in connected:
38
- try:
39
- await websocket.send(message)
40
- except websockets.exceptions.ConnectionClosed:
41
- connected.remove(websocket)
42
-
43
- async def handler(websocket, path):
44
- connected.add(websocket)
45
- try:
46
- await websocket.wait_closed()
47
- finally:
48
- connected.remove(websocket)
49
-
50
- async def process_frames():
51
- frame_count = 0
52
- prev_frame_time = 0
53
-
54
- while True:
55
- ret, frame = cap.read()
56
- if not ret:
57
- break
58
 
59
- frame_count += 1
60
- if frame_count % frame_skip != 0:
61
- continue
62
 
63
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
- resized_frame = resize_image(rgb_frame)
 
65
 
66
- inputs = processor(images=resized_frame, return_tensors="pt").to(device)
67
- inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
68
 
69
- with torch.no_grad():
70
- outputs = model(**inputs)
71
- predicted_depth = outputs.predicted_depth
72
 
73
- depth_map = predicted_depth.squeeze().cpu().numpy()
74
-
75
- depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
76
- depth_map = depth_map.astype(np.float32)
77
-
78
- if depth_map.size == 0:
79
- depth_map = np.zeros((256, 256), dtype=np.uint8)
80
  else:
81
- if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
82
- depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
83
- else:
84
- depth_map = np.zeros_like(depth_map, dtype=np.uint8)
85
-
86
- if np.all(depth_map == 0):
87
- depth_map = manual_normalize(depth_map)
88
-
89
- data = {
90
- 'depthMap': depth_map.tolist(),
91
- 'rgbFrame': rgb_frame.tolist()
92
- }
93
-
94
- await broadcast(json.dumps(data))
95
-
96
- new_frame_time = time.time()
97
- fps = 1 / (new_frame_time - prev_frame_time)
98
- prev_frame_time = new_frame_time
99
 
100
- if cv2.waitKey(1) & 0xFF == ord('q'):
101
- break
102
 
103
- cap.release()
104
- cv2.destroyAllWindows()
105
 
106
- async def main():
107
- server = await websockets.serve(handler, "localhost", 8765)
108
- await asyncio.gather(server.wait_closed(), process_frames())
 
 
 
109
 
110
- if __name__ == "__main__":
111
- asyncio.run(main())
 
2
  import torch
3
  import numpy as np
4
  from transformers import DPTForDepthEstimation, DPTImageProcessor
5
+ import gradio as gr
 
 
 
 
 
 
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
  model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float16).to(device)
9
  processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
10
 
 
 
11
  def resize_image(image, target_size=(256, 256)):
12
  return cv2.resize(image, target_size)
13
 
 
20
  else:
21
  return np.zeros_like(depth_map, dtype=np.uint8)
22
 
 
23
  color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
24
 
25
+ def process_frame(image):
26
+ rgb_frame = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
27
+ resized_frame = resize_image(rgb_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ inputs = processor(images=resized_frame, return_tensors="pt").to(device)
30
+ inputs = {k: v.to(torch.float16) for k, v in inputs.items()}
 
31
 
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ predicted_depth = outputs.predicted_depth
35
 
36
+ depth_map = predicted_depth.squeeze().cpu().numpy()
 
37
 
38
+ depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
39
+ depth_map = depth_map.astype(np.float32)
 
40
 
41
+ if depth_map.size == 0:
42
+ depth_map = np.zeros((256, 256), dtype=np.uint8)
43
+ else:
44
+ if np.any(depth_map) and np.min(depth_map) != np.max(depth_map):
45
+ depth_map = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
 
 
46
  else:
47
+ depth_map = np.zeros_like(depth_map, dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ if np.all(depth_map == 0):
50
+ depth_map = manual_normalize(depth_map)
51
 
52
+ depth_map_colored = cv2.applyColorMap(depth_map, color_map)
53
+ return cv2.cvtColor(depth_map_colored, cv2.COLOR_BGR2RGB)
54
 
55
+ interface = gr.Interface(
56
+ fn=process_frame,
57
+ inputs=gr.Image(source="webcam", streaming=True),
58
+ outputs="image",
59
+ live=True
60
+ )
61
 
62
+ interface.launch()