File size: 5,706 Bytes
18c0acd
64acad1
18c0acd
 
 
 
 
 
c39d1b7
64acad1
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
18c0acd
 
 
 
 
 
 
 
 
 
 
c39d1b7
18c0acd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from operator import itemgetter
from typing import Any, Dict

from langchain.prompts import ChatPromptTemplate
from langchain.schema import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_openai import ChatOpenAI
from src.models.evaluator.text_evaluator import EvaluationChatModelQA
from typing_extensions import override


class EvaluationChatModelImg(EvaluationChatModelQA):

    class Input(BaseModel):
        image_url: str
        user_desc: str

    class Output(BaseModel):
        evaluation: str = Field(
            alias="Evaluation", description="Summarize evaluation of the response (just if there is)"
        )
        tips: str = Field(
            alias="Tips", description="tips about complexity and detailed mistakes correction (just if there are)"
        )
        example: str = Field(
            alias="Example",
            description="Example of a response based on the AI image description following the given guidelines (just if there is)",
        )

    def __init__(
        self,
        eval_model: str,
        level: str,
        openai_api_key: SecretStr,
        chat_temperature: float = 0.3,
        eval_temperature: float = 0.3,
    ) -> None:
        """
        Initializes the class with the given parameters.

        Args:
            eval_model (str): The vision model to use for image description.
            level (str): The level of the exam.
            openai_api_key (SecretStr): The API key for OpenAI.
            chat_temperature (float, optional): The temperature to use for chat. Defaults to 0.3.
            eval_temperature (float, optional): The temperature to use for image model. Defaults to 0.3.

        Returns:
            None
        """
        self.vision_model = ChatOpenAI(
            temperature=eval_temperature, model=eval_model, max_tokens=1024, api_key=openai_api_key
        )
        super().__init__(level, openai_api_key=openai_api_key, chat_temperature=chat_temperature)

    def _get_system_prompt(self) -> ChatPromptTemplate:
        return ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "{format_instructions}"
                    f"""You will evaluate how close is the the image
                    description given by the user compared to the image
                    description given by the ai vision model. Take into account
                    that this image description will be used for evaluating
                    the user how well can describe this image in the context
                    of an {self.level} Speaking English exam. An Example of
                    an image description will also be provided based on the AI image description.""",
                ),
                ("ai", "AI image description: {ai_img_desc}"),
                ("human", "User image description: {base_response}"),
                ("ai", "Tags:\n{tags}\n\\Extraction:\n{extraction}"),
                (
                    "system",
                    f"""Generate a final response given the question, its response and the detected Tags and Extraction:
                    - correct mistakes (just if there are) based on the response
                        given by the MistakeExtractor according to the {self.level} english level.
                    - give relevant and related tips based on how complete the answer is given the punctuation of
                        the AnswerTagger. Best responses are 7, 8 point questions since they are neither too simple
                        nor too complex.
                        - With too simple questions (1, 2, 3, 4 points) you must suggest an alternative response with a higher
                        degree of complexity.
                        - With too complex questions (9, 10 points) you must highlight which part of the response should be ignored.
                    - An excellent response must be grammatically correct, complete and clear.
                    - Provide an example answer to the question given the above guidelines and the AI image description.
                    """,
                ),
            ]
        )

    def _get_multi_chain_dict(self) -> Dict[str, Any]:
        multi_chain_dict = super()._get_multi_chain_dict()
        multi_chain_dict = {key: multi_chain_dict[key] for key in ["tags", "extraction"]}

        multi_chain_dict.update(
            {
                "ai_img_desc": itemgetter("image_url") | self.vision_model | StrOutputParser(),
                "base_response": itemgetter("base_response"),
            }
        )

        return multi_chain_dict

    @override
    def predict(self, user_desc: str, image_url: str) -> Dict[str, str]:
        """Make a prediction using the provided input.

        Args:
            user_desc (str): The user description.
            image_url (str): The image url.

        Returns:
            Dict: The output of the prediction.
        """
        input_model = self.Input(user_desc=user_desc, image_url=image_url)
        vision_model_input = [
            HumanMessage(
                content=[
                    {"type": "text", "text": "What is this image showing?"},
                    {
                        "type": "image_url",
                        "image_url": {"url": input_model.image_url, "detail": "auto"},
                    },
                ]
            )
        ]
        result = self.chain.invoke(
            {"base_response": input_model.user_desc, "image_url": vision_model_input}, config=self.config
        )
        return result.dict(by_alias=True)