Spaces:
Running
Running
import gradio as gr | |
import transformers | |
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
import torch | |
# model large | |
model_name = "pucpr/clinicalnerpt-chemical" | |
model_large = AutoModelForTokenClassification.from_pretrained(model_name) | |
tokenizer_large = AutoTokenizer.from_pretrained(model_name) | |
# model base | |
model_name = "pucpr/clinicalnerpt-chemical" | |
model_base = AutoModelForTokenClassification.from_pretrained(model_name) | |
tokenizer_base = AutoTokenizer.from_pretrained(model_name) | |
# css | |
background_colors_entity_word = { | |
'ChemicalDrugs': "#fae8ff", | |
} | |
background_colors_entity_tag = { | |
'ChemicalDrugs': "#d946ef", | |
} | |
css = { | |
'entity_word': 'color:#000000;background: #xxxxxx; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 2.5; border-radius: 0.35em;', | |
'entity_tag': 'color:#fff;background: #xxxxxx; font-size: 0.8em; font-weight: bold; line-height: 2.5; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5em;' | |
} | |
list_EN = "<span style='" | |
list_EN += f"{css['entity_tag'].replace('#xxxxxx',background_colors_entity_tag['ChemicalDrugs'])};padding:0.5em;" | |
list_EN += "'>ChemicalDrugs</span>" | |
# infos | |
title = "BioBERTpt - Chemical entities" | |
description = "BioBERTpt - Chemical entities" | |
allow_screenshot = False | |
allow_flagging = False | |
examples = [ | |
["Dispneia venoso central em subclavia D duplolumen recebendo solução salina e glicosada em BI."], | |
["Paciente com Sepse pulmonar em D8 tazocin (paciente não recebeu por 2 dias Atb)."], | |
["FOI REALIZADO CURSO DE ATB COM LEVOFLOXACINA POR 7 DIAS."], | |
] | |
def ner(input_text): | |
num = 0 | |
for tokenizer,model in zip([tokenizer_large,tokenizer_base],[model_large,model_base]): | |
# tokenization | |
inputs = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt") | |
tokens = inputs.tokens() | |
# get predictions | |
outputs = model(**inputs).logits | |
predictions = torch.argmax(outputs, dim=2) | |
preds = [model_base.config.id2label[prediction] for prediction in predictions[0].numpy()] | |
# variables | |
groups_pred = dict() | |
group_indices = list() | |
group_label = '' | |
pred_prec = '' | |
group_start = '' | |
count = 0 | |
# group the NEs | |
for i,en in enumerate(preds): | |
if en == 'O': | |
if len(group_indices) > 0: | |
groups_pred[count] = {'indices':group_indices,'en':group_label} | |
group_indices = list() | |
group_label = '' | |
count += 1 | |
if en.startswith('B'): | |
if len(group_indices) > 0: | |
groups_pred[count] = {'indices':group_indices,'en':group_label} | |
group_indices = list() | |
group_label = '' | |
count += 1 | |
group_indices.append(i) | |
group_label = en.replace('B-','') | |
pred_prec = en | |
elif en.startswith('I'): | |
if len(group_indices) > 0: | |
if en.replace('I-','') == group_label: | |
group_indices.append(i) | |
else: | |
groups_pred[count] = {'indices':group_indices,'en':group_label} | |
group_indices = [i] | |
group_label = en.replace('I-','') | |
count += 1 | |
else: | |
group_indices = [i] | |
group_label = en.replace('I-','') | |
if i == len(preds) - 1 and len(group_indices) > 0: | |
groups_pred[count] = {'indices':group_indices,'en':group_label} | |
group_indices = list() | |
group_label = '' | |
count += 1 | |
# there is at least one NE | |
len_groups_pred = len(groups_pred) | |
inputs = inputs['input_ids'][0].numpy()#[1:-1] | |
if len_groups_pred > 0: | |
for pred_num in range(len_groups_pred): | |
en = groups_pred[pred_num]['en'] | |
indices = groups_pred[pred_num]['indices'] | |
if pred_num == 0: | |
if indices[0] > 0: | |
output = tokenizer.decode(inputs[:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> ' | |
else: | |
output = f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> ' | |
else: | |
output += tokenizer.decode(inputs[indices_prev[-1]+1:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> ' | |
indices_prev = indices | |
output += tokenizer.decode(inputs[indices_prev[-1]+1:]) | |
else: | |
output = input_text | |
# output | |
output = output.replace('[CLS]','').replace(' [SEP]','').replace('##','') | |
output = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + output + "</div>" | |
if num == 0: | |
output_large = output | |
num += 1 | |
else: output_base = output | |
return output_large, output_base | |
# interface gradio | |
iface = gr.Interface( | |
title=title, | |
description=description, | |
article=article, | |
allow_screenshot=allow_screenshot, | |
allow_flagging=allow_flagging, | |
fn=ner, | |
inputs=gr.inputs.Textbox(placeholder="Digite uma frase aqui ou clique em um exemplo:", lines=5), | |
outputs=[gr.outputs.HTML(label="NER1"),gr.outputs.HTML(label="NER2")], | |
examples=examples | |
) | |
iface.launch() |