nel-mgenre-multilingual / model_handler_nel.py
emanuelaboros's picture
Upload folder using huggingface_hub
fc58e46
raw
history blame
11.5 kB
from ts.torch_handler.base_handler import BaseHandler
from nltk.chunk import conlltags2tree
from nltk import pos_tag
from nltk.tree import Tree
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json
import string
# Get the directory of your script
import logging
import os
import sys
logger = logging.getLogger(__name__)
# get the current directory
current_directory = os.path.dirname(os.path.realpath(__file__))
print(current_directory)
# add the current directory to sys.path
sys.path.insert(0, current_directory)
import pickle
def pickle_load(path, verbose=False):
if path is None:
return None
if verbose:
print('Loading {}'.format(path))
with open(path, "rb") as f:
obj = pickle.load(f)
return obj
DEFAULT_MODEL = 'facebook/mgenre-wiki'
def tokenize(text):
# Add a space before and after specified punctuation marks
# text = re.sub(r'([,.!?])', r' \1 ', text)
# Split the text into tokens
tokens = text.split()
return tokens
logger.info(f'Loading title2wikidataID')
lang_title2wikidataID_path = "lang_title2wikidataID-normalized_with_redirect.pkl"
lang_title2wikidataID = pickle_load(
lang_title2wikidataID_path, verbose=True)
def text_to_id(x):
return max(lang_title2wikidataID[tuple(
reversed([y.strip() for y in x.split(" >> ")]))], key=lambda y: int(y[1:]))
"""
Method for retrieving the Qid
"""
def get_wikidata_qid(wikipedia_titles, scores):
qid = 'NIL'
wikipedia_title = wikipedia_titles[0]
score = scores[0]
for idx, title in enumerate(
wikipedia_titles):
try:
qid = text_to_id(title)
wikipedia_title = wikipedia_titles[idx]
score = scores[idx]
return qid, wikipedia_title, score
except BaseException:
qid = 'NIL'
return qid, wikipedia_title, score
def get_entities(tokens, preds_list_coarse, preds_list_fine, coarse_confidences, fine_confidences):
tags_coarse = [tag.replace('S-', 'B-').replace('E-', 'I-') for tag in preds_list_coarse]
tags_fine = [tag.replace('S-', 'B-').replace('E-', 'I-') for tag in preds_list_fine]
pos_tags = [pos for token, pos in pos_tag(tokens)]
conll_coarse_tags = [(token, pos, tg)
for token, pos, tg in zip(tokens, pos_tags, tags_coarse)]
conll_fine_tags = [(token, pos, tg)
for token, pos, tg in zip(tokens, pos_tags, tags_fine)]
ne_tree_coarse = conlltags2tree(conll_coarse_tags)
ne_tree_fine = conlltags2tree(conll_fine_tags)
coarse_entities = get_entities_from_tree(ne_tree_coarse, coarse_confidences)
fine_entities = get_entities_from_tree(ne_tree_fine, fine_confidences)
return coarse_entities, fine_entities
def logarithmic_scaling(confidence_score):
return np.log(confidence_score + 1e-10) # Adding a small value to avoid log(0)
def classify_confidence(confidence_score):
return int(confidence_score * 100.0)
# TypeError: Object of type float32 is not JSON serializable
# if confidence_score > 0.95:
# return 'high'
# elif confidence_score > 0.75:
# return 'medium'
# else:
# return 'low'
def get_entities_from_tree(ne_tree, token_confidences):
entities = []
idx = 0
char_position = 0 # This will hold the current character position
for subtree in ne_tree:
# skipping 'O' tags
if isinstance(subtree, Tree):
original_label = subtree.label()
original_string = " ".join(
[token for token, pos in subtree.leaves()])
# original_string = reconstruct_text([token for token, pos in subtree.leaves()])
entity_start_position = char_position
entity_end_position = entity_start_position + len(original_string)
confidences = token_confidences[idx:idx + len(subtree)]
# Compute the average confidence
avg_confidence = sum(confidences) / len(confidences)
print(original_string, '- confidence -', token_confidences[idx:idx + len(subtree)], '- avg -',
avg_confidence, classify_confidence(avg_confidence), '- label -', original_label)
entities.append(
(original_string,
original_label,
(idx,
idx + len(subtree)),
(entity_start_position,
entity_end_position),
classify_confidence(avg_confidence)))
idx += len(subtree)
# Update the current character position
# We add the length of the original string + 1 (for the space)
char_position += len(original_string) + 1
else:
token, pos = subtree
# If it's not a named entity, we still need to update the character
# position
char_position += len(token) + 1 # We add 1 for the space
idx += 1
return entities
def realign(
text_sentence,
tokens_coarse_result,
tokens_fine_result,
coarse_confidences,
fine_confidences,
tokenizer,
language,
nerc_coarse_label_map,
nerc_fine_label_map):
preds_list_coarse, preds_list_fine, words_list, coarse_confidences_list, fine_confidences_list = [], [], [], [], []
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
for idx, word in enumerate(text_sentence):
try:
beginning_index = word_ids.index(idx)
preds_list_coarse.append(nerc_coarse_label_map[tokens_coarse_result[beginning_index]])
preds_list_fine.append(nerc_fine_label_map[tokens_fine_result[beginning_index]])
coarse_confidences_list.append(coarse_confidences[beginning_index])
fine_confidences_list.append(fine_confidences[beginning_index])
except Exception as ex: # the sentence was longer then max_length
preds_list_coarse.append('O')
preds_list_fine.append('O')
coarse_confidences_list.append(1.0)
fine_confidences_list.append(1.0)
words_list.append(word)
return words_list, preds_list_coarse, preds_list_fine, coarse_confidences_list, fine_confidences_list
import os
class NewsAgencyHandler(BaseHandler):
def __init__(self):
super().__init__()
self.model = None
self.tokenizer = None
self.device = None
def initialize(self, ctx):
# boilerplate
properties = ctx.system_properties
self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(self.map_location + ":" + str(
properties.get("gpu_id")) if torch.cuda.is_available() else self.map_location)
# self.manifest = ctx.manifest
# model_dir is the inside of your archive!
# extra-files are in this dir.
model_name = ctx.model_yaml_config["handler"]["model_name"]
logger.info("Model %s loading tokenizer", model_name)
# serialized_file = self.manifest["model"]["serializedFile"]
# self.tokenizer = AutoTokenizer.from_pretrained(
# model_dir, local_files_only=True)
#
# # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# # further setup config can be added.
logger.error(f'getcwd: {os.getcwd()}')
logger.error(f'__file__: {__file__}')
logger.error(f'Model: {model_name}')
logger.error(f'Device: {self.device}')
#
# save_mode = "pretrained"
#
# if save_mode == "torchscript":
# self.model = torch.jit.load(serialized_file)
# elif save_mode == "pretrained":
# model_dir = properties.get("model_dir")
# serialized_file = self.manifest["model"]["serializedFile"]
# self.tokenizer = AutoTokenizer.from_pretrained(
# model_dir, local_files_only=True)
#
# self.model = torch.jit.load(serialized_file, map_location=self.device)
#
# self.model.to(self.device)
# self.model.eval()
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# self.model = torch.nn.DataParallel(self.model)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# else:
# logger.warning("Missing the checkpoint or state_dict.")
self.model.to(self.map_location)
self.model.eval()
logger.info("Transformer model from path %s loaded successfully", model_name)
def preprocess(self, requests):
logger.info(f'Preprocessing requests {len(requests)}')
data = requests[0]
text_sentences = []
# The request should have the text:
# THE next MEETLNG of the TRITSTEE, will be held at the [START] LONDON HOTEL [END] in POOLE, on ldomaT,
# the 12th day or MARCH next. at 12 oClock at Noon
for item in data['body']:
item = json.loads(item)
text = item['text']
text_sentences.append(text)
language = item['language']
# print('Doc id:', item['doc_id'])
# print('-----Text', text, type(text))
# print('-----Language', language)
return text_sentences, language
def inference(self, inputs):
text_sentences, language = inputs
tokens_coarse_results, tokens_fine_results = [], []
tokens_coarse_confidences, tokens_fine_confidences = [], []
qids = []
with torch.no_grad():
for sentence in text_sentences:
sentences = [sentence]
# logger.error(f'Device: {self.device}')
outputs = self.model.generate(
**self.tokenizer(sentences, return_tensors="pt").to(self.device),
num_beams=5,
num_return_sequences=5,
return_dict_in_generate=True,
output_scores=True)
token_ids, scores = outputs['sequences'], outputs['sequences_scores']
wikipedia_titles = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
# Example log-likelihoods (scores)
log_likelihoods = torch.tensor(scores)
# Convert log-likelihoods to "probabilities" (not true probabilities)
probabilities = torch.exp(log_likelihoods)
# Normalize these probabilities so they sum to 1
normalized_probabilities = probabilities / torch.sum(probabilities)
# Convert to percentages
percentages = normalized_probabilities * 100
qid, wikipedia_title, score = get_wikidata_qid(wikipedia_titles, percentages)
percentage_score = int(score)
# logger.info(f"Model prediction: {wikipedia_titles} {qid}, {wikipedia_title}, {score}, "
# f"---- {percentage_score}")
qids.append({'qid': qid, 'wikipedia_title': wikipedia_title, 'score': percentage_score})
# logger.info('-' * 100)
return qids, text_sentences, language
def postprocess(self, outputs):
# postprocess the outputs here, for example, convert predictions to labels
qids, text_sentences, language = outputs
logger.info(f'Result NEL: {qids}')
return [[qids]]