Christopher McMaster commited on
Commit
737007d
·
1 Parent(s): 56c6d77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.cm as cm
2
+ import html
3
+ import torch
4
+ import numpy as np
5
+ from transformers import pipeline
6
+ import gradio as gr
7
+ def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
8
+ "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
9
+ c = cmap(x)
10
+ rgb = (np.array(c[:-1]) * 255).astype(int)
11
+ a = c[-1] * alpha_mult
12
+ return tuple(rgb.tolist() + [a])
13
+ def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs):
14
+ html_code,spans = ['<span style="font-family: monospace;">'], []
15
+ for p, a, cols, l in zip(pieces, prob, colors, label):
16
+ p = html.escape(p)
17
+ c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs))
18
+ spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>')
19
+ html_code.append(sep.join(spans))
20
+ html_code.append('</span>')
21
+ return ''.join(html_code)
22
+ def nothing_ent(i, word):
23
+ return {
24
+ 'entity': 'O',
25
+ 'score': 0,
26
+ 'index': i,
27
+ 'word': word,
28
+ 'start': 0,
29
+ 'end': 0
30
+ }
31
+ def _gradio_highlighting(text):
32
+ result = ner_model(text)
33
+ tokens = ner_model.tokenizer.tokenize(text)
34
+ label_indeces = [i['index'] - 1 for i in result]
35
+ entities = list()
36
+ for i, word in enumerate(tokens):
37
+ if i in label_indeces:
38
+ entities.append(result[label_indeces.index(i)])
39
+ else:
40
+ entities.append(nothing_ent(i, word))
41
+ entities = ner_model.group_entities(entities)
42
+ spans = [e['word'] for e in entities]
43
+ probs = [e['score'] for e in entities]
44
+ labels = [e['entity_group'] for e in entities]
45
+ colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)]
46
+ return piece_prob_html(spans, probs, colors, labels, sep=' ')
47
+
48
+ default_text = """# Pancreatitis
49
+ - Lipase: 535 -> 154 -> 145
50
+ - Managed with NBM, IV fluids
51
+ - CT AP and abdo USS: normal
52
+ - Likely secondary to Azathioprine - ceased, never to be used again.
53
+ - Resolved with conservative measures
54
+ """
55
+ title = "Adverse Drug Reaction Highlighting"
56
+ description = "This app was made to accompany our recent paper. ADRs will be highlighted in purple, offending medications in green. Hover over a word to see the strength of each prediction on a 0-1 scale."
57
+ ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner")
58
+ iface = gr.Interface(_gradio_highlighting,
59
+ [
60
+ gr.inputs.Textbox(
61
+ lines=7,
62
+ label="Text",
63
+ default=default_text),
64
+ ],
65
+ gr.outputs.HTML(label="ADR Prediction"),
66
+ title = title,
67
+ description = description,
68
+ theme = "darkdefault")
69
+ iface.launch()