File size: 3,428 Bytes
7bd11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4

from loguru import logger
from pydantic import (
    BaseModel,
    DirectoryPath,
    Field,
    field_validator,
    ConfigDict,
    ValidationInfo,
)

from app.config.models.openai import OpenAIModelConfig
from app.config.models.vertexai import VertexAIModelConfig


def create_uuid() -> str:
    return str(uuid4())


class Document(BaseModel):
    """Interface for interacting with a document."""

    page_content: str
    metadata: dict = Field(default_factory=dict)


class SentenseTransformerEmbeddingModel(BaseModel):
    model_config = ConfigDict()
    model_config["protected_namespaces"] = ()
    model_name: str
    additional_kwargs: dict = Field(default_factory=dict)


class DocumentPathSettings(BaseModel):
    doc_path: Union[DirectoryPath, str]
    additional_parser_settings: Dict[str, Any] = Field(default_factory=dict)
    passage_prefix: str = ""
    label: str = ""  # Optional label, will be included in the metadata


class EmbedddingsSpladeConfig(BaseModel):
    n_batch: int = 3


class EmbeddingsConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    embedding_model: SentenseTransformerEmbeddingModel
    embeddings_path: Union[DirectoryPath, str]
    document_settings: List[DocumentPathSettings]
    chunk_sizes: List[int] = [1024]
    splade_config: EmbedddingsSpladeConfig = EmbedddingsSpladeConfig(n_batch=5)

    @property
    def labels(self) -> List[str]:
        """Returns list of labels in document settings"""
        return [setting.label for setting in self.document_settings if setting.label]


class SemanticSearchConfig(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
    max_k: int = 15
    max_char_size: int = 2048
    query_prefix: str = ""


class LLMConfig(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
    model_config["protected_namespaces"] = ()

    type: str
    params: dict

    @field_validator("params")
    def validate_params(cls, value, info: ValidationInfo):
        values = info.data
        type_ = values.get("type")
        if type_ == 'vertexai':
            config = VertexAIModelConfig(
                **value
            )  # An attempt to force conversion to the required model config
        if type_ == 'openai':
            config = OpenAIModelConfig(
                **value
            )
        logger.info(
            f"Loading model paramaters in configuration class {LlamaModelConfig.__name__}"
        )
        return config


class ResponseModel(BaseModel):
    id: UUID = Field(default_factory=create_uuid)
    question: str
    response: str
    average_score: float
    semantic_search: List[str] = Field(default_factory=list)
    hyde_response: str = ""


class Config(BaseModel):
    cache_folder: Path
    embeddings: EmbeddingsConfig
    semantic_search: SemanticSearchConfig
    llm: Optional[LLMConfig] = None

    def check_embeddings_exist(self) -> bool:
        """Checks if embedings exist in the specified folder"""

        p_splade = (
                Path(self.embeddings.embeddings_path) / "splade" / "splade_embeddings.npz"
        )
        p_embeddings = Path(self.embeddings.embeddings_path)
        all_parquets = list(p_embeddings.glob("*.parquet"))
        return p_splade.exists() and len(all_parquets) > 0