Spaces:
Configuration error
Configuration error
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Union | |
from uuid import UUID, uuid4 | |
from loguru import logger | |
from pydantic import ( | |
BaseModel, | |
DirectoryPath, | |
Field, | |
field_validator, | |
ConfigDict, | |
ValidationInfo, | |
) | |
from app.config.models.openai import OpenAIModelConfig | |
from app.config.models.vertexai import VertexAIModelConfig | |
def create_uuid() -> str: | |
return str(uuid4()) | |
class Document(BaseModel): | |
"""Interface for interacting with a document.""" | |
page_content: str | |
metadata: dict = Field(default_factory=dict) | |
class SentenseTransformerEmbeddingModel(BaseModel): | |
model_config = ConfigDict() | |
model_config["protected_namespaces"] = () | |
model_name: str | |
additional_kwargs: dict = Field(default_factory=dict) | |
class DocumentPathSettings(BaseModel): | |
doc_path: Union[DirectoryPath, str] | |
additional_parser_settings: Dict[str, Any] = Field(default_factory=dict) | |
passage_prefix: str = "" | |
label: str = "" # Optional label, will be included in the metadata | |
class EmbedddingsSpladeConfig(BaseModel): | |
n_batch: int = 3 | |
class EmbeddingsConfig(BaseModel): | |
model_config = ConfigDict(extra="forbid") | |
embedding_model: SentenseTransformerEmbeddingModel | |
embeddings_path: Union[DirectoryPath, str] | |
document_settings: List[DocumentPathSettings] | |
chunk_sizes: List[int] = [1024] | |
splade_config: EmbedddingsSpladeConfig = EmbedddingsSpladeConfig(n_batch=5) | |
def labels(self) -> List[str]: | |
"""Returns list of labels in document settings""" | |
return [setting.label for setting in self.document_settings if setting.label] | |
class SemanticSearchConfig(BaseModel): | |
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") | |
max_k: int = 15 | |
max_char_size: int = 2048 | |
query_prefix: str = "" | |
class LLMConfig(BaseModel): | |
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") | |
model_config["protected_namespaces"] = () | |
type: str | |
params: dict | |
def validate_params(cls, value, info: ValidationInfo): | |
values = info.data | |
type_ = values.get("type") | |
if type_ == 'vertexai': | |
config = VertexAIModelConfig( | |
**value | |
) # An attempt to force conversion to the required model config | |
if type_ == 'openai': | |
config = OpenAIModelConfig( | |
**value | |
) | |
logger.info( | |
f"Loading model paramaters in configuration class {LlamaModelConfig.__name__}" | |
) | |
return config | |
class ResponseModel(BaseModel): | |
id: UUID = Field(default_factory=create_uuid) | |
question: str | |
response: str | |
average_score: float | |
semantic_search: List[str] = Field(default_factory=list) | |
hyde_response: str = "" | |
class Config(BaseModel): | |
cache_folder: Path | |
embeddings: EmbeddingsConfig | |
semantic_search: SemanticSearchConfig | |
llm: Optional[LLMConfig] = None | |
def check_embeddings_exist(self) -> bool: | |
"""Checks if embedings exist in the specified folder""" | |
p_splade = ( | |
Path(self.embeddings.embeddings_path) / "splade" / "splade_embeddings.npz" | |
) | |
p_embeddings = Path(self.embeddings.embeddings_path) | |
all_parquets = list(p_embeddings.glob("*.parquet")) | |
return p_splade.exists() and len(all_parquets) > 0 | |