from haystack.document_stores import InMemoryDocumentStore from haystack.utils import convert_files_to_docs from haystack.nodes.retriever import TfidfRetriever from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline from haystack.nodes.retriever import EmbeddingRetriever from haystack.nodes import FARMReader import pickle from pprint import pprint class ExportableInMemoryDocumentStore(InMemoryDocumentStore): """ Wrapper class around the InMemoryDocumentStore. When the application is deployed to Huggingface Spaces there will be no GPU available. We need to load pre-calculated data into the InMemoryDocumentStore. """ def export(self, file_name='in_memory_store.pkl'): with open(file_name, 'wb') as f: pickle.dump(self.indexes, f) def load_data(self, file_name='in_memory_store.pkl'): with open(file_name, 'rb') as f: self.indexes = pickle.load(f) document_store = ExportableInMemoryDocumentStore(similarity='cosine') document_store.load_data('documentstore_german-election-idx.pkl') retriever = TfidfRetriever(document_store=document_store) base_dense_retriever = EmbeddingRetriever( document_store=document_store, embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', model_format='sentence_transformers' ) fine_tuned_retriever = EmbeddingRetriever( document_store=document_store, embedding_model='./adapted-retriever', model_format='sentence_transformers' ) def sparse_retrieval(query): """Sparse retrieval pipeline""" p_retrieval = DocumentSearchPipeline(retriever) return p_retrieval.run(query=query) def dense_retrieval(query, retriever='base'): if retriever == 'base': p_retrieval = DocumentSearchPipeline(base_dense_retriever) elif retriever == 'adapted': p_retrieval = DocumentSearchPipeline(fine_tuned_retriever) else: return None return p_retrieval.run(query=query) def do_search(query): sparse_result = sparse_retrieval(query)['documents'] dense_base_result =dense_retrieval(query, retriever='base')['documents'] dense_adapted_result = dense_retrieval(query, retriever='adapted')['documents'] return sparse_result, dense_base_result, dense_adapted_result if __name__ == '__main__': query = 'Klimawandel stoppen?' result = do_search(query) pprint(result)