Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
from transformers import AutoTokenizer, TFAutoModelForTokenClassification | |
model_name = "krishnapal2308/NER-Task3" | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
model = TFAutoModelForTokenClassification.from_pretrained(model_name) | |
id2label = { | |
0: "O", | |
1: "B-treatment", 2: "I-treatment", | |
3: "B-chronic_disease", 4: "I-chronic_disease", | |
5: "B-cancer", 6: "I-cancer", | |
7: "B-allergy_name", 8: "I-allergy_name" | |
} | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True) | |
outputs = model(inputs) | |
predictions = tf.argmax(outputs.logits, axis=-1) | |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
labels = [id2label[pred.numpy()] for pred in predictions[0]] | |
# Remove special tokens and group B- and I- tags | |
result = [] | |
current_word = "" | |
current_label = None | |
for token, label in zip(tokens, labels): | |
if token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
if token.startswith("##"): | |
current_word += token[2:] # Append without '##' | |
else: | |
if current_word: # Save the previous word before starting a new one | |
result.append((current_word, current_label)) | |
current_word = token | |
current_label = label[2:] if label.startswith("B-") else label[2:] if label.startswith("I-") and current_label == label[2:] else None | |
if current_word: # Add the last word | |
result.append((current_word, current_label)) | |
final_result = [] | |
to_skip = [] | |
# Combining words with same labels | |
for ind, word_label in enumerate(result): | |
print(ind, word_label) | |
if ind not in to_skip: | |
if word_label[1]: | |
combined_word = word_label[0] | |
for next_ind, next_word_label in enumerate(result[ind+1:]): | |
if word_label[1] == next_word_label[1]: | |
to_skip.append(ind+next_ind+1) | |
combined_word += ' '+next_word_label[0] | |
final_result.append((combined_word, word_label[1])) | |
else: | |
final_result.append((word_label[0], word_label[1])) | |
final_result = [(word, 'allergy') if label == 'allergy_name' else (word, label) for word, label in final_result] | |
return final_result | |
def ner_function(text): | |
result = predict(text) | |
return result | |
examples = [ | |
["The patient was diagnosed with stage 2 breast cancer and treated with tamoxifen."], | |
["He has a history of type 2 diabetes and is allergic to penicillin."] | |
] | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=ner_function, | |
inputs=gr.Textbox(lines=5, label="Input Text"), | |
outputs=gr.HighlightedText(label="Text with Entities"), | |
title="Clinical Trial Named Entity Recognition", | |
description=""" | |
This interface presents a Named Entity Recognition (NER) system specifically designed for analyzing clinical trial data. | |
Leveraging a fine-tuned BERT-based model, the system is capable of identifying and classifying key medical entities such as treatments, chronic diseases, cancers, and allergies. | |
Explore the provided examples to observe the model's capabilities in action. | |
""", | |
examples=examples, | |
cache_examples=True, | |
allow_flagging="never", | |
theme="default" | |
) | |
# Launch the interface | |
iface.launch() | |