Spaces:
Sleeping
Sleeping
from typing import Any, Optional | |
from omagent_core.models.encoders.base import EncoderBase | |
from omagent_core.utils.error import VQLError | |
from omagent_core.utils.registry import registry | |
from pydantic import field_validator | |
from pymilvus import CollectionSchema, DataType, FieldSchema | |
from .milvus_handler import MilvusHandler | |
class VideoHandler(MilvusHandler): | |
collection_name: str | |
text_encoder: Optional[EncoderBase] = None | |
dim: int = None | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = "allow" | |
arbitrary_types_allowed = True | |
def init_encoder(cls, text_encoder): | |
if isinstance(text_encoder, EncoderBase): | |
return text_encoder | |
elif isinstance(text_encoder, dict): | |
return registry.get_encoder(text_encoder.get("name"))(**text_encoder) | |
else: | |
raise ValueError("text_encoder must be EncoderBase or Dict") | |
def __init__(self, **data: Any) -> None: | |
super().__init__(**data) | |
self.dim = self.text_encoder.dim | |
_uid = FieldSchema( | |
name="_uid", dtype=DataType.INT64, is_primary=True, auto_id=True | |
) | |
video_md5 = FieldSchema( | |
name="video_md5", dtype=DataType.VARCHAR, max_length=100 | |
) | |
content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535) | |
content_vector = FieldSchema( | |
name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim | |
) | |
start_time = FieldSchema( | |
name="start_time", | |
dtype=DataType.FLOAT, | |
) | |
end_time = FieldSchema( | |
name="end_time", | |
dtype=DataType.FLOAT, | |
) | |
schema = CollectionSchema( | |
fields=[_uid, video_md5, content, content_vector, start_time, end_time], | |
description="video summary vector DB", | |
enable_dynamic_field=True, | |
) | |
self.make_collection(self.collection_name, schema) | |
def text_add(self, video_md5, content, start_time, end_time): | |
if self.text_encoder is None: | |
raise VQLError(500, detail="Missing text_encoder") | |
content_vector = self.text_encoder.infer([content])[0] | |
# upload_data = [ | |
# [video_md5], | |
# [content], | |
# [content_vector], | |
# [start_time], | |
# [end_time], | |
# ] | |
upload_data = [ | |
{ | |
"video_md5": video_md5, | |
"content": content, | |
"content_vector": content_vector, | |
"start_time": start_time, | |
"end_time": end_time, | |
} | |
] | |
add_detail = self.do_add(self.collection_name, upload_data) | |
# assert add_detail.succ_count == len(upload_data) | |
def text_match( | |
self, | |
video_md5, | |
content, | |
threshold: float, | |
start_time=None, | |
end_time=None, | |
res_size: int = 100, | |
): | |
# search_query = {"size": res_size, "sort": [{"_score": "desc"}], | |
# "include": ["content", "start_time", "end_time"]} | |
filter_expr = "" | |
if video_md5 is not None: | |
filter_expr = f"video_md5=='{video_md5}'" | |
if start_time is not None and end_time is not None: | |
filter_expr += f" and (start_time>={max(0, start_time - 10)} and end_time<={end_time + 10})" | |
elif start_time is not None: | |
filter_expr += f" and start_time>={max(0, start_time - 10)}" | |
elif end_time is not None: | |
filter_expr += f" and end_time<={end_time + 10}" | |
# text retrieve stage | |
content_vector = self.text_encoder.infer([content])[0] | |
match_res = self.match( | |
collection_name=self.collection_name, | |
query_vectors=[content_vector], | |
query_field="content_vector", | |
output_fields=["content", "start_time", "end_time"], | |
res_size=res_size, | |
threshold=threshold, | |
filter_expr=filter_expr, | |
) | |
output = [] | |
for match in match_res[0]: | |
print(match) | |
output.append(match["entity"]) | |
return output | |