File size: 8,228 Bytes
18c0acd
 
 
 
 
 
 
 
 
 
 
c39d1b7
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
 
 
 
 
 
18c0acd
 
 
 
 
 
 
c39d1b7
 
 
18c0acd
 
 
 
 
 
c39d1b7
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import logging
from operator import itemgetter
from typing import Any, Dict, Optional, Type

# from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain.chains import create_extraction_chain_pydantic, create_tagging_chain_pydantic
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai import ChatOpenAI
from src.models.base.base_model import EvaluationChatModel

logger = logging.getLogger(__name__)


class EvaluationChatModelQA(EvaluationChatModel):

    class Input(BaseModel):
        response: str
        question: Optional[str] = Field(default="")

    class Output(BaseModel):
        evaluation: str = Field(
            alias="Evaluation",
            description="Summarize evaluation of the response (just if there is)",
            default="",
        )
        tips: str = Field(
            alias="Tips",
            description="tips about complexity and detailed mistakes correction (just if there are)",
            default="",
        )
        example: str = Field(
            alias="Example",
            description="Example of the response following the given  guidelines (just if there is)",
            default="",
        )

    class AnswerTagger(BaseModel):
        """
        Tags the answer considering the following aspects:

        - complexity
        """

        answer_complexity: int = Field(
            description="describes how complex the answer is. It is a number between 0 (simpler) and 10 (more complex)",
            enum=list(range(11)),
        )

    class MistakeExtractor(BaseModel):
        """
        Extracts the mistakes from the text.
        """

        grammar_mistake: Optional[str] = Field(
            description="Grammar syntax mistakes detected in text, just if there are, if not return empty string.",
            default="",
        )

    def __init__(
        self,
        level: str,
        openai_api_key: SecretStr,
        eval_model: str = "gpt-3.5-turbo",
        chat_temperature: float = 0.3,
        eval_temperature: float = 0.3,
    ) -> None:
        """
        Initializes the class with the given parameters.

        Args:
            exam_prompt (str): The prompt for the exam.
            level (str): The level of the exam.
            openai_api_key (SecretStr): The API key for OpenAI.
            eval_model (str, optional): The model to use for evaluation. Defaults to "gpt-3.5-turbo".
            chat_temperature (float, optional): The temperature to use for chat. Defaults to 0.3.
            eval_temperature (float, optional): The temperature to use for evaluation. Defaults to 0.3.

        Returns:
            None
        """
        super().__init__(level=level, openai_api_key=openai_api_key, chat_temperature=chat_temperature)

        self.checker_llm = ChatOpenAI(api_key=self.openai_api_key, temperature=eval_temperature, model=eval_model)
        self.prompt = self._get_system_prompt()

        self.chain = self._create_chain()

        self.config = RunnableConfig({})
        # {"callbacks": [ConsoleCallbackHandler()]}

    def _get_multi_chain_dict(self) -> Dict[str, Any]:
        """
        Returns a dictionary containing three chains for extracting mistakes, tagging responses, and retrieving relevant information from an item.

        The dictionary has the following keys:
        - "tags": A chain for tagging responses.
        - "extraction": A chain for extracting mistakes.
        - "base_response": A function that retrieves the "base_response" field from an item.
        - "question": A function that retrieves the "question" field from an item.

        The "tags" chain is created using the `create_tagging_chain_pydantic` function, with the `AnswerTagger` pydantic schema and a prompt template.
        The "extraction" chain is created using the `create_extraction_chain_pydantic` function, with the `MistakeExtractor` pydantic schema and a prompt template.

        Returns:
            dict: A dictionary containing the three chains and the relevant item getter functions.
        """
        chain_extractor = create_extraction_chain_pydantic(
            pydantic_schema=self.MistakeExtractor,
            llm=self.checker_llm,
            prompt=PromptTemplate(
                template="Extract the mistakes found on: {base_response}", input_variables=["base_response"]
            ),
        )

        chain_tagger = create_tagging_chain_pydantic(
            pydantic_schema=self.AnswerTagger,
            llm=self.checker_llm,
            prompt=PromptTemplate(
                template="Tag the given response: {base_response}", input_variables=["base_response"]
            ),
        )

        return {
            "tags": chain_tagger,
            "extraction": chain_extractor,
            "base_response": itemgetter("base_response"),
            "question": itemgetter("question"),
        }

    def _get_system_prompt(self) -> ChatPromptTemplate:
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "{format_instructions}"
                    """You are an excellent english teacher. You teach spanish people to speak English.
                    You will do the following tasks for that purpose:
                    - You will evaluate the quality of the responses given by the user.
                    - if a non-related text is asked you will politely decline to answer and you will
                    suggest to stay on the topic.
                    - Limit your evaluation to just once per interaction at a time.
                    - guide client to adquire fluency for English exams.
                    """,
                ),
                ("ai", "AI Question: {question}"),
                ("human", "Human Response: {base_response}"),
                ("ai", "Tags:\n{tags}\n\\Extraction:\n{extraction}"),
                (
                    "system",
                    f"""Generate a final response given the AI Question, the Human Response and the detected Tags and Extraction:
                    - correct mistakes (just if there are) based on the Human Response
                        given by the MistakeExtractor according to the {self.level} english level.
                    - give relevant and related tips based on how complete the Human Response is given the punctuation of
                        the answer_complexity AnswerTagger Tags. Best responses are 7, 8 point responses since they are neither too simple
                        nor too complex.
                        - With too simple responses (1, 2, 3, 4 points) you must suggest an alternative response with a higher
                        degree of complexity.
                        - With too complex responses (9, 10 points) you must highlight which part of the response should be ignored.
                    - An excellent response must be grammatically correct, complete and clear.
                    - You will propose an excellent example answer to the AI Question given the above guidelines.
                    """,
                ),
            ]
        )
        return prompt

    def _get_output_parser(self, pydantic_schema: Type[BaseModel]) -> PydanticOutputParser[Any]:

        return PydanticOutputParser(pydantic_object=pydantic_schema)

    def _create_chain(self) -> Runnable[Any, Any]:
        response_parser = self._get_output_parser(self.Output)
        prompt = self.prompt.partial(format_instructions=response_parser.get_format_instructions())

        final_responder = prompt | self.chat_llm | response_parser

        return self._get_multi_chain_dict() | final_responder

    def predict(self, response: str, question: str) -> Dict[str, str]:

        input_model = self.Input(response=response, question=question)

        result = self.chain.invoke(
            {"base_response": input_model.response, "question": input_model.question}, config=self.config
        )

        return result.dict(by_alias=True)