#!/usr/bin/env python3 # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Rank documents with TF-IDF scores""" import logging import numpy as np import scipy.sparse as sp from multiprocessing.pool import ThreadPool from functools import partial from . import utils from . import DEFAULTS from .. import tokenizers logger = logging.getLogger(__name__) class TfidfDocRanker(object): """Loads a pre-weighted inverted index of token/document terms. Scores new queries by taking sparse dot products. """ def __init__(self, tfidf_path=None, strict=True): """ Args: tfidf_path: path to saved model file strict: fail on empty queries or continue (and return empty result) """ # Load from disk tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] logger.info('Loading %s' % tfidf_path) matrix, metadata = utils.load_sparse_csr(tfidf_path) self.doc_mat = matrix self.ngrams = metadata['ngram'] self.hash_size = metadata['hash_size'] self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() self.doc_freqs = metadata['doc_freqs'].squeeze() self.doc_dict = metadata['doc_dict'] self.num_docs = len(self.doc_dict[0]) self.strict = strict def get_doc_index(self, doc_id): """Convert doc_id --> doc_index""" return self.doc_dict[0][doc_id] def get_doc_id(self, doc_index): """Convert doc_index --> doc_id""" return self.doc_dict[1][doc_index] def closest_docs(self, query, k=1): """Closest docs by dot product between query and documents in tfidf weighted word vector space. """ spvec = self.text2spvec(query) res = spvec * self.doc_mat if len(res.data) <= k: o_sort = np.argsort(-res.data) else: o = np.argpartition(-res.data, k)[0:k] o_sort = o[np.argsort(-res.data[o])] doc_scores = res.data[o_sort] doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] return doc_ids, doc_scores def batch_closest_docs(self, queries, k=1, num_workers=None): """Process a batch of closest_docs requests multithreaded. Note: we can use plain threads here as scipy is outside of the GIL. """ with ThreadPool(num_workers) as threads: closest_docs = partial(self.closest_docs, k=k) results = threads.map(closest_docs, queries) return results def parse(self, query): """Parse the query into tokens (either ngrams or tokens).""" tokens = self.tokenizer.tokenize(query) return tokens.ngrams(n=self.ngrams, uncased=True, filter_fn=utils.filter_ngram) def text2spvec(self, query): """Create a sparse tfidf-weighted word vector from query. tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) """ # Get hashed ngrams words = self.parse(utils.normalize(query)) wids = [utils.hash(w, self.hash_size) for w in words] if len(wids) == 0: if self.strict: raise RuntimeError('No valid word in: %s' % query) else: logger.warning('No valid word in: %s' % query) return sp.csr_matrix((1, self.hash_size)) # Count TF wids_unique, wids_counts = np.unique(wids, return_counts=True) tfs = np.log1p(wids_counts) # Count IDF Ns = self.doc_freqs[wids_unique] idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) idfs[idfs < 0] = 0 # TF-IDF data = np.multiply(tfs, idfs) # One row, sparse csr matrix indptr = np.array([0, len(wids_unique)]) spvec = sp.csr_matrix( (data, wids_unique, indptr), shape=(1, self.hash_size) ) return spvec