from fastapi import Depends, Request from transformers import (AutoModel, AutoModelForMaskedLM, AutoTokenizer, pipeline) from app.infrastructure.repository.query_search_repository import \ QuerySearchRepository from app.modules.denseEmbeddings.denseEmbeddings import DenseEmbeddings from app.modules.hybridSearcher.hybridSearcher import HybridSearcher from app.modules.querySearch.controllers.querySearch_controller import \ QuerySearchController from app.modules.querySearch.features.querySearch_feature import \ QuerySearchFeature from app.modules.questionAnswer.questionAnswer import QuestionAnswering from app.qdrant import QdrantConnectionDb def get_qdrant_connection_db() -> QdrantConnectionDb: return QdrantConnectionDb() def get_query_search_repository( qdrant_connection_db: QdrantConnectionDb = Depends(get_qdrant_connection_db), ): return QuerySearchRepository(qdrant_connection_db) def get_dense_model(request: Request) -> AutoModel: return request.scope["state"]["dense_model"] def get_sparse_model(request: Request) -> AutoModelForMaskedLM: return request.scope["state"]["sparse_model"] def get_dense_tokenizer(request: Request) -> AutoTokenizer: return request.scope["state"]["dense_tokenizer"] def get_sparse_tokenizer(request: Request) -> AutoTokenizer: return request.scope["state"]["sparse_tokenizer"] def get_dense_embeddings( dense_model: AutoModel = Depends(get_dense_model), dense_tokenizer: AutoTokenizer = Depends(get_dense_tokenizer), sparse_model: AutoModelForMaskedLM = Depends(get_sparse_model), sparse_tokenizer: AutoTokenizer = Depends(get_sparse_tokenizer), ): return DenseEmbeddings( dense_model=dense_model, dense_tokenizer=dense_tokenizer, sparse_model=sparse_model, sparse_tokenizer=sparse_tokenizer, ) def get_qa_pipeline(request: Request): return request.scope["state"]["qa_pipeline"] def get_question_ansering(qa_pipline: pipeline = Depends(get_qa_pipeline)): return QuestionAnswering(qa_pipline) def get_hybrid_searcher( dense_embeddings: DenseEmbeddings = Depends(get_dense_embeddings), query_search_repository: QuerySearchRepository = Depends( get_query_search_repository ), ): return HybridSearcher(dense_embeddings, query_search_repository) def get_query_search_feature( qa_pipeline: pipeline = Depends(get_qa_pipeline), hybrid_searcher: HybridSearcher = Depends(get_hybrid_searcher), question_answering: QuestionAnswering = Depends(get_question_ansering), ): return QuerySearchFeature(qa_pipeline, hybrid_searcher, question_answering) def get_query_search_controller( query_search_feature: QuerySearchFeature = Depends(get_query_search_feature), ): return QuerySearchController(query_search_feature)