from typing import Any from transformers import ( AutoModelForTokenClassification, AutoTokenizer, ) from nltk.tokenize import sent_tokenize import torch import numpy as np class Detector: def __init__(self, model_name, device): if "classification" in model_name: num_labels = 2 elif "multi-dimension" in model_name: num_labels = 3 else: num_labels = 1 self.model = AutoModelForTokenClassification.from_pretrained( model_name, num_labels=num_labels ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.device = device self.model.to(device) self.model.eval() @torch.no_grad() def __call__(self, text, preprocess=True, threshold=None): """ return_type: sentence or text """ if preprocess: sents = sent_tokenize(text) text = " ".join(sents) else: sents = text.split(" ") input_ids = self.tokenizer(text, max_length=2048, truncation=True)["input_ids"] sent_label_idx = [i for i, ids in enumerate(input_ids) if ids == 2] tensor_input = torch.tensor([input_ids]).to(self.device) outputs = self.model(tensor_input).logits.detach().cpu().numpy() outputs_logits = outputs[0][sent_label_idx] outputs_logits: np.ndarray if outputs_logits.shape[1] == 2: outputs_logits = outputs_logits[:, 1] elif outputs_logits.shape[1] == 3: outputs_logits = outputs_logits.mean(axis=-1) outputs_logits = outputs_logits.flatten() if threshold is None: return list(zip(sents, outputs_logits.tolist())) else: return list(zip(sents, (outputs_logits > threshold).tolist()))