File size: 2,848 Bytes
47b5f0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import Depends, Request
from transformers import (AutoModel, AutoModelForMaskedLM, AutoTokenizer,
                          pipeline)

from app.infrastructure.repository.query_search_repository import \
    QuerySearchRepository
from app.modules.denseEmbeddings.denseEmbeddings import DenseEmbeddings
from app.modules.hybridSearcher.hybridSearcher import HybridSearcher
from app.modules.querySearch.controllers.querySearch_controller import \
    QuerySearchController
from app.modules.querySearch.features.querySearch_feature import \
    QuerySearchFeature
from app.modules.questionAnswer.questionAnswer import QuestionAnswering
from app.qdrant import QdrantConnectionDb


def get_qdrant_connection_db() -> QdrantConnectionDb:
    return QdrantConnectionDb()


def get_query_search_repository(
    qdrant_connection_db: QdrantConnectionDb = Depends(get_qdrant_connection_db),
):
    return QuerySearchRepository(qdrant_connection_db)


def get_dense_model(request: Request) -> AutoModel:
    return request.scope["state"]["dense_model"]


def get_sparse_model(request: Request) -> AutoModelForMaskedLM:
    return request.scope["state"]["sparse_model"]


def get_dense_tokenizer(request: Request) -> AutoTokenizer:
    return request.scope["state"]["dense_tokenizer"]


def get_sparse_tokenizer(request: Request) -> AutoTokenizer:
    return request.scope["state"]["sparse_tokenizer"]


def get_dense_embeddings(
    dense_model: AutoModel = Depends(get_dense_model),
    dense_tokenizer: AutoTokenizer = Depends(get_dense_tokenizer),
    sparse_model: AutoModelForMaskedLM = Depends(get_sparse_model),
    sparse_tokenizer: AutoTokenizer = Depends(get_sparse_tokenizer),
):
    return DenseEmbeddings(
        dense_model=dense_model,
        dense_tokenizer=dense_tokenizer,
        sparse_model=sparse_model,
        sparse_tokenizer=sparse_tokenizer,
    )


def get_qa_pipeline(request: Request):
    return request.scope["state"]["qa_pipeline"]


def get_question_ansering(qa_pipline: pipeline = Depends(get_qa_pipeline)):
    return QuestionAnswering(qa_pipline)


def get_hybrid_searcher(
    dense_embeddings: DenseEmbeddings = Depends(get_dense_embeddings),
    query_search_repository: QuerySearchRepository = Depends(
        get_query_search_repository
    ),
):
    return HybridSearcher(dense_embeddings, query_search_repository)


def get_query_search_feature(
    qa_pipeline: pipeline = Depends(get_qa_pipeline),
    hybrid_searcher: HybridSearcher = Depends(get_hybrid_searcher),
    question_answering: QuestionAnswering = Depends(get_question_ansering),
):
    return QuerySearchFeature(qa_pipeline, hybrid_searcher, question_answering)


def get_query_search_controller(
    query_search_feature: QuerySearchFeature = Depends(get_query_search_feature),
):
    return QuerySearchController(query_search_feature)