Spaces:
Runtime error
Runtime error
Christopher McMaster
commited on
Commit
·
737007d
1
Parent(s):
56c6d77
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.cm as cm
|
2 |
+
import html
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from transformers import pipeline
|
6 |
+
import gradio as gr
|
7 |
+
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
|
8 |
+
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
|
9 |
+
c = cmap(x)
|
10 |
+
rgb = (np.array(c[:-1]) * 255).astype(int)
|
11 |
+
a = c[-1] * alpha_mult
|
12 |
+
return tuple(rgb.tolist() + [a])
|
13 |
+
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
|
14 |
+
html_code,spans = ['<span style="font-family: monospace;">'], []
|
15 |
+
for p, a, cols, l in zip(pieces, prob, colors, label):
|
16 |
+
p = html.escape(p)
|
17 |
+
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
|
18 |
+
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
|
19 |
+
html_code.append(sep.join(spans))
|
20 |
+
html_code.append('</span>')
|
21 |
+
return ''.join(html_code)
|
22 |
+
def nothing_ent(i, word):
|
23 |
+
return {
|
24 |
+
'entity': 'O',
|
25 |
+
'score': 0,
|
26 |
+
'index': i,
|
27 |
+
'word': word,
|
28 |
+
'start': 0,
|
29 |
+
'end': 0
|
30 |
+
}
|
31 |
+
def _gradio_highlighting(text):
|
32 |
+
result = ner_model(text)
|
33 |
+
tokens = ner_model.tokenizer.tokenize(text)
|
34 |
+
label_indeces = [i['index'] - 1 for i in result]
|
35 |
+
entities = list()
|
36 |
+
for i, word in enumerate(tokens):
|
37 |
+
if i in label_indeces:
|
38 |
+
entities.append(result[label_indeces.index(i)])
|
39 |
+
else:
|
40 |
+
entities.append(nothing_ent(i, word))
|
41 |
+
entities = ner_model.group_entities(entities)
|
42 |
+
spans = [e['word'] for e in entities]
|
43 |
+
probs = [e['score'] for e in entities]
|
44 |
+
labels = [e['entity_group'] for e in entities]
|
45 |
+
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
|
46 |
+
return piece_prob_html(spans, probs, colors, labels, sep=' ')
|
47 |
+
|
48 |
+
default_text = """# Pancreatitis
|
49 |
+
- Lipase: 535 -> 154 -> 145
|
50 |
+
- Managed with NBM, IV fluids
|
51 |
+
- CT AP and abdo USS: normal
|
52 |
+
- Likely secondary to Azathioprine - ceased, never to be used again.
|
53 |
+
- Resolved with conservative measures
|
54 |
+
"""
|
55 |
+
title = "Adverse Drug Reaction Highlighting"
|
56 |
+
description = "This app was made to accompany our recent paper. ADRs will be highlighted in purple, offending medications in green. Hover over a word to see the strength of each prediction on a 0-1 scale."
|
57 |
+
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
|
58 |
+
iface = gr.Interface(_gradio_highlighting,
|
59 |
+
[
|
60 |
+
gr.inputs.Textbox(
|
61 |
+
lines=7,
|
62 |
+
label="Text",
|
63 |
+
default=default_text),
|
64 |
+
],
|
65 |
+
gr.outputs.HTML(label="ADR Prediction"),
|
66 |
+
title = title,
|
67 |
+
description = description,
|
68 |
+
theme = "darkdefault")
|
69 |
+
iface.launch()
|