韩宇
init
1b7e88c
from typing import Any, Optional
from omagent_core.models.encoders.base import EncoderBase
from omagent_core.utils.error import VQLError
from omagent_core.utils.registry import registry
from pydantic import field_validator
from pymilvus import CollectionSchema, DataType, FieldSchema
from .milvus_handler import MilvusHandler
@registry.register_component()
class VideoHandler(MilvusHandler):
collection_name: str
text_encoder: Optional[EncoderBase] = None
dim: int = None
class Config:
"""Configuration for this pydantic object."""
extra = "allow"
arbitrary_types_allowed = True
@field_validator("text_encoder", mode="before")
@classmethod
def init_encoder(cls, text_encoder):
if isinstance(text_encoder, EncoderBase):
return text_encoder
elif isinstance(text_encoder, dict):
return registry.get_encoder(text_encoder.get("name"))(**text_encoder)
else:
raise ValueError("text_encoder must be EncoderBase or Dict")
def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.dim = self.text_encoder.dim
_uid = FieldSchema(
name="_uid", dtype=DataType.INT64, is_primary=True, auto_id=True
)
video_md5 = FieldSchema(
name="video_md5", dtype=DataType.VARCHAR, max_length=100
)
content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535)
content_vector = FieldSchema(
name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim
)
start_time = FieldSchema(
name="start_time",
dtype=DataType.FLOAT,
)
end_time = FieldSchema(
name="end_time",
dtype=DataType.FLOAT,
)
schema = CollectionSchema(
fields=[_uid, video_md5, content, content_vector, start_time, end_time],
description="video summary vector DB",
enable_dynamic_field=True,
)
self.make_collection(self.collection_name, schema)
def text_add(self, video_md5, content, start_time, end_time):
if self.text_encoder is None:
raise VQLError(500, detail="Missing text_encoder")
content_vector = self.text_encoder.infer([content])[0]
# upload_data = [
# [video_md5],
# [content],
# [content_vector],
# [start_time],
# [end_time],
# ]
upload_data = [
{
"video_md5": video_md5,
"content": content,
"content_vector": content_vector,
"start_time": start_time,
"end_time": end_time,
}
]
add_detail = self.do_add(self.collection_name, upload_data)
# assert add_detail.succ_count == len(upload_data)
def text_match(
self,
video_md5,
content,
threshold: float,
start_time=None,
end_time=None,
res_size: int = 100,
):
# search_query = {"size": res_size, "sort": [{"_score": "desc"}],
# "include": ["content", "start_time", "end_time"]}
filter_expr = ""
if video_md5 is not None:
filter_expr = f"video_md5=='{video_md5}'"
if start_time is not None and end_time is not None:
filter_expr += f" and (start_time>={max(0, start_time - 10)} and end_time<={end_time + 10})"
elif start_time is not None:
filter_expr += f" and start_time>={max(0, start_time - 10)}"
elif end_time is not None:
filter_expr += f" and end_time<={end_time + 10}"
# text retrieve stage
content_vector = self.text_encoder.infer([content])[0]
match_res = self.match(
collection_name=self.collection_name,
query_vectors=[content_vector],
query_field="content_vector",
output_fields=["content", "start_time", "end_time"],
res_size=res_size,
threshold=threshold,
filter_expr=filter_expr,
)
output = []
for match in match_res[0]:
print(match)
output.append(match["entity"])
return output