File size: 5,075 Bytes
acb3380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import base64

from ctm.utils.exponential_backoff import exponential_backoff


class BaseProcessor(object):
    _processor_registry = {}  # type: ignore[var-annotated] # FIX ME

    @classmethod
    def register_processor(cls, processor_name):  # type: ignore[no-untyped-def] # FIX ME
        def decorator(subclass):  # type: ignore[no-untyped-def] # FIX ME
            cls._processor_registry[processor_name] = subclass
            return subclass

        return decorator

    def __new__(cls, processor_name, *args, **kwargs):  # type: ignore[no-untyped-def] # FIX ME
        if processor_name not in cls._processor_registry:
            raise ValueError(
                f"No processor registered with name '{processor_name}'"
            )
        return super(BaseProcessor, cls).__new__(
            cls._processor_registry[processor_name]
        )

    def set_model(self):  # type: ignore[no-untyped-def] # FIX ME
        raise NotImplementedError(
            "The 'set_model' method must be implemented in derived classes."
        )

    @staticmethod
    def process_image(image_path):  # type: ignore[no-untyped-def] # FIX ME
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    @staticmethod
    def process_audio(audio_path):  # type: ignore[no-untyped-def] # FIX ME
        return None

    @staticmethod
    def process_video(video_path):  # type: ignore[no-untyped-def] # FIX ME
        return None

    def ask(self, query, context, image_path, audio_path, video_path):  # type: ignore[no-untyped-def] # FIX ME
        gist = self.ask_info(  # type: ignore[no-untyped-call] # FIX ME
            query, context, image_path, audio_path, video_path
        )
        score = self.ask_score(query, gist, verbose=True)  # type: ignore[no-untyped-call] # FIX ME
        return gist, score

    @exponential_backoff(retries=5, base_wait_time=1)  # type: ignore[misc, no-untyped-call] # FIX ME
    def ask_relevance(self, query: str, gist: str) -> float:
        response = self.model.chat.completions.create(  # type: ignore[attr-defined] # FIX ME
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": "How related is the information ({}) with the query ({})? Answer with a number from 0 to 5 and do not add any other thing.".format(
                        gist, query
                    ),
                },
            ],
            max_tokens=50,
        )
        score = int(response.choices[0].message.content.strip()) / 5
        return score

    @exponential_backoff(retries=5, base_wait_time=1)  # type: ignore[misc, no-untyped-call] # FIX ME
    def ask_confidence(self, query: str, gist: str) -> float:
        response = self.model.chat.completions.create(  # type: ignore[attr-defined] # FIX ME
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": "How confidence do you think the information ({}) is a mustk? Answer with a number from 0 to 5 and do not add any other thing.".format(  # type: ignore[str-format] # FIX ME
                        gist, query
                    ),
                },
            ],
            max_tokens=50,
        )
        score = int(response.choices[0].message.content.strip()) / 5
        return score

    @exponential_backoff(retries=5, base_wait_time=1)  # type: ignore[misc, no-untyped-call] # FIX ME
    def ask_surprise(
        self, query: str, gist: str, history_gists: str = None  # type: ignore[assignment] # FIX ME
    ) -> float:
        response = self.model.chat.completions.create(  # type: ignore[attr-defined] # FIX ME
            model="gpt-4-0125-preview",
            messages=[
                {
                    "role": "user",
                    "content": "How surprise do you think the information ({}) is as an output of the processor? Answer with a number from 0 to 5 and do not add any other thing.".format(  # type: ignore[str-format] # FIX ME
                        gist, query
                    ),
                },
            ],
            max_tokens=50,
        )
        score = int(response.choices[0].message.content.strip()) / 5
        return score

    def ask_score(self, query, gist, verbose=False, *args, **kwargs):  # type: ignore[no-untyped-def] # FIX ME
        relevance = self.ask_relevance(query, gist, *args, **kwargs)
        confidence = self.ask_confidence(query, gist, *args, **kwargs)
        surprise = self.ask_surprise(query, gist, *args, **kwargs)
        if verbose:
            print(
                f"Relevance: {relevance}, Confidence: {confidence}, Surprise: {surprise}"
            )
        return relevance * confidence * surprise

    def ask_info(self, query, image_path, *args, **kwargs):  # type: ignore[no-untyped-def] # FIX ME
        raise NotImplementedError(
            "The 'ask_information' method must be implemented in derived classes."
        )