qubvel-hf HF Staff commited on
Commit
3fd48f0
·
1 Parent(s): 8102816
Files changed (1) hide show
  1. app.py +68 -74
app.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
 
10
  import spaces
11
  import gradio as gr
 
12
 
13
  from pathlib import Path
14
  from functools import lru_cache
@@ -69,38 +70,43 @@ logging.basicConfig(
69
  logger = logging.getLogger(__name__)
70
 
71
 
72
- @lru_cache(maxsize=3)
73
- def get_model_and_image_processor(checkpoint: str, device: str = "cpu"):
74
- model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE).to(device)
75
- image_processor = AutoImageProcessor.from_pretrained(checkpoint)
76
- return model, image_processor
77
-
78
  @spaces.GPU(duration=20)
79
  def detect_objects(
80
  checkpoint: str,
81
- images: Optional[List[Image.Image]] = None,
82
  confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
83
- target_sizes: Optional[List[Tuple[int, int]]] = None,
 
84
  ):
85
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
- model, image_processor = get_model_and_image_processor(checkpoint, device=device)
 
88
 
89
- # preprocess images
90
- inputs = image_processor(images=images, return_tensors="pt")
91
- inputs = inputs.to(device).to(TORCH_DTYPE)
92
 
93
- # forward pass
94
- with torch.no_grad():
95
- outputs = model(**inputs)
96
 
97
- # postprocess outputs
98
- if not target_sizes:
99
- target_sizes = [(image.height, image.width) for image in images]
100
 
101
- results = image_processor.post_process_object_detection(
102
- outputs, target_sizes=target_sizes, threshold=confidence_threshold
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  return results, model.config.id2label
106
 
@@ -120,7 +126,7 @@ def process_image(
120
 
121
  results, id2label = detect_objects(
122
  checkpoint=checkpoint,
123
- images=[image],
124
  confidence_threshold=confidence_threshold,
125
  )
126
  result = results[0] # first image in batch (we have batch size 1)
@@ -150,6 +156,25 @@ def get_target_size(image_height, image_width, max_size: int):
150
  new_height = int(image_height * max_size / image_width)
151
  return new_width, new_height
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def process_video(
154
  video_path: str,
155
  checkpoint: str,
@@ -164,69 +189,38 @@ def process_video(
164
  if ext not in ALLOWED_VIDEO_EXTENSIONS:
165
  raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}")
166
 
167
- cap = cv2.VideoCapture(video_path)
168
- if not cap.isOpened():
169
- raise ValueError(f"Failed to open video: {video_path}")
170
-
171
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
172
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
173
- fps = cap.get(cv2.CAP_PROP_FPS)
174
- num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
175
 
176
- process_each_frame = fps // 25
177
- target_fps = fps / process_each_frame
178
- target_width, target_height = get_target_size(height, width, 1080)
179
 
180
  # Use H.264 codec for browser compatibility
181
- fourcc = cv2.VideoWriter_fourcc(*"MJPG")
182
- temp_file = tempfile.NamedTemporaryFile(suffix=".avi", delete=False)
183
  writer = cv2.VideoWriter(temp_file.name, fourcc, target_fps, (target_width, target_height))
184
 
185
  box_annotator = sv.BoxAnnotator(thickness=1)
186
  label_annotator = sv.LabelAnnotator(text_scale=0.5)
187
-
188
- if not writer.isOpened():
189
- cap.release()
190
- temp_file.close()
191
- os.unlink(temp_file.name)
192
- raise ValueError("Failed to initialize video writer")
193
-
194
- frames_to_process = int(min(MAX_NUM_FRAMES * process_each_frame, num_frames))
195
- batch = []
196
-
197
- for i in tqdm.tqdm(range(frames_to_process), desc="Processing video"):
198
-
199
- ok, frame = cap.read()
200
- if not ok:
201
- break
202
-
203
- if not i % process_each_frame == 0:
204
- continue
205
 
206
- if len(batch) < BATCH_SIZE:
207
- frame = frame[:, :, ::-1].copy() # BGR to RGB
208
- batch.append(frame)
209
- continue
210
-
211
- results, id2label = detect_objects(
212
- images=[Image.fromarray(frame) for frame in batch],
213
- checkpoint=checkpoint,
214
- confidence_threshold=confidence_threshold,
215
- target_sizes=[(target_height, target_width)] * len(batch),
216
- )
217
-
218
- for frame, result in zip(batch, results):
219
- frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
220
- detections = sv.Detections.from_transformers(result, id2label=id2label)
221
- detections = detections.with_nms(threshold=0.95, class_agnostic=True)
222
- annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
223
- annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections)
224
- writer.write(cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
225
 
226
- batch = []
 
 
 
 
 
 
227
 
228
  writer.release()
229
- cap.release()
230
  temp_file.close()
231
 
232
  # Copy to persistent directory for Gradio access
 
9
 
10
  import spaces
11
  import gradio as gr
12
+ import numpy as np
13
 
14
  from pathlib import Path
15
  from functools import lru_cache
 
70
  logger = logging.getLogger(__name__)
71
 
72
 
 
 
 
 
 
 
73
  @spaces.GPU(duration=20)
74
  def detect_objects(
75
  checkpoint: str,
76
+ images: List[np.ndarray],
77
  confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
78
+ target_size: Optional[Tuple[int, int]] = None,
79
+ batch_size: int = BATCH_SIZE,
80
  ):
81
 
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE).to(device)
84
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
85
 
86
+ batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
 
 
87
 
88
+ results = []
89
+ for batch in tqdm.tqdm(batches, desc="Processing frames"):
 
90
 
91
+ # preprocess images
92
+ inputs = image_processor(images=batch, return_tensors="pt")
93
+ inputs = inputs.to(device).to(TORCH_DTYPE)
94
 
95
+ # forward pass
96
+ with torch.no_grad():
97
+ outputs = model(**inputs)
98
+
99
+ # postprocess outputs
100
+ if target_size:
101
+ target_sizes = [target_size] * len(batch)
102
+ else:
103
+ target_sizes = [(image.shape[0], image.shape[1]) for image in batch]
104
+
105
+ batch_results = image_processor.post_process_object_detection(
106
+ outputs, target_sizes=target_sizes, threshold=confidence_threshold
107
+ )
108
+
109
+ results.extend(batch_results)
110
 
111
  return results, model.config.id2label
112
 
 
126
 
127
  results, id2label = detect_objects(
128
  checkpoint=checkpoint,
129
+ images=[np.array(image)],
130
  confidence_threshold=confidence_threshold,
131
  )
132
  result = results[0] # first image in batch (we have batch size 1)
 
156
  new_height = int(image_height * max_size / image_width)
157
  return new_width, new_height
158
 
159
+
160
+ def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1):
161
+ cap = cv2.VideoCapture(video_path)
162
+ frames = []
163
+ i = 0
164
+ progress_bar = tqdm.tqdm(total=k, desc="Reading frames")
165
+ while cap.isOpened() and len(frames) < k:
166
+ ret, frame = cap.read()
167
+ if not ret:
168
+ break
169
+ if i % read_every_i_frame == 0:
170
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
171
+ progress_bar.update(1)
172
+ i += 1
173
+ cap.release()
174
+ progress_bar.close()
175
+ return frames
176
+
177
+
178
  def process_video(
179
  video_path: str,
180
  checkpoint: str,
 
189
  if ext not in ALLOWED_VIDEO_EXTENSIONS:
190
  raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}")
191
 
192
+ video_info = sv.VideoInfo.from_video_path(video_path)
193
+ read_each_i_frame = video_info.fps // 25
194
+ target_fps = video_info.fps / read_each_i_frame
195
+ target_width, target_height = get_target_size(video_info.height, video_info.width, 1080)
 
 
 
 
196
 
197
+ n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame)
198
+ frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame)
 
199
 
200
  # Use H.264 codec for browser compatibility
201
+ fourcc = cv2.VideoWriter_fourcc(*"H264")
202
+ temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
203
  writer = cv2.VideoWriter(temp_file.name, fourcc, target_fps, (target_width, target_height))
204
 
205
  box_annotator = sv.BoxAnnotator(thickness=1)
206
  label_annotator = sv.LabelAnnotator(text_scale=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ results, id2label = detect_objects(
209
+ images=frames,
210
+ checkpoint=checkpoint,
211
+ confidence_threshold=confidence_threshold,
212
+ target_size=(target_height, target_width),
213
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)):
216
+ frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
217
+ detections = sv.Detections.from_transformers(result, id2label=id2label)
218
+ detections = detections.with_nms(threshold=0.95, class_agnostic=True)
219
+ annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
220
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections)
221
+ writer.write(cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
222
 
223
  writer.release()
 
224
  temp_file.close()
225
 
226
  # Copy to persistent directory for Gradio access