linguAIcoach / src /models /evaluator /text_evaluator.py
alvaroalon2's picture
feat: parameterized models
c39d1b7
import logging
from operator import itemgetter
from typing import Any, Dict, Optional, Type
# from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain.chains import create_extraction_chain_pydantic, create_tagging_chain_pydantic
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai import ChatOpenAI
from src.models.base.base_model import EvaluationChatModel
logger = logging.getLogger(__name__)
class EvaluationChatModelQA(EvaluationChatModel):
class Input(BaseModel):
response: str
question: Optional[str] = Field(default="")
class Output(BaseModel):
evaluation: str = Field(
alias="Evaluation",
description="Summarize evaluation of the response (just if there is)",
default="",
)
tips: str = Field(
alias="Tips",
description="tips about complexity and detailed mistakes correction (just if there are)",
default="",
)
example: str = Field(
alias="Example",
description="Example of the response following the given guidelines (just if there is)",
default="",
)
class AnswerTagger(BaseModel):
"""
Tags the answer considering the following aspects:
- complexity
"""
answer_complexity: int = Field(
description="describes how complex the answer is. It is a number between 0 (simpler) and 10 (more complex)",
enum=list(range(11)),
)
class MistakeExtractor(BaseModel):
"""
Extracts the mistakes from the text.
"""
grammar_mistake: Optional[str] = Field(
description="Grammar syntax mistakes detected in text, just if there are, if not return empty string.",
default="",
)
def __init__(
self,
level: str,
openai_api_key: SecretStr,
eval_model: str = "gpt-3.5-turbo",
chat_temperature: float = 0.3,
eval_temperature: float = 0.3,
) -> None:
"""
Initializes the class with the given parameters.
Args:
exam_prompt (str): The prompt for the exam.
level (str): The level of the exam.
openai_api_key (SecretStr): The API key for OpenAI.
eval_model (str, optional): The model to use for evaluation. Defaults to "gpt-3.5-turbo".
chat_temperature (float, optional): The temperature to use for chat. Defaults to 0.3.
eval_temperature (float, optional): The temperature to use for evaluation. Defaults to 0.3.
Returns:
None
"""
super().__init__(level=level, openai_api_key=openai_api_key, chat_temperature=chat_temperature)
self.checker_llm = ChatOpenAI(api_key=self.openai_api_key, temperature=eval_temperature, model=eval_model)
self.prompt = self._get_system_prompt()
self.chain = self._create_chain()
self.config = RunnableConfig({})
# {"callbacks": [ConsoleCallbackHandler()]}
def _get_multi_chain_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary containing three chains for extracting mistakes, tagging responses, and retrieving relevant information from an item.
The dictionary has the following keys:
- "tags": A chain for tagging responses.
- "extraction": A chain for extracting mistakes.
- "base_response": A function that retrieves the "base_response" field from an item.
- "question": A function that retrieves the "question" field from an item.
The "tags" chain is created using the `create_tagging_chain_pydantic` function, with the `AnswerTagger` pydantic schema and a prompt template.
The "extraction" chain is created using the `create_extraction_chain_pydantic` function, with the `MistakeExtractor` pydantic schema and a prompt template.
Returns:
dict: A dictionary containing the three chains and the relevant item getter functions.
"""
chain_extractor = create_extraction_chain_pydantic(
pydantic_schema=self.MistakeExtractor,
llm=self.checker_llm,
prompt=PromptTemplate(
template="Extract the mistakes found on: {base_response}", input_variables=["base_response"]
),
)
chain_tagger = create_tagging_chain_pydantic(
pydantic_schema=self.AnswerTagger,
llm=self.checker_llm,
prompt=PromptTemplate(
template="Tag the given response: {base_response}", input_variables=["base_response"]
),
)
return {
"tags": chain_tagger,
"extraction": chain_extractor,
"base_response": itemgetter("base_response"),
"question": itemgetter("question"),
}
def _get_system_prompt(self) -> ChatPromptTemplate:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"{format_instructions}"
"""You are an excellent english teacher. You teach spanish people to speak English.
You will do the following tasks for that purpose:
- You will evaluate the quality of the responses given by the user.
- if a non-related text is asked you will politely decline to answer and you will
suggest to stay on the topic.
- Limit your evaluation to just once per interaction at a time.
- guide client to adquire fluency for English exams.
""",
),
("ai", "AI Question: {question}"),
("human", "Human Response: {base_response}"),
("ai", "Tags:\n{tags}\n\\Extraction:\n{extraction}"),
(
"system",
f"""Generate a final response given the AI Question, the Human Response and the detected Tags and Extraction:
- correct mistakes (just if there are) based on the Human Response
given by the MistakeExtractor according to the {self.level} english level.
- give relevant and related tips based on how complete the Human Response is given the punctuation of
the answer_complexity AnswerTagger Tags. Best responses are 7, 8 point responses since they are neither too simple
nor too complex.
- With too simple responses (1, 2, 3, 4 points) you must suggest an alternative response with a higher
degree of complexity.
- With too complex responses (9, 10 points) you must highlight which part of the response should be ignored.
- An excellent response must be grammatically correct, complete and clear.
- You will propose an excellent example answer to the AI Question given the above guidelines.
""",
),
]
)
return prompt
def _get_output_parser(self, pydantic_schema: Type[BaseModel]) -> PydanticOutputParser[Any]:
return PydanticOutputParser(pydantic_object=pydantic_schema)
def _create_chain(self) -> Runnable[Any, Any]:
response_parser = self._get_output_parser(self.Output)
prompt = self.prompt.partial(format_instructions=response_parser.get_format_instructions())
final_responder = prompt | self.chat_llm | response_parser
return self._get_multi_chain_dict() | final_responder
def predict(self, response: str, question: str) -> Dict[str, str]:
input_model = self.Input(response=response, question=question)
result = self.chain.invoke(
{"base_response": input_model.response, "question": input_model.question}, config=self.config
)
return result.dict(by_alias=True)