freddyaboulton HF staff commited on
Commit
61732db
·
1 Parent(s): 6a95f1f
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -2,6 +2,9 @@ import spaces
2
  import gradio as gr
3
  import cv2
4
  from PIL import Image
 
 
 
5
 
6
  from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
7
 
@@ -16,7 +19,7 @@ def stream_object_detection(video, conf_threshold):
16
 
17
  video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
18
  fps = int(cap.get(cv2.CAP_PROP_FPS))
19
- desired_fps = fps // 3
20
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
21
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
22
 
@@ -24,28 +27,33 @@ def stream_object_detection(video, conf_threshold):
24
 
25
  n_frames = 0
26
  n_chunks = 0
27
- name = str(current_dir / f"output_{n_chunks}.ts")
 
28
  segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
29
  batch = []
30
 
31
  while iterating:
32
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
33
- if n_frames % 3 == 0:
34
  batch.append(frame)
35
- if len(batch) == desired_fps:
36
  inputs = image_processor(images=batch, return_tensors="pt")
37
 
 
 
38
  with torch.no_grad():
39
  outputs = model(**inputs)
 
 
40
 
41
  boxes = image_processor.post_process_object_detection(
42
  outputs,
43
- target_sizes=torch.tensor([batch[0].shape[::-1]] * len(batch)),
44
  threshold=conf_threshold)
45
 
46
  for array, box in zip(batch, boxes):
47
- pil_image = draw_bounding_boxes(Image.from_array(array), boxes[0], model, 0.3)
48
- frame = numpy.array(pil_image)
49
  # Convert RGB to BGR
50
  frame = frame[:, :, ::-1].copy()
51
  segment_file.write(frame)
@@ -54,7 +62,7 @@ def stream_object_detection(video, conf_threshold):
54
  n_frames = 0
55
  n_chunks += 1
56
  yield name
57
- name = str(current_dir / f"output_{n_chunks}.ts")
58
  segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
59
 
60
  iterating, frame = cap.read()
@@ -83,7 +91,7 @@ with gr.Blocks(css=css) as app:
83
  """)
84
  with gr.Column(elem_classes=["my-column"]):
85
  with gr.Group(elem_classes=["my-group"]):
86
- video = gr.Video(label="Video Source")
87
  conf_threshold = gr.Slider(
88
  label="Confidence Threshold",
89
  minimum=0.0,
 
2
  import gradio as gr
3
  import cv2
4
  from PIL import Image
5
+ import torch
6
+ import time
7
+ import numpy as np
8
 
9
  from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
10
 
 
19
 
20
  video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
21
  fps = int(cap.get(cv2.CAP_PROP_FPS))
22
+ desired_fps = fps // 5
23
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
24
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
25
 
 
27
 
28
  n_frames = 0
29
  n_chunks = 0
30
+
31
+ name = f"output_{n_chunks}.ts"
32
  segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
33
  batch = []
34
 
35
  while iterating:
36
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
37
+ if n_frames % 5 == 0:
38
  batch.append(frame)
39
+ if len(batch) == 2 * desired_fps:
40
  inputs = image_processor(images=batch, return_tensors="pt")
41
 
42
+ print(f"starting batch of size {len(batch)}")
43
+ start = time.time()
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
+ end = time.time()
47
+ print("time taken ", end - start)
48
 
49
  boxes = image_processor.post_process_object_detection(
50
  outputs,
51
+ target_sizes=torch.tensor([frame[0].shape[:2][::-1]] * len(batch)),
52
  threshold=conf_threshold)
53
 
54
  for array, box in zip(batch, boxes):
55
+ pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
56
+ frame = np.array(pil_image)
57
  # Convert RGB to BGR
58
  frame = frame[:, :, ::-1].copy()
59
  segment_file.write(frame)
 
62
  n_frames = 0
63
  n_chunks += 1
64
  yield name
65
+ name = f"output_{n_chunks}.ts"
66
  segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
67
 
68
  iterating, frame = cap.read()
 
91
  """)
92
  with gr.Column(elem_classes=["my-column"]):
93
  with gr.Group(elem_classes=["my-group"]):
94
+ video = gr.Video(label="Video Source", streaming=True, autoplay=True)
95
  conf_threshold = gr.Slider(
96
  label="Confidence Threshold",
97
  minimum=0.0,