Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Rank documents with TF-IDF scores""" | |
import logging | |
import numpy as np | |
import scipy.sparse as sp | |
from multiprocessing.pool import ThreadPool | |
from functools import partial | |
from . import utils | |
from . import DEFAULTS | |
from .. import tokenizers | |
logger = logging.getLogger(__name__) | |
class TfidfDocRanker(object): | |
"""Loads a pre-weighted inverted index of token/document terms. | |
Scores new queries by taking sparse dot products. | |
""" | |
def __init__(self, tfidf_path=None, strict=True): | |
""" | |
Args: | |
tfidf_path: path to saved model file | |
strict: fail on empty queries or continue (and return empty result) | |
""" | |
# Load from disk | |
tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] | |
logger.info('Loading %s' % tfidf_path) | |
matrix, metadata = utils.load_sparse_csr(tfidf_path) | |
self.doc_mat = matrix | |
self.ngrams = metadata['ngram'] | |
self.hash_size = metadata['hash_size'] | |
self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() | |
self.doc_freqs = metadata['doc_freqs'].squeeze() | |
self.doc_dict = metadata['doc_dict'] | |
self.num_docs = len(self.doc_dict[0]) | |
self.strict = strict | |
def get_doc_index(self, doc_id): | |
"""Convert doc_id --> doc_index""" | |
return self.doc_dict[0][doc_id] | |
def get_doc_id(self, doc_index): | |
"""Convert doc_index --> doc_id""" | |
return self.doc_dict[1][doc_index] | |
def closest_docs(self, query, k=1): | |
"""Closest docs by dot product between query and documents | |
in tfidf weighted word vector space. | |
""" | |
spvec = self.text2spvec(query) | |
res = spvec * self.doc_mat | |
if len(res.data) <= k: | |
o_sort = np.argsort(-res.data) | |
else: | |
o = np.argpartition(-res.data, k)[0:k] | |
o_sort = o[np.argsort(-res.data[o])] | |
doc_scores = res.data[o_sort] | |
doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] | |
return doc_ids, doc_scores | |
def batch_closest_docs(self, queries, k=1, num_workers=None): | |
"""Process a batch of closest_docs requests multithreaded. | |
Note: we can use plain threads here as scipy is outside of the GIL. | |
""" | |
with ThreadPool(num_workers) as threads: | |
closest_docs = partial(self.closest_docs, k=k) | |
results = threads.map(closest_docs, queries) | |
return results | |
def parse(self, query): | |
"""Parse the query into tokens (either ngrams or tokens).""" | |
tokens = self.tokenizer.tokenize(query) | |
return tokens.ngrams(n=self.ngrams, uncased=True, | |
filter_fn=utils.filter_ngram) | |
def text2spvec(self, query): | |
"""Create a sparse tfidf-weighted word vector from query. | |
tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) | |
""" | |
# Get hashed ngrams | |
words = self.parse(utils.normalize(query)) | |
wids = [utils.hash(w, self.hash_size) for w in words] | |
if len(wids) == 0: | |
if self.strict: | |
raise RuntimeError('No valid word in: %s' % query) | |
else: | |
logger.warning('No valid word in: %s' % query) | |
return sp.csr_matrix((1, self.hash_size)) | |
# Count TF | |
wids_unique, wids_counts = np.unique(wids, return_counts=True) | |
tfs = np.log1p(wids_counts) | |
# Count IDF | |
Ns = self.doc_freqs[wids_unique] | |
idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) | |
idfs[idfs < 0] = 0 | |
# TF-IDF | |
data = np.multiply(tfs, idfs) | |
# One row, sparse csr matrix | |
indptr = np.array([0, len(wids_unique)]) | |
spvec = sp.csr_matrix( | |
(data, wids_unique, indptr), shape=(1, self.hash_size) | |
) | |
return spvec | |