File size: 3,572 Bytes
57cf043
 
86c402d
 
 
57cf043
86c402d
 
57cf043
 
86c402d
 
 
57cf043
 
 
 
86c402d
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
86c402d
57cf043
 
86c402d
 
 
57cf043
 
 
 
 
 
86c402d
 
 
 
 
 
 
 
 
 
 
57cf043
86c402d
57cf043
86c402d
 
 
57cf043
 
86c402d
 
 
 
 
 
 
57cf043
 
86c402d
 
 
 
 
57cf043
 
 
 
 
 
86c402d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57cf043
 
86c402d
57cf043
86c402d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
import os
from logging import Logger
from typing import Annotated

from fastapi import Depends
from ntr_text_fragmentation import InjectionBuilder
from sqlalchemy.orm import Session, sessionmaker

from common.configuration import Configuration
from common.db import session_factory
from components.dbo.chunk_repository import ChunkRepository
from components.embedding_extraction import EmbeddingExtractor
from components.llm.common import LlmParams
from components.llm.deepinfra_api import DeepInfraApi
from components.services.dataset import DatasetService
from components.services.document import DocumentService
from components.services.entity import EntityService
from components.services.llm_config import LLMConfigService
from components.services.llm_prompt import LlmPromptService


def get_config() -> Configuration:
    return Configuration(os.environ.get('CONFIG_PATH', 'config_dev.yaml'))


def get_db() -> sessionmaker:
    return session_factory


def get_logger() -> Logger:
    return logging.getLogger(__name__)


def get_embedding_extractor(
    config: Annotated[Configuration, Depends(get_config)],
) -> EmbeddingExtractor:
    return EmbeddingExtractor(
        config.db_config.faiss.model_embedding_path,
        config.db_config.faiss.device,
    )


def get_chunk_repository(db: Annotated[Session, Depends(get_db)]) -> ChunkRepository:
    return ChunkRepository(db)


def get_injection_builder(
    chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
) -> InjectionBuilder:
    return InjectionBuilder(chunk_repository)


def get_entity_service(
    vectorizer: Annotated[EmbeddingExtractor, Depends(get_embedding_extractor)],
    chunk_repository: Annotated[ChunkRepository, Depends(get_chunk_repository)],
    config: Annotated[Configuration, Depends(get_config)],
) -> EntityService:
    """Получение сервиса для работы с сущностями через DI."""
    return EntityService(vectorizer, chunk_repository, config)


def get_dataset_service(
    entity_service: Annotated[EntityService, Depends(get_entity_service)],
    config: Annotated[Configuration, Depends(get_config)],
    db: Annotated[sessionmaker, Depends(get_db)],
) -> DatasetService:
    """Получение сервиса для работы с датасетами через DI."""
    return DatasetService(entity_service, config, db)


def get_document_service(
    dataset_service: Annotated[DatasetService, Depends(get_dataset_service)],
    config: Annotated[Configuration, Depends(get_config)],
    db: Annotated[sessionmaker, Depends(get_db)],
) -> DocumentService:
    return DocumentService(dataset_service, config, db)


def get_llm_config_service(db: Annotated[Session, Depends(get_db)]) -> LLMConfigService:
    return LLMConfigService(db)


def get_llm_service(
    config: Annotated[Configuration, Depends(get_config)],
) -> DeepInfraApi:

    llm_params = LlmParams(
        **{
            "url": config.llm_config.base_url,
            "model": config.llm_config.model,
            "tokenizer": config.llm_config.tokenizer,
            "type": "deepinfra",
            "default": True,
            "predict_params": None,  # должны задаваться при каждом запросе
            "api_key": os.environ.get(config.llm_config.api_key_env),
            "context_length": 128000,
        }
    )
    return DeepInfraApi(params=llm_params)


def get_llm_prompt_service(db: Annotated[Session, Depends(get_db)]) -> LlmPromptService:
    return LlmPromptService(db)