import gradio as gr import transformers from transformers import AutoModelForTokenClassification, AutoTokenizer import torch # model large model_name = "pucpr/clinicalnerpt-chemical" model_large = AutoModelForTokenClassification.from_pretrained(model_name) tokenizer_large = AutoTokenizer.from_pretrained(model_name) # model base model_name = "pucpr/clinicalnerpt-chemical" model_base = AutoModelForTokenClassification.from_pretrained(model_name) tokenizer_base = AutoTokenizer.from_pretrained(model_name) # css background_colors_entity_word = { 'ChemicalDrugs': "#fae8ff", } background_colors_entity_tag = { 'ChemicalDrugs': "#d946ef", } css = { 'entity_word': 'color:#000000;background: #xxxxxx; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 2.5; border-radius: 0.35em;', 'entity_tag': 'color:#fff;background: #xxxxxx; font-size: 0.8em; font-weight: bold; line-height: 2.5; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5em;' } list_EN = "ChemicalDrugs" # infos title = "BioBERTpt - Chemical entities" description = "BioBERTpt - Chemical entities" allow_screenshot = False allow_flagging = False examples = [ ["Dispneia venoso central em subclavia D duplolumen recebendo solução salina e glicosada em BI."], ["Paciente com Sepse pulmonar em D8 tazocin (paciente não recebeu por 2 dias Atb)."], ["FOI REALIZADO CURSO DE ATB COM LEVOFLOXACINA POR 7 DIAS."], ] def ner(input_text): num = 0 for tokenizer,model in zip([tokenizer_large,tokenizer_base],[model_large,model_base]): # tokenization inputs = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt") tokens = inputs.tokens() # get predictions outputs = model(**inputs).logits predictions = torch.argmax(outputs, dim=2) preds = [model_base.config.id2label[prediction] for prediction in predictions[0].numpy()] # variables groups_pred = dict() group_indices = list() group_label = '' pred_prec = '' group_start = '' count = 0 # group the NEs for i,en in enumerate(preds): if en == 'O': if len(group_indices) > 0: groups_pred[count] = {'indices':group_indices,'en':group_label} group_indices = list() group_label = '' count += 1 if en.startswith('B'): if len(group_indices) > 0: groups_pred[count] = {'indices':group_indices,'en':group_label} group_indices = list() group_label = '' count += 1 group_indices.append(i) group_label = en.replace('B-','') pred_prec = en elif en.startswith('I'): if len(group_indices) > 0: if en.replace('I-','') == group_label: group_indices.append(i) else: groups_pred[count] = {'indices':group_indices,'en':group_label} group_indices = [i] group_label = en.replace('I-','') count += 1 else: group_indices = [i] group_label = en.replace('I-','') if i == len(preds) - 1 and len(group_indices) > 0: groups_pred[count] = {'indices':group_indices,'en':group_label} group_indices = list() group_label = '' count += 1 # there is at least one NE len_groups_pred = len(groups_pred) inputs = inputs['input_ids'][0].numpy()#[1:-1] if len_groups_pred > 0: for pred_num in range(len_groups_pred): en = groups_pred[pred_num]['en'] indices = groups_pred[pred_num]['indices'] if pred_num == 0: if indices[0] > 0: output = tokenizer.decode(inputs[:indices[0]]) + f'' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'' + en + ' ' else: output = f'' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'' + en + ' ' else: output += tokenizer.decode(inputs[indices_prev[-1]+1:indices[0]]) + f'' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'' + en + ' ' indices_prev = indices output += tokenizer.decode(inputs[indices_prev[-1]+1:]) else: output = input_text # output output = output.replace('[CLS]','').replace(' [SEP]','').replace('##','') output = "