from pathlib import Path from typing import Any, Dict, List, Optional, Union from uuid import UUID, uuid4 from loguru import logger from pydantic import ( BaseModel, DirectoryPath, Field, field_validator, ConfigDict, ValidationInfo, ) from app.config.models.openai import OpenAIModelConfig from app.config.models.vertexai import VertexAIModelConfig def create_uuid() -> str: return str(uuid4()) class Document(BaseModel): """Interface for interacting with a document.""" page_content: str metadata: dict = Field(default_factory=dict) class SentenseTransformerEmbeddingModel(BaseModel): model_config = ConfigDict() model_config["protected_namespaces"] = () model_name: str additional_kwargs: dict = Field(default_factory=dict) class DocumentPathSettings(BaseModel): doc_path: Union[DirectoryPath, str] additional_parser_settings: Dict[str, Any] = Field(default_factory=dict) passage_prefix: str = "" label: str = "" # Optional label, will be included in the metadata class EmbedddingsSpladeConfig(BaseModel): n_batch: int = 3 class EmbeddingsConfig(BaseModel): model_config = ConfigDict(extra="forbid") embedding_model: SentenseTransformerEmbeddingModel embeddings_path: Union[DirectoryPath, str] document_settings: List[DocumentPathSettings] chunk_sizes: List[int] = [1024] splade_config: EmbedddingsSpladeConfig = EmbedddingsSpladeConfig(n_batch=5) @property def labels(self) -> List[str]: """Returns list of labels in document settings""" return [setting.label for setting in self.document_settings if setting.label] class SemanticSearchConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") max_k: int = 15 max_char_size: int = 2048 query_prefix: str = "" class LLMConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") model_config["protected_namespaces"] = () type: str params: dict @field_validator("params") def validate_params(cls, value, info: ValidationInfo): values = info.data type_ = values.get("type") if type_ == 'vertexai': config = VertexAIModelConfig( **value ) # An attempt to force conversion to the required model config if type_ == 'openai': config = OpenAIModelConfig( **value ) logger.info( f"Loading model paramaters in configuration class {LlamaModelConfig.__name__}" ) return config class ResponseModel(BaseModel): id: UUID = Field(default_factory=create_uuid) question: str response: str average_score: float semantic_search: List[str] = Field(default_factory=list) hyde_response: str = "" class Config(BaseModel): cache_folder: Path embeddings: EmbeddingsConfig semantic_search: SemanticSearchConfig llm: Optional[LLMConfig] = None def check_embeddings_exist(self) -> bool: """Checks if embedings exist in the specified folder""" p_splade = ( Path(self.embeddings.embeddings_path) / "splade" / "splade_embeddings.npz" ) p_embeddings = Path(self.embeddings.embeddings_path) all_parquets = list(p_embeddings.glob("*.parquet")) return p_splade.exists() and len(all_parquets) > 0