changing to x-clip
Browse files- src/embedder.py +16 -18
- src/pipeline.py +3 -5
src/embedder.py
CHANGED
@@ -1,28 +1,26 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
-
from transformers import
|
4 |
-
import
|
5 |
|
6 |
-
class
|
7 |
-
def __init__(self):
|
8 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
-
self.model =
|
10 |
-
self.processor =
|
11 |
|
12 |
def embed_video(self, video_path):
|
13 |
-
vr =
|
14 |
-
|
15 |
-
|
16 |
|
|
|
17 |
with torch.no_grad():
|
18 |
-
|
19 |
-
|
20 |
-
return video_vector / np.linalg.norm(video_vector)
|
21 |
|
22 |
def embed_text(self, text):
|
23 |
-
inputs = self.processor(text=[text], return_tensors="pt").to(self.device)
|
24 |
-
|
25 |
with torch.no_grad():
|
26 |
-
|
27 |
-
|
28 |
-
return text_vector / np.linalg.norm(text_vector)
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
+
from transformers import XCLIPProcessor, XCLIPModel
|
4 |
+
from decord import VideoReader, cpu
|
5 |
|
6 |
+
class XCLIPEmbedder:
|
7 |
+
def __init__(self, model_name="microsoft/xclip-base-patch14"):
|
8 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
self.model = XCLIPModel.from_pretrained(model_name).to(self.device)
|
10 |
+
self.processor = XCLIPProcessor.from_pretrained(model_name)
|
11 |
|
12 |
def embed_video(self, video_path):
|
13 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
14 |
+
frame_indices = np.linspace(0, len(vr) - 1, num=8, dtype=int)
|
15 |
+
video_frames = vr.get_batch(frame_indices).asnumpy()
|
16 |
|
17 |
+
inputs = self.processor(videos=list(video_frames), return_tensors="pt", padding=True).to(self.device)
|
18 |
with torch.no_grad():
|
19 |
+
video_features = self.model.get_video_features(**inputs).squeeze(0).cpu().numpy()
|
20 |
+
return video_features / np.linalg.norm(video_features)
|
|
|
21 |
|
22 |
def embed_text(self, text):
|
23 |
+
inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device)
|
|
|
24 |
with torch.no_grad():
|
25 |
+
text_features = self.model.get_text_features(**inputs).squeeze(0).cpu().numpy()
|
26 |
+
return text_features / np.linalg.norm(text_features)
|
|
src/pipeline.py
CHANGED
@@ -2,7 +2,7 @@ from src.segmenter import detect_event_segments
|
|
2 |
from src.transcriber import transcribe_video
|
3 |
from src.event_card import parse_game_card
|
4 |
from src.labeler import TogetherLLMLabeler
|
5 |
-
from src.embedder import
|
6 |
from src.pinecone_store import PineconeStore
|
7 |
from src.utils import (
|
8 |
extract_key_frames, save_frames_locally,
|
@@ -11,7 +11,7 @@ from src.utils import (
|
|
11 |
)
|
12 |
|
13 |
labeler = TogetherLLMLabeler()
|
14 |
-
embedder =
|
15 |
pinecone = PineconeStore()
|
16 |
|
17 |
def run_pipeline(video_path, game_card_str):
|
@@ -40,7 +40,6 @@ def run_pipeline(video_path, game_card_str):
|
|
40 |
clip_path = clip_video_segment(video_path, event['start_sec'], event['end_sec'], event_id)
|
41 |
|
42 |
video_vector = embedder.embed_video(clip_path)
|
43 |
-
text_vector = embedder.embed_text(label)
|
44 |
|
45 |
metadata = {
|
46 |
"start_sec": event['start_sec'],
|
@@ -49,7 +48,6 @@ def run_pipeline(video_path, game_card_str):
|
|
49 |
}
|
50 |
|
51 |
pinecone.upsert(f"{event_id}-video", video_vector, metadata)
|
52 |
-
pinecone.upsert(f"{event_id}-text", text_vector, metadata)
|
53 |
|
54 |
results.append(metadata)
|
55 |
|
@@ -57,7 +55,7 @@ def run_pipeline(video_path, game_card_str):
|
|
57 |
|
58 |
def search_highlights(query, top_k=5):
|
59 |
query_vector = embedder.embed_text(query)
|
60 |
-
results = pinecone.query(query_vector, filter_key="
|
61 |
return [
|
62 |
f"{r['label']} ({r['start_sec']}s - {r['end_sec']}s)" for r in results
|
63 |
]
|
|
|
2 |
from src.transcriber import transcribe_video
|
3 |
from src.event_card import parse_game_card
|
4 |
from src.labeler import TogetherLLMLabeler
|
5 |
+
from src.embedder import XCLIPEmbedder
|
6 |
from src.pinecone_store import PineconeStore
|
7 |
from src.utils import (
|
8 |
extract_key_frames, save_frames_locally,
|
|
|
11 |
)
|
12 |
|
13 |
labeler = TogetherLLMLabeler()
|
14 |
+
embedder = XCLIPEmbedder()
|
15 |
pinecone = PineconeStore()
|
16 |
|
17 |
def run_pipeline(video_path, game_card_str):
|
|
|
40 |
clip_path = clip_video_segment(video_path, event['start_sec'], event['end_sec'], event_id)
|
41 |
|
42 |
video_vector = embedder.embed_video(clip_path)
|
|
|
43 |
|
44 |
metadata = {
|
45 |
"start_sec": event['start_sec'],
|
|
|
48 |
}
|
49 |
|
50 |
pinecone.upsert(f"{event_id}-video", video_vector, metadata)
|
|
|
51 |
|
52 |
results.append(metadata)
|
53 |
|
|
|
55 |
|
56 |
def search_highlights(query, top_k=5):
|
57 |
query_vector = embedder.embed_text(query)
|
58 |
+
results = pinecone.query(query_vector, filter_key="video", top_k=top_k)
|
59 |
return [
|
60 |
f"{r['label']} ({r['start_sec']}s - {r['end_sec']}s)" for r in results
|
61 |
]
|