File size: 2,194 Bytes
4f0303f
 
d3cf082
 
 
4f0303f
d3cf082
e43bb54
d3cf082
4f0303f
d3cf082
4f0303f
 
d3cf082
 
4f0303f
 
d3cf082
 
4f0303f
d3cf082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0303f
d3cf082
 
4f0303f
d3cf082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0303f
d3cf082
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import pandas as pd
import json
from collections import defaultdict

# Create tokenizer for biomed model
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")    # https://huggingface.co/d4data/biomedical-ner-all?text=asthma
model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")

# Matplotlib for entity graph
import matplotlib.pyplot as plt
plt.switch_backend("Agg")

# Load examples from JSON
EXAMPLES = {}
with open("examples.json", "r") as f:
    example_json = json.load(f)
    EXAMPLES = {x["text"]: x["label"] for x in example_json}

def group_by_entity(raw):
    out = defaultdict(int)
    for ent in raw:
        out[ent["entity_group"]] += 1
    # out["total"] = sum(out.values())
    return out


def plot_to_figure(grouped):
    fig = plt.figure()
    plt.bar(x=list(grouped.keys()), height=list(grouped.values()))
    plt.margins(0.2)
    plt.subplots_adjust(bottom=0.4)
    plt.xticks(rotation=90)
    return fig


def ner(text):
    raw = pipe(text)
    ner_content = {
        "text": text,
        "entities": [
            {
                "entity": x["entity_group"],
                "word": x["word"],
                "score": x["score"],
                "start": x["start"],
                "end": x["end"],
            }
            for x in raw
        ],
    }
    
    grouped = group_by_entity(raw)
    figure = plot_to_figure(grouped)
    label = EXAMPLES.get(text, "Unknown")

    meta = {
        "entity_counts": grouped,
        "entities": len(set(grouped.keys())),
        "counts": sum(grouped.values()),
    }

    return (ner_content, meta, label, figure)


interface = gr.Interface(
    ner,
    inputs=gr.Textbox(label="Note text", value=""),
    outputs=[
        gr.HighlightedText(label="NER", combine_adjacent=True),
        gr.JSON(label="Entity Counts"),
        gr.Label(label="Rating"),
        gr.Plot(label="Bar"),
    ],
    examples=list(EXAMPLES.keys()),
    allow_flagging="never",
)

interface.launch()