Spaces:
Sleeping
Sleeping
from operator import itemgetter | |
from typing import Any, Dict | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import HumanMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr | |
from langchain_openai import ChatOpenAI | |
from src.models.evaluator.text_evaluator import EvaluationChatModelQA | |
from typing_extensions import override | |
class EvaluationChatModelImg(EvaluationChatModelQA): | |
class Input(BaseModel): | |
image_url: str | |
user_desc: str | |
class Output(BaseModel): | |
evaluation: str = Field( | |
alias="Evaluation", description="Summarize evaluation of the response (just if there is)" | |
) | |
tips: str = Field( | |
alias="Tips", description="tips about complexity and detailed mistakes correction (just if there are)" | |
) | |
example: str = Field( | |
alias="Example", | |
description="Example of a response based on the AI image description following the given guidelines (just if there is)", | |
) | |
def __init__( | |
self, | |
eval_model: str, | |
level: str, | |
openai_api_key: SecretStr, | |
chat_temperature: float = 0.3, | |
eval_temperature: float = 0.3, | |
) -> None: | |
""" | |
Initializes the class with the given parameters. | |
Args: | |
eval_model (str): The vision model to use for image description. | |
level (str): The level of the exam. | |
openai_api_key (SecretStr): The API key for OpenAI. | |
chat_temperature (float, optional): The temperature to use for chat. Defaults to 0.3. | |
eval_temperature (float, optional): The temperature to use for image model. Defaults to 0.3. | |
Returns: | |
None | |
""" | |
self.vision_model = ChatOpenAI( | |
temperature=eval_temperature, model=eval_model, max_tokens=1024, api_key=openai_api_key | |
) | |
super().__init__(level, openai_api_key=openai_api_key, chat_temperature=chat_temperature) | |
def _get_system_prompt(self) -> ChatPromptTemplate: | |
return ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"{format_instructions}" | |
f"""You will evaluate how close is the the image | |
description given by the user compared to the image | |
description given by the ai vision model. Take into account | |
that this image description will be used for evaluating | |
the user how well can describe this image in the context | |
of an {self.level} Speaking English exam. An Example of | |
an image description will also be provided based on the AI image description.""", | |
), | |
("ai", "AI image description: {ai_img_desc}"), | |
("human", "User image description: {base_response}"), | |
("ai", "Tags:\n{tags}\n\\Extraction:\n{extraction}"), | |
( | |
"system", | |
f"""Generate a final response given the question, its response and the detected Tags and Extraction: | |
- correct mistakes (just if there are) based on the response | |
given by the MistakeExtractor according to the {self.level} english level. | |
- give relevant and related tips based on how complete the answer is given the punctuation of | |
the AnswerTagger. Best responses are 7, 8 point questions since they are neither too simple | |
nor too complex. | |
- With too simple questions (1, 2, 3, 4 points) you must suggest an alternative response with a higher | |
degree of complexity. | |
- With too complex questions (9, 10 points) you must highlight which part of the response should be ignored. | |
- An excellent response must be grammatically correct, complete and clear. | |
- Provide an example answer to the question given the above guidelines and the AI image description. | |
""", | |
), | |
] | |
) | |
def _get_multi_chain_dict(self) -> Dict[str, Any]: | |
multi_chain_dict = super()._get_multi_chain_dict() | |
multi_chain_dict = {key: multi_chain_dict[key] for key in ["tags", "extraction"]} | |
multi_chain_dict.update( | |
{ | |
"ai_img_desc": itemgetter("image_url") | self.vision_model | StrOutputParser(), | |
"base_response": itemgetter("base_response"), | |
} | |
) | |
return multi_chain_dict | |
def predict(self, user_desc: str, image_url: str) -> Dict[str, str]: | |
"""Make a prediction using the provided input. | |
Args: | |
user_desc (str): The user description. | |
image_url (str): The image url. | |
Returns: | |
Dict: The output of the prediction. | |
""" | |
input_model = self.Input(user_desc=user_desc, image_url=image_url) | |
vision_model_input = [ | |
HumanMessage( | |
content=[ | |
{"type": "text", "text": "What is this image showing?"}, | |
{ | |
"type": "image_url", | |
"image_url": {"url": input_model.image_url, "detail": "auto"}, | |
}, | |
] | |
) | |
] | |
result = self.chain.invoke( | |
{"base_response": input_model.user_desc, "image_url": vision_model_input}, config=self.config | |
) | |
return result.dict(by_alias=True) | |