amezi commited on
Commit
6557d31
·
1 Parent(s): 73e4bdc

fixing some stuff

Browse files
Files changed (4) hide show
  1. requirements.txt +1 -0
  2. src/labeler.py +5 -4
  3. src/segmenter.py +23 -34
  4. src/utils.py +1 -1
requirements.txt CHANGED
@@ -13,3 +13,4 @@ together
13
  einops
14
  opencv-python
15
  timm
 
 
13
  einops
14
  opencv-python
15
  timm
16
+ #inference
src/labeler.py CHANGED
@@ -16,9 +16,6 @@ class TogetherLLMLabeler:
16
  Commentary:
17
  {transcript}
18
 
19
- Spatial Context (object detections per frame):
20
- {spatial_context}
21
-
22
  Instructions:
23
  - Summarize this event in factual soccer terminology.
24
  - Focus on the play's significance to the score.
@@ -35,4 +32,8 @@ class TogetherLLMLabeler:
35
  max_tokens=200
36
  )
37
 
38
- return response.choices[0].message["content"].strip()
 
 
 
 
 
16
  Commentary:
17
  {transcript}
18
 
 
 
 
19
  Instructions:
20
  - Summarize this event in factual soccer terminology.
21
  - Focus on the play's significance to the score.
 
32
  max_tokens=200
33
  )
34
 
35
+ return response.choices[0].message["content"].strip()
36
+
37
+ #after commentary:
38
+ # Spatial Context (object detections per frame):
39
+ # {spatial_context}
src/segmenter.py CHANGED
@@ -1,51 +1,40 @@
1
  import cv2
2
- import os
3
- from roboflow import Roboflow
4
- from dotenv import load_dotenv
5
-
6
- load_dotenv()
7
- ## When the ball is no longer detected, we start a new segment
8
-
9
- def detect_event_segments(video_path, confidence=0.4):
10
- rf = Roboflow(api_key=os.getenv("ROBOFLOW_API_KEY"))
11
- project = rf.workspace().project("soccer-players-ckbru/15")
12
- model = project.version(1).model
13
 
 
14
  cap = cv2.VideoCapture(video_path)
15
  fps = cap.get(cv2.CAP_PROP_FPS)
16
 
17
- events = []
18
- active_event = None
19
- frame_data = []
 
 
20
 
21
  while cap.isOpened():
22
  ret, frame = cap.read()
23
  if not ret:
24
  break
25
 
26
- frame_number = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
27
- detections = model.predict(frame, confidence=confidence).json().get('predictions', [])
28
- frame_data.append({"frame": frame_number, "objects": detections})
29
-
30
- ball_detected = any(obj['class'] == 'ball' for obj in detections)
31
- goal_area_activity = any(obj['class'] == 'goal' for obj in detections) and ball_detected
32
-
33
- if goal_area_activity and active_event is None:
34
- active_event = {"start_frame": frame_number, "frames": []}
35
 
36
- if active_event:
37
- active_event["frames"].append(frame_data[-1])
 
 
 
 
 
 
38
 
39
- if active_event and not ball_detected:
40
- active_event["end_frame"] = frame_number
41
- events.append(active_event)
42
- active_event = None
43
 
44
  cap.release()
45
 
46
- # Convert frames to timestamps
47
- for event in events:
48
- event['start_sec'] = event['start_frame'] / fps
49
- event['end_sec'] = event['end_frame'] / fps
 
50
 
51
- return events
 
1
  import cv2
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ def detect_event_segments(video_path):
4
  cap = cv2.VideoCapture(video_path)
5
  fps = cap.get(cv2.CAP_PROP_FPS)
6
 
7
+ segments = []
8
+ segment_duration = 5 # seconds per segment
9
+ frames_per_segment = int(segment_duration * fps)
10
+
11
+ frame_number = 0
12
 
13
  while cap.isOpened():
14
  ret, frame = cap.read()
15
  if not ret:
16
  break
17
 
18
+ frame_number += 1
 
 
 
 
 
 
 
 
19
 
20
+ if frame_number % frames_per_segment == 1:
21
+ segment_start_sec = (frame_number - 1) / fps
22
+ segment_end_sec = (frame_number + frames_per_segment - 2) / fps
23
+ segments.append({
24
+ "start_sec": segment_start_sec,
25
+ "end_sec": segment_end_sec,
26
+ "frames": [] # This can hold keyframes later if needed
27
+ })
28
 
29
+ if segments:
30
+ segments[-1]["frames"].append(frame)
 
 
31
 
32
  cap.release()
33
 
34
+ # Final cleanup to make sure segment end matches actual video length if needed
35
+ if segments:
36
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
37
+ total_duration = total_frames / fps
38
+ segments[-1]["end_sec"] = min(segments[-1]["end_sec"], total_duration)
39
 
40
+ return segments
src/utils.py CHANGED
@@ -35,7 +35,7 @@ def generate_frame_urls(frame_paths):
35
  base_url = os.getenv("SPACE_URL", "http://localhost:8000")
36
  return [f"{base_url}/data/{os.path.basename(path)}" for path in frame_paths]
37
 
38
- def match_transcript_to_events(events, transcript):
39
  for event in events:
40
  matched_lines = [
41
  line["text"] for line in transcript
 
35
  base_url = os.getenv("SPACE_URL", "http://localhost:8000")
36
  return [f"{base_url}/data/{os.path.basename(path)}" for path in frame_paths]
37
 
38
+ def match_transcript_to_segments(events, transcript):
39
  for event in events:
40
  matched_lines = [
41
  line["text"] for line in transcript