File size: 3,195 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import logging
from logging import Logger
import os
from fastapi import Depends

from common.configuration import Configuration
from components.llm.common import LlmParams
from components.llm.deepinfra_api import DeepInfraApi
from components.services.dataset import DatasetService
from components.embedding_extraction import EmbeddingExtractor
from components.datasets.dispatcher import Dispatcher
from components.services.document import DocumentService
from components.services.acronym import AcronymService
from components.services.llm_config import LLMConfigService

from typing import Annotated
from sqlalchemy.orm import sessionmaker, Session
from common.db import session_factory
from components.services.llm_prompt import LlmPromptService


def get_config() -> Configuration:
    return Configuration(os.environ.get('CONFIG_PATH', 'config_dev.yaml'))


def get_db() -> sessionmaker:
    return session_factory


def get_logger() -> Logger:
    return logging.getLogger(__name__) 


def get_embedding_extractor(config: Annotated[Configuration, Depends(get_config)]) -> EmbeddingExtractor:
    return EmbeddingExtractor(
        config.db_config.faiss.model_embedding_path,
        config.db_config.faiss.device,
    )


def get_dataset_service(
    vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
    config: Annotated[Configuration, Depends(get_config)],
    db: Annotated[sessionmaker, Depends(get_db)]
) -> DatasetService:
    return DatasetService(vectorizer, config, db)

def get_dispatcher(vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)], 
                   config: Annotated[Configuration, Depends(get_config)],              
                   logger: Annotated[Logger, Depends(get_logger)],
                   dataset_service: Annotated[DatasetService, Depends(get_dataset_service)]) -> Dispatcher:
    return Dispatcher(vectorizer, config, logger, dataset_service)


def get_acronym_service(db: Annotated[Session, Depends(get_db)]) -> AcronymService:
    return AcronymService(db)


def get_document_service(dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
                        config: Annotated[Configuration, Depends(get_config)],
                        db: Annotated[sessionmaker, Depends(get_db)]) -> DocumentService:
    return DocumentService(dataset_service, config, db)


def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
    return LLMConfigService(db)

def get_llm_service(config: Annotated[Configuration, Depends(get_config)]) -> DeepInfraApi:
    
    llm_params = LlmParams(**{
        "url": config.llm_config.base_url,
        "model": config.llm_config.model,
        "tokenizer": config.llm_config.tokenizer,
        "type": "deepinfra",
        "default": True,
        "predict_params": None, #должны задаваться при каждом запросе
        "api_key": os.environ.get(config.llm_config.api_key_env),
        "context_length": 128000
    })
    return DeepInfraApi(params=llm_params)

def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
    return LlmPromptService(db)