generic-chatbot-backend / common /dependencies.py
muryshev's picture
init
57cf043
raw
history blame
3.2 kB
import logging
from logging import Logger
import os
from fastapi import Depends
from common.configuration import Configuration
from components.llm.common import LlmParams
from components.llm.deepinfra_api import DeepInfraApi
from components.services.dataset import DatasetService
from components.embedding_extraction import EmbeddingExtractor
from components.datasets.dispatcher import Dispatcher
from components.services.document import DocumentService
from components.services.acronym import AcronymService
from components.services.llm_config import LLMConfigService
from typing import Annotated
from sqlalchemy.orm import sessionmaker, Session
from common.db import session_factory
from components.services.llm_prompt import LlmPromptService
def get_config() -> Configuration:
return Configuration(os.environ.get('CONFIG_PATH', 'config_dev.yaml'))
def get_db() -> sessionmaker:
return session_factory
def get_logger() -> Logger:
return logging.getLogger(__name__)
def get_embedding_extractor(config: Annotated[Configuration, Depends(get_config)]) -> EmbeddingExtractor:
return EmbeddingExtractor(
config.db_config.faiss.model_embedding_path,
config.db_config.faiss.device,
)
def get_dataset_service(
vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)]
) -> DatasetService:
return DatasetService(vectorizer, config, db)
def get_dispatcher(vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
config: Annotated[Configuration, Depends(get_config)],
logger: Annotated[Logger, Depends(get_logger)],
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)]) -> Dispatcher:
return Dispatcher(vectorizer, config, logger, dataset_service)
def get_acronym_service(db: Annotated[Session, Depends(get_db)]) -> AcronymService:
return AcronymService(db)
def get_document_service(dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)]) -> DocumentService:
return DocumentService(dataset_service, config, db)
def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
return LLMConfigService(db)
def get_llm_service(config: Annotated[Configuration, Depends(get_config)]) -> DeepInfraApi:
llm_params = LlmParams(**{
"url": config.llm_config.base_url,
"model": config.llm_config.model,
"tokenizer": config.llm_config.tokenizer,
"type": "deepinfra",
"default": True,
"predict_params": None, #должны задаваться при каждом запросе
"api_key": os.environ.get(config.llm_config.api_key_env),
"context_length": 128000
})
return DeepInfraApi(params=llm_params)
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
return LlmPromptService(db)