|
from typing import Any, Optional |
|
|
|
from omagent_core.models.encoders.base import EncoderBase |
|
from omagent_core.utils.error import VQLError |
|
from omagent_core.utils.registry import registry |
|
from pydantic import field_validator |
|
from pymilvus import CollectionSchema, DataType, FieldSchema |
|
|
|
from .milvus_handler import MilvusHandler |
|
|
|
|
|
@registry.register_component() |
|
class VideoHandler(MilvusHandler): |
|
collection_name: str |
|
text_encoder: Optional[EncoderBase] = None |
|
dim: int = None |
|
|
|
class Config: |
|
"""Configuration for this pydantic object.""" |
|
|
|
extra = "allow" |
|
arbitrary_types_allowed = True |
|
|
|
@field_validator("text_encoder", mode="before") |
|
@classmethod |
|
def init_encoder(cls, text_encoder): |
|
if isinstance(text_encoder, EncoderBase): |
|
return text_encoder |
|
elif isinstance(text_encoder, dict): |
|
return registry.get_encoder(text_encoder.get("name"))(**text_encoder) |
|
else: |
|
raise ValueError("text_encoder must be EncoderBase or Dict") |
|
|
|
def __init__(self, **data: Any) -> None: |
|
super().__init__(**data) |
|
|
|
self.dim = self.text_encoder.dim |
|
|
|
_uid = FieldSchema( |
|
name="_uid", dtype=DataType.INT64, is_primary=True, auto_id=True |
|
) |
|
video_md5 = FieldSchema( |
|
name="video_md5", dtype=DataType.VARCHAR, max_length=100 |
|
) |
|
content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535) |
|
content_vector = FieldSchema( |
|
name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim |
|
) |
|
start_time = FieldSchema( |
|
name="start_time", |
|
dtype=DataType.FLOAT, |
|
) |
|
end_time = FieldSchema( |
|
name="end_time", |
|
dtype=DataType.FLOAT, |
|
) |
|
schema = CollectionSchema( |
|
fields=[_uid, video_md5, content, content_vector, start_time, end_time], |
|
description="video summary vector DB", |
|
enable_dynamic_field=True, |
|
) |
|
self.make_collection(self.collection_name, schema) |
|
|
|
def text_add(self, video_md5, content, start_time, end_time): |
|
|
|
if self.text_encoder is None: |
|
raise VQLError(500, detail="Missing text_encoder") |
|
content_vector = self.text_encoder.infer([content])[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_data = [ |
|
{ |
|
"video_md5": video_md5, |
|
"content": content, |
|
"content_vector": content_vector, |
|
"start_time": start_time, |
|
"end_time": end_time, |
|
} |
|
] |
|
|
|
add_detail = self.do_add(self.collection_name, upload_data) |
|
|
|
|
|
def text_match( |
|
self, |
|
video_md5, |
|
content, |
|
threshold: float, |
|
start_time=None, |
|
end_time=None, |
|
res_size: int = 100, |
|
): |
|
|
|
|
|
filter_expr = "" |
|
if video_md5 is not None: |
|
filter_expr = f"video_md5=='{video_md5}'" |
|
if start_time is not None and end_time is not None: |
|
filter_expr += f" and (start_time>={max(0, start_time - 10)} and end_time<={end_time + 10})" |
|
elif start_time is not None: |
|
filter_expr += f" and start_time>={max(0, start_time - 10)}" |
|
elif end_time is not None: |
|
filter_expr += f" and end_time<={end_time + 10}" |
|
|
|
content_vector = self.text_encoder.infer([content])[0] |
|
match_res = self.match( |
|
collection_name=self.collection_name, |
|
query_vectors=[content_vector], |
|
query_field="content_vector", |
|
output_fields=["content", "start_time", "end_time"], |
|
res_size=res_size, |
|
threshold=threshold, |
|
filter_expr=filter_expr, |
|
) |
|
|
|
output = [] |
|
for match in match_res[0]: |
|
print(match) |
|
output.append(match["entity"]) |
|
|
|
return output |
|
|