import logging import re from logging import Logger from pathlib import Path from typing import Dict, List, Tuple import pandas as pd from elasticsearch.exceptions import ConnectionError from natasha import Doc, MorphVocab, NewsEmbedding, NewsMorphTagger, Segmenter from common.common import ( get_elastic_abbreviation_query, get_elastic_group_query, get_elastic_people_query, get_elastic_query, get_elastic_rocks_nn_query, get_elastic_segmentation_query, ) from common.configuration import Configuration, Query, SummaryChunks from common.constants import PROMPT, PROMPT_CLASSIFICATION from components.elastic import create_index_elastic_chunks from components.elastic.elasticsearch_client import ElasticsearchClient from components.embedding_extraction import EmbeddingExtractor from components.nmd.aggregate_answers import aggregate_answers from components.nmd.faiss_vector_search import FaissVectorSearch from components.nmd.llm_chunk_search import LLMChunkSearch from components.nmd.metadata_manager import MetadataManager from components.nmd.query_classification import QueryClassification from components.nmd.rancker import DocumentRanking from components.services.dataset import DatasetService logger = logging.getLogger(__name__) class Dispatcher: def __init__( self, embedding_model: EmbeddingExtractor, config: Configuration, logger: Logger, dataset_service: DatasetService ): self.dataset_service = dataset_service self.config = config self.embedder = embedding_model self.dataset_id = None self.try_load_default_dataset() self.llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger) if self.config.db_config.elastic.use_elastic: self.elastic_search = ElasticsearchClient( host=f'{config.db_config.elastic.es_host}', port=config.db_config.elastic.es_port, ) self.query_classification = QueryClassification( config.llm_config, PROMPT_CLASSIFICATION, logger ) self.segmenter = Segmenter() self.morph_tagger = NewsMorphTagger(NewsEmbedding()) self.morph_vocab = MorphVocab() def try_load_default_dataset(self): default_dataset = self.dataset_service.get_default_dataset() if default_dataset is not None and default_dataset.id is not None and default_dataset.id != self.dataset_id: logger.info(f'Reloading dataset {default_dataset.id}') self.reset_dataset(default_dataset.id) else: self.faiss_search = None self.meta_database = None def reset_dataset(self, dataset_id: int): logger.info(f'Reset dataset to dataset_id: {dataset_id}') data_path = Path(self.config.db_config.faiss.path_to_metadata) df = pd.read_pickle(data_path / str(dataset_id) / 'dataset.pkl') logger.info(f'Dataset loaded from {data_path / str(dataset_id) / "dataset.pkl"}') logger.info(f'Dataset shape: {df.shape}') self.faiss_search = FaissVectorSearch(self.embedder, df, self.config.db_config) logger.info(f'Faiss search initialized') self.meta_database = MetadataManager(df, logger) logger.info(f'Meta database initialized') if self.config.db_config.elastic.use_elastic: create_index_elastic_chunks(df, logger) logger.info(f'Elastic index created') self.document_ranking = DocumentRanking(df, self.config) logger.info(f'Document ranking initialized') def __vector_search(self, query: str) -> Dict[int, Dict]: """ Метод для поиска ближайших векторов по векторной базе Faiss. Args: query: Запрос пользователя. Returns: возвращает словарь chunks. """ query_embeds, scores, indexes = self.faiss_search.search_vectors(query) if self.config.db_config.ranker.use_ranging: indexes = self.document_ranking.doc_ranking(query_embeds, scores, indexes) return self.meta_database.search(indexes) def __elastic_search( self, query: str, index_name: str, search_function, size: int ) -> Dict: """ Метод для полнотекстового поиска. Args: query: Запрос пользователя. index_name: Наименование индекса. search_function: Функция запроса, зависит от индекса по которому нужно искать. size: Количество ближайших соседей, или размер выборки. Returns: Возвращает словарь c ответами. """ self.elastic_search.set_index(index_name) return self.elastic_search.search(query=search_function(query), size=size) @staticmethod def _get_indexes_full_text_elastic_search(elastic_answer: Dict) -> List: """ Метод позволяет получить индексы чанков, которые нашел elastic. Args: elastic_answer: Результаты полнотекстового поиска по чанкам. Returns: Возвращает список индексов. """ answer = [] for answer_dict in elastic_answer: answer.append(answer_dict['_source']['index']) return answer def _lemmatization_text(self, text: str): doc = Doc(text) doc.segment(self.segmenter) doc.tag_morph(self.morph_tagger) for token in doc.tokens: token.lemmatize(self.morph_vocab) return ' '.join([token.lemma for token in doc.tokens]) def _get_abbreviations(self, query: Query): query_abbreviation = query.query_abbreviation abbreviations_replaced = query.abbreviations_replaced try: if self.config.db_config.elastic.use_elastic: if ( self.config.db_config.search.abbreviation_search.use_abbreviation_search ): abbreviation_answer = self.__elastic_search( query=query.query, index_name=self.config.db_config.search.abbreviation_search.index_name, search_function=get_elastic_abbreviation_query, size=self.config.db_config.search.abbreviation_search.k_neighbors, ) if len(abbreviation_answer) > 0: query_lemmatization = self._lemmatization_text(query.query) for abbreviation in abbreviation_answer: abbreviation_lemmatization = self._lemmatization_text( abbreviation['_source']['text'].lower() ) if abbreviation_lemmatization in query_lemmatization: query_abbreviation_lemmatization = ( self._lemmatization_text(query_abbreviation) ) index = re.search( abbreviation_lemmatization, query_abbreviation_lemmatization, ).span()[1] space_index = query_abbreviation.find(' ', index) if space_index != -1: query_abbreviation = '{} ({}) {}'.format( query_abbreviation[:space_index], abbreviation["_source"]["abbreviation"], query_abbreviation[space_index:], ) else: query_abbreviation = '{} ({})'.format( query_abbreviation, abbreviation["_source"]["abbreviation"], ) except ConnectionError: logger.info("Connection Error Elasticsearch") return Query( query=query.query, query_abbreviation=query_abbreviation, abbreviations_replaced=abbreviations_replaced, ) def search_answer(self, query: Query) -> SummaryChunks: """ Метод для поиска чанков отвечающих на вопрос пользователя в разных типах поиска. Args: query: Запрос пользователя. Returns: Возвращает чанки найденные на запрос пользователя. """ self.try_load_default_dataset() query = self._get_abbreviations(query) logger.info(f'Start search for {query.query_abbreviation}') logger.info(f'Use elastic search: {self.config.db_config.elastic.use_elastic}') answer = {} if self.config.db_config.search.vector_search.use_vector_search: logger.info('Start vector search.') answer['vector_answer'] = self.__vector_search(query.query_abbreviation) logger.info(f'Vector search found {len(answer["vector_answer"])} chunks') try: if self.config.db_config.elastic.use_elastic: if self.config.db_config.search.people_elastic_search.use_people_search: logger.info('Start people search.') people_answer = self.__elastic_search( query.query, index_name=self.config.db_config.search.people_elastic_search.index_name, search_function=get_elastic_people_query, size=self.config.db_config.search.people_elastic_search.k_neighbors, ) logger.info(f'People search found {len(people_answer)} chunks') answer['people_answer'] = people_answer if self.config.db_config.search.chunks_elastic_search.use_chunks_search: logger.info('Start full text chunks search.') chunks_answer = self.__elastic_search( query.query, index_name=self.config.db_config.search.chunks_elastic_search.index_name, search_function=get_elastic_query, size=self.config.db_config.search.chunks_elastic_search.k_neighbors, ) indexes = self._get_indexes_full_text_elastic_search(chunks_answer) chunks_answer = self.meta_database.search(indexes) logger.info( f'Full text chunks search found {len(chunks_answer)} chunks' ) answer['chunks_answer'] = chunks_answer if self.config.db_config.search.groups_elastic_search.use_groups_search: logger.info('Start groups search.') groups_answer = self.__elastic_search( query.query, index_name=self.config.db_config.search.groups_elastic_search.index_name, search_function=get_elastic_group_query, size=self.config.db_config.search.groups_elastic_search.k_neighbors, ) if len(groups_answer) != 0: logger.info(f'Groups search found {len(groups_answer)} chunks') answer['groups_answer'] = groups_answer if ( self.config.db_config.search.rocks_nn_elastic_search.use_rocks_nn_search ): logger.info('Start Rocks NN search.') rocks_nn_answer = self.__elastic_search( query.query, index_name=self.config.db_config.search.rocks_nn_elastic_search.index_name, search_function=get_elastic_rocks_nn_query, size=self.config.db_config.search.rocks_nn_elastic_search.k_neighbors, ) if len(rocks_nn_answer) != 0: logger.info( f'Rocks NN search found {len(rocks_nn_answer)} chunks' ) answer['rocks_nn_answer'] = rocks_nn_answer if ( self.config.db_config.search.segmentation_elastic_search.use_segmentation_search ): logger.info('Start Segmentation search.') segmentation_answer = self.__elastic_search( query.query, index_name=self.config.db_config.search.segmentation_elastic_search.index_name, search_function=get_elastic_segmentation_query, size=self.config.db_config.search.segmentation_elastic_search.k_neighbors, ) if len(segmentation_answer) != 0: logger.info( f'Segmentation search found {len(segmentation_answer)} chunks' ) answer['segmentation_answer'] = segmentation_answer except ConnectionError: logger.info("Connection Error Elasticsearch") final_answer = aggregate_answers(**answer) logger.info(f'Final answer found {len(final_answer)} chunks') return SummaryChunks(**final_answer) def llm_classification(self, query: str) -> str: type_query = self.query_classification.classification(query) return type_query def llm_answer( self, query: str, answer_chunks: SummaryChunks ) -> Tuple[str, str, str, int]: """ Метод для поиска правильного ответа с помощью LLM. Args: query: Запрос. answer_chunks: Ответы векторного поиска и elastic. Returns: Возвращает исходные chunks из поисков, и chunk который выбрала модель. """ prompt = PROMPT return self.llm_search.llm_chunk_search(query, answer_chunks, prompt)