Spaces:
Sleeping
Sleeping
File size: 3,195 Bytes
57cf043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import logging
from logging import Logger
import os
from fastapi import Depends
from common.configuration import Configuration
from components.llm.common import LlmParams
from components.llm.deepinfra_api import DeepInfraApi
from components.services.dataset import DatasetService
from components.embedding_extraction import EmbeddingExtractor
from components.datasets.dispatcher import Dispatcher
from components.services.document import DocumentService
from components.services.acronym import AcronymService
from components.services.llm_config import LLMConfigService
from typing import Annotated
from sqlalchemy.orm import sessionmaker, Session
from common.db import session_factory
from components.services.llm_prompt import LlmPromptService
def get_config() -> Configuration:
return Configuration(os.environ.get('CONFIG_PATH', 'config_dev.yaml'))
def get_db() -> sessionmaker:
return session_factory
def get_logger() -> Logger:
return logging.getLogger(__name__)
def get_embedding_extractor(config: Annotated[Configuration, Depends(get_config)]) -> EmbeddingExtractor:
return EmbeddingExtractor(
config.db_config.faiss.model_embedding_path,
config.db_config.faiss.device,
)
def get_dataset_service(
vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)]
) -> DatasetService:
return DatasetService(vectorizer, config, db)
def get_dispatcher(vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
config: Annotated[Configuration, Depends(get_config)],
logger: Annotated[Logger, Depends(get_logger)],
dataset_service: Annotated[DatasetService, Depends(get_dataset_service)]) -> Dispatcher:
return Dispatcher(vectorizer, config, logger, dataset_service)
def get_acronym_service(db: Annotated[Session, Depends(get_db)]) -> AcronymService:
return AcronymService(db)
def get_document_service(dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
config: Annotated[Configuration, Depends(get_config)],
db: Annotated[sessionmaker, Depends(get_db)]) -> DocumentService:
return DocumentService(dataset_service, config, db)
def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
return LLMConfigService(db)
def get_llm_service(config: Annotated[Configuration, Depends(get_config)]) -> DeepInfraApi:
llm_params = LlmParams(**{
"url": config.llm_config.base_url,
"model": config.llm_config.model,
"tokenizer": config.llm_config.tokenizer,
"type": "deepinfra",
"default": True,
"predict_params": None, #должны задаваться при каждом запросе
"api_key": os.environ.get(config.llm_config.api_key_env),
"context_length": 128000
})
return DeepInfraApi(params=llm_params)
def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
return LlmPromptService(db) |