Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Rank documents with an ElasticSearch index""" | |
import logging | |
import scipy.sparse as sp | |
from multiprocessing.pool import ThreadPool | |
from functools import partial | |
from elasticsearch import Elasticsearch | |
from . import utils | |
from . import DEFAULTS | |
from .. import tokenizers | |
logger = logging.getLogger(__name__) | |
class ElasticDocRanker(object): | |
""" Connect to an ElasticSearch index. | |
Score pairs based on Elasticsearch | |
""" | |
def __init__(self, elastic_url=None, elastic_index=None, elastic_fields=None, elastic_field_doc_name=None, strict=True, elastic_field_content=None): | |
""" | |
Args: | |
elastic_url: URL of the ElasticSearch server containing port | |
elastic_index: Index name of ElasticSearch | |
elastic_fields: Fields of the Elasticsearch index to search in | |
elastic_field_doc_name: Field containing the name of the document (index) | |
strict: fail on empty queries or continue (and return empty result) | |
elastic_field_content: Field containing the content of document in plaint text | |
""" | |
# Load from disk | |
elastic_url = elastic_url or DEFAULTS['elastic_url'] | |
logger.info('Connecting to %s' % elastic_url) | |
self.es = Elasticsearch(hosts=elastic_url) | |
self.elastic_index = elastic_index | |
self.elastic_fields = elastic_fields | |
self.elastic_field_doc_name = elastic_field_doc_name | |
self.elastic_field_content = elastic_field_content | |
self.strict = strict | |
# Elastic Ranker | |
def get_doc_index(self, doc_id): | |
"""Convert doc_id --> doc_index""" | |
field_index = self.elastic_field_doc_name | |
if isinstance(field_index, list): | |
field_index = '.'.join(field_index) | |
result = self.es.search(index=self.elastic_index, body={'query':{'match': | |
{field_index: doc_id}}}) | |
return result['hits']['hits'][0]['_id'] | |
def get_doc_id(self, doc_index): | |
"""Convert doc_index --> doc_id""" | |
result = self.es.search(index=self.elastic_index, body={'query': { 'match': {"_id": doc_index}}}) | |
source = result['hits']['hits'][0]['_source'] | |
return utils.get_field(source, self.elastic_field_doc_name) | |
def closest_docs(self, query, k=1): | |
"""Closest docs by using ElasticSearch | |
""" | |
results = self.es.search(index=self.elastic_index, body={'size':k ,'query': | |
{'multi_match': { | |
'query': query, | |
'type': 'most_fields', | |
'fields': self.elastic_fields}}}) | |
hits = results['hits']['hits'] | |
doc_ids = [utils.get_field(row['_source'], self.elastic_field_doc_name) for row in hits] | |
doc_scores = [row['_score'] for row in hits] | |
return doc_ids, doc_scores | |
def batch_closest_docs(self, queries, k=1, num_workers=None): | |
"""Process a batch of closest_docs requests multithreaded. | |
Note: we can use plain threads here as scipy is outside of the GIL. | |
""" | |
with ThreadPool(num_workers) as threads: | |
closest_docs = partial(self.closest_docs, k=k) | |
results = threads.map(closest_docs, queries) | |
return results | |
# Elastic DB | |
def __enter__(self): | |
return self | |
def close(self): | |
"""Close the connection to the database.""" | |
self.es = None | |
def get_doc_ids(self): | |
"""Fetch all ids of docs stored in the db.""" | |
results = self.es.search(index= self.elastic_index, body={ | |
"query": {"match_all": {}}}) | |
doc_ids = [utils.get_field(result['_source'], self.elastic_field_doc_name) for result in results['hits']['hits']] | |
return doc_ids | |
def get_doc_text(self, doc_id): | |
"""Fetch the raw text of the doc for 'doc_id'.""" | |
idx = self.get_doc_index(doc_id) | |
result = self.es.get(index=self.elastic_index, doc_type='_doc', id=idx) | |
return result if result is None else result['_source'][self.elastic_field_content] | |