qubvel-hf HF Staff commited on
Commit
9d19ec6
·
1 Parent(s): 9f2cbff
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -53,7 +53,7 @@ IMAGE_EXAMPLES = [
53
  ]
54
 
55
  # Video
56
- MAX_NUM_FRAMES = 500
57
  BATCH_SIZE = 4
58
  ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
59
  VIDEO_OUTPUT_DIR = Path("static/videos")
@@ -70,18 +70,28 @@ logging.basicConfig(
70
  logger = logging.getLogger(__name__)
71
 
72
 
 
 
 
 
 
 
 
73
  @spaces.GPU(duration=20)
74
  def detect_objects(
75
  checkpoint: str,
76
- images: List[np.ndarray],
77
  confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
78
  target_size: Optional[Tuple[int, int]] = None,
79
  batch_size: int = BATCH_SIZE,
80
  ):
81
 
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
- model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE).to(device)
84
- image_processor = AutoImageProcessor.from_pretrained(checkpoint)
 
 
 
85
 
86
  batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
87
 
@@ -205,12 +215,13 @@ def process_video(
205
 
206
  n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame)
207
  frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame)
 
208
 
209
  box_annotator = sv.BoxAnnotator(thickness=1)
210
  label_annotator = sv.LabelAnnotator(text_scale=0.5)
211
 
212
  results, id2label = detect_objects(
213
- images=frames,
214
  checkpoint=checkpoint,
215
  confidence_threshold=confidence_threshold,
216
  target_size=(target_height, target_width),
@@ -218,7 +229,6 @@ def process_video(
218
 
219
  annotated_frames = []
220
  for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)):
221
- frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_AREA)
222
  detections = sv.Detections.from_transformers(result, id2label=id2label)
223
  detections = detections.with_nms(threshold=0.95, class_agnostic=True)
224
  annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
@@ -226,7 +236,7 @@ def process_video(
226
  annotated_frames.append(annotated_frame)
227
 
228
  output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4")
229
- iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264") #, pixelformat="yuv420p")
230
  return output_filename
231
 
232
 
 
53
  ]
54
 
55
  # Video
56
+ MAX_NUM_FRAMES = 250
57
  BATCH_SIZE = 4
58
  ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
59
  VIDEO_OUTPUT_DIR = Path("static/videos")
 
70
  logger = logging.getLogger(__name__)
71
 
72
 
73
+ @lru_cache(maxsize=3)
74
+ def get_model_and_processor(checkpoint: str):
75
+ model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE)
76
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
77
+ return model, image_processor
78
+
79
+
80
  @spaces.GPU(duration=20)
81
  def detect_objects(
82
  checkpoint: str,
83
+ images: List[np.ndarray] | np.ndarray,
84
  confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
85
  target_size: Optional[Tuple[int, int]] = None,
86
  batch_size: int = BATCH_SIZE,
87
  ):
88
 
89
  device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ model, image_processor = get_model_and_processor(checkpoint)
91
+ model = model.to(device)
92
+
93
+ if isinstance(images, np.ndarray) and images.ndim == 4:
94
+ images = [x for x in images] # split video array into list of images
95
 
96
  batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
97
 
 
215
 
216
  n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame)
217
  frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame)
218
+ frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames]
219
 
220
  box_annotator = sv.BoxAnnotator(thickness=1)
221
  label_annotator = sv.LabelAnnotator(text_scale=0.5)
222
 
223
  results, id2label = detect_objects(
224
+ images=np.array(frames),
225
  checkpoint=checkpoint,
226
  confidence_threshold=confidence_threshold,
227
  target_size=(target_height, target_width),
 
229
 
230
  annotated_frames = []
231
  for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)):
 
232
  detections = sv.Detections.from_transformers(result, id2label=id2label)
233
  detections = detections.with_nms(threshold=0.95, class_agnostic=True)
234
  annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
 
236
  annotated_frames.append(annotated_frame)
237
 
238
  output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4")
239
+ iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264")
240
  return output_filename
241
 
242