|
import logging |
|
from datetime import datetime |
|
from typing import Dict, List, AnyStr |
|
|
|
from sentence_transformers import CrossEncoder |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.cross_encoder = CrossEncoder(path, device=device) |
|
|
|
def __call__(self, data: Dict[str, AnyStr]) -> Dict[str, List[float]]: |
|
""" |
|
Args: |
|
data (Dict[str, AnyStr]): A dictionary containing the input data and parameters for inference. |
|
The input data should include a "query" and a list of "passages". |
|
Return: |
|
Dict[str, List[float]]: A dictionary with a single key "scores", containing a list of floating point numbers. |
|
Each number represents the score of a passage for the given query. The order of the scores matches the order of the passages. |
|
""" |
|
inputs = data.get("inputs") |
|
query = inputs.get("query") |
|
passages = inputs.get("passages") |
|
|
|
logger.info(f"Query: {query}") |
|
logger.info(f"N. of passages: {len(passages)}") |
|
|
|
start_time = datetime.now() |
|
|
|
scores = self.cross_encoder.predict([(query, passage) for passage in passages], activation_fct=torch.nn.Sigmoid()) |
|
|
|
logger.info(f"Time to run cross-encoder for query '{query}' with {len(passages)} passages: {datetime.now() - start_time}") |
|
|
|
logger.info(f"Scores: {scores}") |
|
return { |
|
"scores": scores.tolist() |
|
} |
|
|
|
|