project / app /config /models /configs.py
kabylake's picture
commit
7bd11ed
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
from loguru import logger
from pydantic import (
BaseModel,
DirectoryPath,
Field,
field_validator,
ConfigDict,
ValidationInfo,
)
from app.config.models.openai import OpenAIModelConfig
from app.config.models.vertexai import VertexAIModelConfig
def create_uuid() -> str:
return str(uuid4())
class Document(BaseModel):
"""Interface for interacting with a document."""
page_content: str
metadata: dict = Field(default_factory=dict)
class SentenseTransformerEmbeddingModel(BaseModel):
model_config = ConfigDict()
model_config["protected_namespaces"] = ()
model_name: str
additional_kwargs: dict = Field(default_factory=dict)
class DocumentPathSettings(BaseModel):
doc_path: Union[DirectoryPath, str]
additional_parser_settings: Dict[str, Any] = Field(default_factory=dict)
passage_prefix: str = ""
label: str = "" # Optional label, will be included in the metadata
class EmbedddingsSpladeConfig(BaseModel):
n_batch: int = 3
class EmbeddingsConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
embedding_model: SentenseTransformerEmbeddingModel
embeddings_path: Union[DirectoryPath, str]
document_settings: List[DocumentPathSettings]
chunk_sizes: List[int] = [1024]
splade_config: EmbedddingsSpladeConfig = EmbedddingsSpladeConfig(n_batch=5)
@property
def labels(self) -> List[str]:
"""Returns list of labels in document settings"""
return [setting.label for setting in self.document_settings if setting.label]
class SemanticSearchConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
max_k: int = 15
max_char_size: int = 2048
query_prefix: str = ""
class LLMConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
model_config["protected_namespaces"] = ()
type: str
params: dict
@field_validator("params")
def validate_params(cls, value, info: ValidationInfo):
values = info.data
type_ = values.get("type")
if type_ == 'vertexai':
config = VertexAIModelConfig(
**value
) # An attempt to force conversion to the required model config
if type_ == 'openai':
config = OpenAIModelConfig(
**value
)
logger.info(
f"Loading model paramaters in configuration class {LlamaModelConfig.__name__}"
)
return config
class ResponseModel(BaseModel):
id: UUID = Field(default_factory=create_uuid)
question: str
response: str
average_score: float
semantic_search: List[str] = Field(default_factory=list)
hyde_response: str = ""
class Config(BaseModel):
cache_folder: Path
embeddings: EmbeddingsConfig
semantic_search: SemanticSearchConfig
llm: Optional[LLMConfig] = None
def check_embeddings_exist(self) -> bool:
"""Checks if embedings exist in the specified folder"""
p_splade = (
Path(self.embeddings.embeddings_path) / "splade" / "splade_embeddings.npz"
)
p_embeddings = Path(self.embeddings.embeddings_path)
all_parquets = list(p_embeddings.glob("*.parquet"))
return p_splade.exists() and len(all_parquets) > 0