File size: 3,878 Bytes
18c0acd
 
 
c39d1b7
18c0acd
 
c39d1b7
 
 
 
 
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
 
 
 
18c0acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
 
18c0acd
c39d1b7
 
 
 
18c0acd
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
 
18c0acd
 
 
 
 
 
 
 
 
 
c39d1b7
 
 
18c0acd
 
 
 
 
 
 
 
c39d1b7
 
 
 
 
 
18c0acd
c39d1b7
 
 
 
 
 
 
18c0acd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from abc import ABC, abstractmethod
from typing import Any, List, Optional

from langchain.schema import BaseMessage
from langchain_core.pydantic_v1 import SecretStr

from src.models.base.base_model import ChainGenerator, ContentGenerator, EvaluationChatModel
from src.models.evaluator.img_evaluator import EvaluationChatModelImg
from src.models.evaluator.text_evaluator import EvaluationChatModelQA
from src.models.gen.img_generator import ImgGenerator
from src.models.gen.text_generator import QuestionGenerator


class ModelFactory(ABC):

    @abstractmethod
    def create_model(self, *args: Any, **kwargs: Any) -> ChainGenerator:
        """
        An abstract method to create a model, with the return types EvaluationChatModel or ChainGenerator.
        """


class EvaluationChatModelFactory(ModelFactory):

    def create_model(
        self,
        model_class: str,
        openai_api_key: SecretStr,
        **kwargs: Any,
    ) -> EvaluationChatModel:
        """
        Create a model based on the provided model class and OpenAI API key.

        Args:
            model_class (str): The type of model to create.
            openai_api_key (SecretStr): The API key for OpenAI.
            **kwargs (Any): Additional keyword arguments.

        Returns:
            EvaluationChatModel: The created evaluation chat model.

        Raises:
            ValueError: If an invalid model class is provided.
        """
        match model_class:
            case "qa":
                return EvaluationChatModelQA(
                    openai_api_key=openai_api_key,
                    **kwargs,
                )
            case "img_desc":
                return EvaluationChatModelImg(
                    openai_api_key=openai_api_key,
                    **kwargs,
                )
            case _:
                raise ValueError("Invalid model class provided")


class GeneratorModelFactory(ModelFactory):

    def create_model(
        self,
        model_class: str,
        openai_api_key: SecretStr,
        history_chat: Optional[List[BaseMessage]] = None,
        n_messages_memory: int = 20,
        prompt_model: str = "gpt-3.5-turbo",
        prompt_model_temperature: float = 0.3,
        img_size: str = "256x256",
        **kwargs: Any,
    ) -> ContentGenerator:
        """
        Generate a model based on the specified model class and parameters.

        Parameters:
            model_class (str): The class of the model to create.
            openai_api_key (SecretStr): The API key for OpenAI.
            history_chat (Optional[list], optional): List of chat history. Defaults to None.
            n_messages_memory (int, optional): Number of messages to keep in memory. Defaults to 20.
            prompt_model (str, optional): The prompt model. Defaults to "gpt-3.5-turbo".
            prompt_model_temperature (float, optional): The prompt model temperature. Defaults to 0.3.
            img_size (str, optional): The size of the image. Defaults to "256x256".
            **kwargs (Any): Additional keyword arguments.

        Returns:
            ContentGenerator: A generator for the specified model class.
        """
        match model_class:
            case "qa":
                return QuestionGenerator(
                    openai_api_key=openai_api_key,
                    history_chat=history_chat or [],
                    n_messages_memory=n_messages_memory,
                    **kwargs,
                )
            case "img_desc":
                return ImgGenerator(
                    openai_api_key=openai_api_key,
                    prompt_model=prompt_model,
                    prompt_model_temperature=prompt_model_temperature,
                    img_size=img_size,
                    **kwargs,
                )
            case _:
                raise ValueError("Invalid model class provided")