|
import json |
|
import re |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import json_repair |
|
from omagent_core.models.llms.base import BaseLLMBackend |
|
from omagent_core.models.llms.prompt import PromptTemplate |
|
from omagent_core.tool_system.base import ArgSchema, BaseTool |
|
from omagent_core.utils.logger import logging |
|
from omagent_core.utils.registry import registry |
|
from pydantic import Field |
|
from scenedetect import FrameTimecode |
|
|
|
from ...misc.scene import VideoScenes |
|
|
|
CURRENT_PATH = Path(__file__).parents[0] |
|
|
|
ARGSCHEMA = { |
|
"start_time": { |
|
"type": "number", |
|
"description": "Start time (in seconds) of the video to extract frames from.", |
|
"required": True, |
|
}, |
|
"end_time": { |
|
"type": "number", |
|
"description": "End time (in seconds) of the video to extract frames from.", |
|
"required": True, |
|
}, |
|
"number": { |
|
"type": "number", |
|
"description": "Number of frames of extraction. More frames means more details but more cost. Do not exceed 10.", |
|
"required": True, |
|
}, |
|
} |
|
|
|
|
|
@registry.register_tool() |
|
class Rewinder(BaseTool, BaseLLMBackend): |
|
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA) |
|
description: str = ( |
|
"Rollback and extract frames from video which is already loaded to get more specific details for further analysis." |
|
) |
|
prompts: List[PromptTemplate] = Field( |
|
default=[ |
|
PromptTemplate.from_file( |
|
CURRENT_PATH.joinpath("rewinder_sys_prompt.prompt"), |
|
role="system", |
|
), |
|
PromptTemplate.from_file( |
|
CURRENT_PATH.joinpath("rewinder_user_prompt.prompt"), |
|
role="user", |
|
), |
|
] |
|
) |
|
|
|
def _run( |
|
self, start_time: float = 0.0, end_time: float = None, number: int = 1 |
|
) -> str: |
|
if self.stm(self.workflow_instance_id).get("video", None) is None: |
|
raise ValueError("No video is loaded.") |
|
else: |
|
video: VideoScenes = VideoScenes.from_serializable( |
|
self.stm(self.workflow_instance_id)["video"] |
|
) |
|
if number > 10: |
|
logging.warning("Number of frames exceeds 10. Will extract 10 frames.") |
|
number = 10 |
|
|
|
start = FrameTimecode(timecode=start_time, fps=video.stream.frame_rate) |
|
if end_time is None: |
|
end = video.stream.duration |
|
else: |
|
end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate) |
|
|
|
if start_time == end_time: |
|
frames, time_stamps = video.get_video_frames( |
|
(start, end + 1), video.stream.frame_rate |
|
) |
|
else: |
|
interval = int((end.get_frames() - start.get_frames()) / number) |
|
frames, time_stamps = video.get_video_frames((start, end), interval) |
|
|
|
|
|
payload = [] |
|
for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)): |
|
payload.append(f"timestamp_{time_stamp}") |
|
payload.append(frame) |
|
res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][ |
|
0 |
|
]["message"]["content"] |
|
image_contents = json_repair.loads(res) |
|
self.stm(self.workflow_instance_id)["image_cache"] = {} |
|
return f"extracted_frames described as: {image_contents}." |
|
|
|
async def _arun( |
|
self, start_time: float = 0.0, end_time: float = None, number: int = 1 |
|
) -> str: |
|
return self._run(start_time, end_time, number=number) |
|
|