amezi commited on
Commit
61b5ef6
·
1 Parent(s): aa822ad

changing to x-clip

Browse files
Files changed (2) hide show
  1. src/embedder.py +16 -18
  2. src/pipeline.py +3 -5
src/embedder.py CHANGED
@@ -1,28 +1,26 @@
1
  import torch
2
  import numpy as np
3
- from transformers import AutoProcessor, AutoModel
4
- import decord
5
 
6
- class InternVLEmbedder:
7
- def __init__(self):
8
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
- self.model = AutoModel.from_pretrained("OpenGVLab/InternVL2_5-1B-MPO", trust_remote_code=True).to(self.device)
10
- self.processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL2_5-1B-MPO", trust_remote_code=True)
11
 
12
  def embed_video(self, video_path):
13
- vr = decord.VideoReader(video_path)
14
- frames = np.stack([vr[i].asnumpy() for i in np.linspace(0, len(vr)-1, 8).astype(int)])
15
- tensor = torch.tensor(frames).permute(0, 3, 1, 2).unsqueeze(0).to(self.device)
16
 
 
17
  with torch.no_grad():
18
- video_vector = self.model.get_video_features(tensor).squeeze(0).cpu().numpy()
19
-
20
- return video_vector / np.linalg.norm(video_vector)
21
 
22
  def embed_text(self, text):
23
- inputs = self.processor(text=[text], return_tensors="pt").to(self.device)
24
-
25
  with torch.no_grad():
26
- text_vector = self.model.get_text_features(**inputs).squeeze(0).cpu().numpy()
27
-
28
- return text_vector / np.linalg.norm(text_vector)
 
1
  import torch
2
  import numpy as np
3
+ from transformers import XCLIPProcessor, XCLIPModel
4
+ from decord import VideoReader, cpu
5
 
6
+ class XCLIPEmbedder:
7
+ def __init__(self, model_name="microsoft/xclip-base-patch14"):
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ self.model = XCLIPModel.from_pretrained(model_name).to(self.device)
10
+ self.processor = XCLIPProcessor.from_pretrained(model_name)
11
 
12
  def embed_video(self, video_path):
13
+ vr = VideoReader(video_path, ctx=cpu(0))
14
+ frame_indices = np.linspace(0, len(vr) - 1, num=8, dtype=int)
15
+ video_frames = vr.get_batch(frame_indices).asnumpy()
16
 
17
+ inputs = self.processor(videos=list(video_frames), return_tensors="pt", padding=True).to(self.device)
18
  with torch.no_grad():
19
+ video_features = self.model.get_video_features(**inputs).squeeze(0).cpu().numpy()
20
+ return video_features / np.linalg.norm(video_features)
 
21
 
22
  def embed_text(self, text):
23
+ inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
 
24
  with torch.no_grad():
25
+ text_features = self.model.get_text_features(**inputs).squeeze(0).cpu().numpy()
26
+ return text_features / np.linalg.norm(text_features)
 
src/pipeline.py CHANGED
@@ -2,7 +2,7 @@ from src.segmenter import detect_event_segments
2
  from src.transcriber import transcribe_video
3
  from src.event_card import parse_game_card
4
  from src.labeler import TogetherLLMLabeler
5
- from src.embedder import InternVLEmbedder
6
  from src.pinecone_store import PineconeStore
7
  from src.utils import (
8
  extract_key_frames, save_frames_locally,
@@ -11,7 +11,7 @@ from src.utils import (
11
  )
12
 
13
  labeler = TogetherLLMLabeler()
14
- embedder = InternVLEmbedder()
15
  pinecone = PineconeStore()
16
 
17
  def run_pipeline(video_path, game_card_str):
@@ -40,7 +40,6 @@ def run_pipeline(video_path, game_card_str):
40
  clip_path = clip_video_segment(video_path, event['start_sec'], event['end_sec'], event_id)
41
 
42
  video_vector = embedder.embed_video(clip_path)
43
- text_vector = embedder.embed_text(label)
44
 
45
  metadata = {
46
  "start_sec": event['start_sec'],
@@ -49,7 +48,6 @@ def run_pipeline(video_path, game_card_str):
49
  }
50
 
51
  pinecone.upsert(f"{event_id}-video", video_vector, metadata)
52
- pinecone.upsert(f"{event_id}-text", text_vector, metadata)
53
 
54
  results.append(metadata)
55
 
@@ -57,7 +55,7 @@ def run_pipeline(video_path, game_card_str):
57
 
58
  def search_highlights(query, top_k=5):
59
  query_vector = embedder.embed_text(query)
60
- results = pinecone.query(query_vector, filter_key="text", top_k=top_k)
61
  return [
62
  f"{r['label']} ({r['start_sec']}s - {r['end_sec']}s)" for r in results
63
  ]
 
2
  from src.transcriber import transcribe_video
3
  from src.event_card import parse_game_card
4
  from src.labeler import TogetherLLMLabeler
5
+ from src.embedder import XCLIPEmbedder
6
  from src.pinecone_store import PineconeStore
7
  from src.utils import (
8
  extract_key_frames, save_frames_locally,
 
11
  )
12
 
13
  labeler = TogetherLLMLabeler()
14
+ embedder = XCLIPEmbedder()
15
  pinecone = PineconeStore()
16
 
17
  def run_pipeline(video_path, game_card_str):
 
40
  clip_path = clip_video_segment(video_path, event['start_sec'], event['end_sec'], event_id)
41
 
42
  video_vector = embedder.embed_video(clip_path)
 
43
 
44
  metadata = {
45
  "start_sec": event['start_sec'],
 
48
  }
49
 
50
  pinecone.upsert(f"{event_id}-video", video_vector, metadata)
 
51
 
52
  results.append(metadata)
53
 
 
55
 
56
  def search_highlights(query, top_k=5):
57
  query_vector = embedder.embed_text(query)
58
+ results = pinecone.query(query_vector, filter_key="video", top_k=top_k)
59
  return [
60
  f"{r['label']} ({r['start_sec']}s - {r['end_sec']}s)" for r in results
61
  ]