import gradio as gr from spacy import displacy from transformers import (AutoModelForTokenClassification, AutoTokenizer, pipeline, ) model_checkpoint = "jsylee/scibert_scivocab_uncased-finetuned-ner" model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=5, id2label={0: 'O', 1: 'DRUG', 2: 'DRUG', 3: 'ADVERSE EFFECT', 4: 'ADVERSE EFFECT'} # for grouping BIO tags back together ) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model.to("cpu") model_pipeline = pipeline(task="ner", model=model, tokenizer=tokenizer, device=-1, grouped_entities=True) def extract_entities(sentence): """ Extract drug and reaction entities, and show using displaCy's NER visualizer. source: https://github.com/jsylee/personal-projects/blob/master/Hugging%20Face%20ADR%20Fine-Tuning/SciBERT%20ADR%20Fine-Tuning.ipynb """ tokens = model_pipeline(sentence) entities = [] for token in tokens: label = token["entity_group"] if label != "0": # label 0 corresponds to "Outside" any entity we care about token["label"] = label entities.append(token) params = [{"text": sentence, "ents": entities, "title": None}] return displacy.render(params, style="ent", manual=True, options={ "colors": { "DRUG": "#f08080", "ADVERSE EFFECT": "#9bddff", }, }) # the following examples of adverse effects are taken from Wikipedia: # https://en.wikipedia.org/wiki/Adverse_effect#Medications examples = [ "Abortion, miscarriage or uterine hemorrhage associated with misoprostol (Cytotec), a labor-inducing drug.", "Addiction to many sedatives and analgesics, such as diazepam, morphine, etc.", "Birth defects associated with thalidomide", "Bleeding of the intestine associated with aspirin therapy", "Cardiovascular disease associated with COX-2 inhibitors (i.e. Vioxx)", "Deafness and kidney failure associated with gentamicin (an antibiotic)", "Death, following sedation, in children using propofol (Diprivan)", "Depression or hepatic injury caused by interferon", "Diabetes caused by atypical antipsychotic medications (neuroleptic psychiatric drugs)" ] footer = """