Spaces:
Runtime error
Runtime error
from haystack.document_stores import InMemoryDocumentStore | |
from haystack.utils import convert_files_to_docs | |
from haystack.nodes.retriever import TfidfRetriever | |
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline | |
from haystack.nodes.retriever import EmbeddingRetriever | |
from haystack.nodes import FARMReader | |
import pickle | |
from pprint import pprint | |
class ExportableInMemoryDocumentStore(InMemoryDocumentStore): | |
""" | |
Wrapper class around the InMemoryDocumentStore. | |
When the application is deployed to Huggingface Spaces there will be no GPU available. | |
We need to load pre-calculated data into the InMemoryDocumentStore. | |
""" | |
def export(self, file_name='in_memory_store.pkl'): | |
with open(file_name, 'wb') as f: | |
pickle.dump(self.indexes, f) | |
def load_data(self, file_name='in_memory_store.pkl'): | |
with open(file_name, 'rb') as f: | |
self.indexes = pickle.load(f) | |
document_store = ExportableInMemoryDocumentStore(similarity='cosine') | |
document_store.load_data('documentstore_german-election-idx.pkl') | |
retriever = TfidfRetriever(document_store=document_store) | |
base_dense_retriever = EmbeddingRetriever( | |
document_store=document_store, | |
embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
model_format='sentence_transformers' | |
) | |
fine_tuned_retriever = EmbeddingRetriever( | |
document_store=document_store, | |
embedding_model='./adapted-retriever', | |
model_format='sentence_transformers' | |
) | |
def sparse_retrieval(query): | |
"""Sparse retrieval pipeline""" | |
p_retrieval = DocumentSearchPipeline(retriever) | |
return p_retrieval.run(query=query) | |
def dense_retrieval(query, retriever='base'): | |
if retriever == 'base': | |
p_retrieval = DocumentSearchPipeline(base_dense_retriever) | |
elif retriever == 'adapted': | |
p_retrieval = DocumentSearchPipeline(fine_tuned_retriever) | |
else: | |
return None | |
return p_retrieval.run(query=query) | |
def do_search(query): | |
sparse_result = sparse_retrieval(query)['documents'] | |
dense_base_result =dense_retrieval(query, retriever='base')['documents'] | |
dense_adapted_result = dense_retrieval(query, retriever='adapted')['documents'] | |
return sparse_result, dense_base_result, dense_adapted_result | |
if __name__ == '__main__': | |
query = 'Klimawandel stoppen?' | |
result = do_search(query) | |
pprint(result) | |