File size: 4,291 Bytes
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Any, Optional

from omagent_core.models.encoders.base import EncoderBase
from omagent_core.utils.error import VQLError
from omagent_core.utils.registry import registry
from pydantic import field_validator
from pymilvus import CollectionSchema, DataType, FieldSchema

from .milvus_handler import MilvusHandler


@registry.register_component()
class VideoHandler(MilvusHandler):
    collection_name: str
    text_encoder: Optional[EncoderBase] = None
    dim: int = None

    class Config:
        """Configuration for this pydantic object."""

        extra = "allow"
        arbitrary_types_allowed = True

    @field_validator("text_encoder", mode="before")
    @classmethod
    def init_encoder(cls, text_encoder):
        if isinstance(text_encoder, EncoderBase):
            return text_encoder
        elif isinstance(text_encoder, dict):
            return registry.get_encoder(text_encoder.get("name"))(**text_encoder)
        else:
            raise ValueError("text_encoder must be EncoderBase or Dict")

    def __init__(self, **data: Any) -> None:
        super().__init__(**data)

        self.dim = self.text_encoder.dim

        _uid = FieldSchema(
            name="_uid", dtype=DataType.INT64, is_primary=True, auto_id=True
        )
        video_md5 = FieldSchema(
            name="video_md5", dtype=DataType.VARCHAR, max_length=100
        )
        content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535)
        content_vector = FieldSchema(
            name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim
        )
        start_time = FieldSchema(
            name="start_time",
            dtype=DataType.FLOAT,
        )
        end_time = FieldSchema(
            name="end_time",
            dtype=DataType.FLOAT,
        )
        schema = CollectionSchema(
            fields=[_uid, video_md5, content, content_vector, start_time, end_time],
            description="video summary vector DB",
            enable_dynamic_field=True,
        )
        self.make_collection(self.collection_name, schema)

    def text_add(self, video_md5, content, start_time, end_time):

        if self.text_encoder is None:
            raise VQLError(500, detail="Missing text_encoder")
        content_vector = self.text_encoder.infer([content])[0]

        # upload_data = [
        #     [video_md5],
        #     [content],
        #     [content_vector],
        #     [start_time],
        #     [end_time],
        # ]
        upload_data = [
            {
                "video_md5": video_md5,
                "content": content,
                "content_vector": content_vector,
                "start_time": start_time,
                "end_time": end_time,
            }
        ]

        add_detail = self.do_add(self.collection_name, upload_data)
        # assert add_detail.succ_count == len(upload_data)

    def text_match(
        self,
        video_md5,
        content,
        threshold: float,
        start_time=None,
        end_time=None,
        res_size: int = 100,
    ):
        # search_query = {"size": res_size, "sort": [{"_score": "desc"}],
        #                 "include": ["content", "start_time", "end_time"]}
        filter_expr = ""
        if video_md5 is not None:
            filter_expr = f"video_md5=='{video_md5}'"
        if start_time is not None and end_time is not None:
            filter_expr += f" and (start_time>={max(0, start_time - 10)} and end_time<={end_time + 10})"
        elif start_time is not None:
            filter_expr += f" and start_time>={max(0, start_time - 10)}"
        elif end_time is not None:
            filter_expr += f" and end_time<={end_time + 10}"
        # text retrieve stage
        content_vector = self.text_encoder.infer([content])[0]
        match_res = self.match(
            collection_name=self.collection_name,
            query_vectors=[content_vector],
            query_field="content_vector",
            output_fields=["content", "start_time", "end_time"],
            res_size=res_size,
            threshold=threshold,
            filter_expr=filter_expr,
        )

        output = []
        for match in match_res[0]:
            print(match)
            output.append(match["entity"])

        return output