File size: 2,061 Bytes
bdf7636
023a91a
bdf7636
5826177
bdf7636
24d9d43
3469da9
bdf7636
a9898b9
bc05f33
a9898b9
bc05f33
 
 
bdf7636
5826177
24d9d43
621a7b7
bdf7636
023a91a
621a7b7
bdf7636
32eb862
 
 
5826177
 
 
ecca0d1
 
5826177
 
 
 
023a91a
32eb862
24d9d43
32eb862
 
 
 
 
 
 
 
 
 
 
 
023a91a
 
 
621a7b7
023a91a
 
 
 
bc05f33
 
621a7b7
bc05f33
 
 
 
 
 
621a7b7
 
023a91a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
from collections import defaultdict, Counter

import matplotlib.pyplot as plt
import gradio as gr
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification

MODELS = ["d4data/biomedical-ner-all"]

current_model = MODELS[0]

tokenizer = AutoTokenizer.from_pretrained(current_model)
model = AutoModelForTokenClassification.from_pretrained(current_model)

plt.switch_backend("Agg")

examples = []
with open("examples.json", "r") as f:
    content = json.load(f)
    examples = [f"{x['label']}: {x['text']}" for x in content]

pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")


def plot_to_figure(grouped):
    fig = plt.figure()
    plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
    plt.margins(0.2)
    plt.subplots_adjust(bottom=0.4)
    plt.xticks(rotation=90)
    return fig


def run_ner(text):
    raw = pipe(text)
    ner_content = {
        "text": text,
        "entities": [
            {
                "entity": x["entity_group"],
                "word": x["word"],
                "score": x["score"],
                "start": x["start"],
                "end": x["end"],
            }
            for x in raw
        ],
    }
    grouped = Counter((x["entity_group"] for x in raw))
    rows = [[k, v] for k, v in grouped.items()]
    figure = plot_to_figure(grouped)
    return ner_content, rows, figure


with gr.Blocks() as demo:
    note = gr.Textbox(label="Note text")
    submit = gr.Button("Submit")
    # with gr.Accordion("Examples", open=False):
    example_dropdown = gr.Dropdown(label="Examples", choices=examples)
    example_dropdown.change(
        lambda x: gr.Textbox.update(value=x), inputs=example_dropdown, outputs=note
    )
    highlight = gr.HighlightedText(label="NER", combine_adjacent=True)
    table = gr.Dataframe(headers=["Entity", "Count"])
    plot = gr.Plot(label="Bar")
    submit.click(run_ner, [note], [highlight, table, plot])
    note.submit(run_ner, [note], [highlight, table, plot])

demo.launch()