krishnapal2308's picture
Upload 2 files
cbc9c5c verified
import gradio as gr
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForTokenClassification
model_name = "krishnapal2308/NER-Task3"
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = TFAutoModelForTokenClassification.from_pretrained(model_name)
id2label = {
0: "O",
1: "B-treatment", 2: "I-treatment",
3: "B-chronic_disease", 4: "I-chronic_disease",
5: "B-cancer", 6: "I-cancer",
7: "B-allergy_name", 8: "I-allergy_name"
}
def predict(text):
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
outputs = model(inputs)
predictions = tf.argmax(outputs.logits, axis=-1)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [id2label[pred.numpy()] for pred in predictions[0]]
# Remove special tokens and group B- and I- tags
result = []
current_word = ""
current_label = None
for token, label in zip(tokens, labels):
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if token.startswith("##"):
current_word += token[2:] # Append without '##'
else:
if current_word: # Save the previous word before starting a new one
result.append((current_word, current_label))
current_word = token
current_label = label[2:] if label.startswith("B-") else label[2:] if label.startswith("I-") and current_label == label[2:] else None
if current_word: # Add the last word
result.append((current_word, current_label))
final_result = []
to_skip = []
# Combining words with same labels
for ind, word_label in enumerate(result):
print(ind, word_label)
if ind not in to_skip:
if word_label[1]:
combined_word = word_label[0]
for next_ind, next_word_label in enumerate(result[ind+1:]):
if word_label[1] == next_word_label[1]:
to_skip.append(ind+next_ind+1)
combined_word += ' '+next_word_label[0]
final_result.append((combined_word, word_label[1]))
else:
final_result.append((word_label[0], word_label[1]))
final_result = [(word, 'allergy') if label == 'allergy_name' else (word, label) for word, label in final_result]
return final_result
def ner_function(text):
result = predict(text)
return result
examples = [
["The patient was diagnosed with stage 2 breast cancer and treated with tamoxifen."],
["He has a history of type 2 diabetes and is allergic to penicillin."]
]
# Create Gradio interface
iface = gr.Interface(
fn=ner_function,
inputs=gr.Textbox(lines=5, label="Input Text"),
outputs=gr.HighlightedText(label="Text with Entities"),
title="Clinical Trial Named Entity Recognition",
description="""
This interface presents a Named Entity Recognition (NER) system specifically designed for analyzing clinical trial data.
Leveraging a fine-tuned BERT-based model, the system is capable of identifying and classifying key medical entities such as treatments, chronic diseases, cancers, and allergies.
Explore the provided examples to observe the model's capabilities in action.
""",
examples=examples,
cache_examples=True,
allow_flagging="never",
theme="default"
)
# Launch the interface
iface.launch()