semantic-demo / retriever.py
mrchtr's picture
Update styles
01628bb
raw
history blame
2.43 kB
from haystack.document_stores import InMemoryDocumentStore
from haystack.utils import convert_files_to_docs
from haystack.nodes.retriever import TfidfRetriever
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes import FARMReader
import pickle
from pprint import pprint
class ExportableInMemoryDocumentStore(InMemoryDocumentStore):
"""
Wrapper class around the InMemoryDocumentStore.
When the application is deployed to Huggingface Spaces there will be no GPU available.
We need to load pre-calculated data into the InMemoryDocumentStore.
"""
def export(self, file_name='in_memory_store.pkl'):
with open(file_name, 'wb') as f:
pickle.dump(self.indexes, f)
def load_data(self, file_name='in_memory_store.pkl'):
with open(file_name, 'rb') as f:
self.indexes = pickle.load(f)
document_store = ExportableInMemoryDocumentStore(similarity='cosine')
document_store.load_data('documentstore_german-election-idx.pkl')
retriever = TfidfRetriever(document_store=document_store)
base_dense_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
model_format='sentence_transformers'
)
fine_tuned_retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model='./adapted-retriever',
model_format='sentence_transformers'
)
def sparse_retrieval(query):
"""Sparse retrieval pipeline"""
p_retrieval = DocumentSearchPipeline(retriever)
return p_retrieval.run(query=query)
def dense_retrieval(query, retriever='base'):
if retriever == 'base':
p_retrieval = DocumentSearchPipeline(base_dense_retriever)
elif retriever == 'adapted':
p_retrieval = DocumentSearchPipeline(fine_tuned_retriever)
else:
return None
return p_retrieval.run(query=query)
def do_search(query):
sparse_result = sparse_retrieval(query)['documents']
dense_base_result =dense_retrieval(query, retriever='base')['documents']
dense_adapted_result = dense_retrieval(query, retriever='adapted')['documents']
return sparse_result, dense_base_result, dense_adapted_result
if __name__ == '__main__':
query = 'Klimawandel stoppen?'
result = do_search(query)
pprint(result)