Spaces:
Runtime error
Runtime error
import matplotlib.cm as cm | |
import html | |
import torch | |
import numpy as np | |
from transformers import pipeline | |
import gradio as gr | |
def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0): | |
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`." | |
c = cmap(x) | |
rgb = (np.array(c[:-1]) * 255).astype(int) | |
a = c[-1] * alpha_mult | |
return tuple(rgb.tolist() + [a]) | |
def piece_prob_html(pieces, prob, colors, label, sep=' ', **kwargs): | |
html_code,spans = ['<span style="font-family: monospace;">'], [] | |
for p, a, cols, l in zip(pieces, prob, colors, label): | |
p = html.escape(p) | |
c = str(value2rgba(a, cmap=cols, alpha_mult=0.5, **kwargs)) | |
spans.append(f'<span title="{l}: {a:.3f}" style="background-color: rgba{c};">{p}</span>') | |
html_code.append(sep.join(spans)) | |
html_code.append('</span>') | |
return ''.join(html_code) | |
def nothing_ent(i, word): | |
return { | |
'entity': 'O', | |
'score': 0, | |
'index': i, | |
'word': word, | |
'start': 0, | |
'end': 0 | |
} | |
def _gradio_highlighting(text): | |
result = ner_model(text) | |
tokens = ner_model.tokenizer.tokenize(text) | |
label_indeces = [i['index'] - 1 for i in result] | |
entities = list() | |
for i, word in enumerate(tokens): | |
if i in label_indeces: | |
entities.append(result[label_indeces.index(i)]) | |
else: | |
entities.append(nothing_ent(i, word)) | |
entities = ner_model.group_entities(entities) | |
spans = [e['word'] for e in entities] | |
probs = [e['score'] for e in entities] | |
labels = [e['entity_group'] for e in entities] | |
colors = [cm.RdPu if label == 'ADR' else cm.YlGn for i, label in enumerate(labels)] | |
return piece_prob_html(spans, probs, colors, labels, sep=' ') | |
default_text = """# Pancreatitis | |
- Lipase: 535 -> 154 -> 145 | |
- Managed with NBM, IV fluids | |
- CT AP and abdo USS: normal | |
- Likely secondary to Azathioprine - ceased, never to be used again. | |
- Resolved with conservative measures | |
""" | |
title = "Adverse Drug Reaction Highlighting" | |
description = "Named Entity Recognition model to detect ADRs in discharge summaries" | |
article = """This app was made to accompany our recent [paper](https://www.medrxiv.org/content/10.1101/2021.12.11.21267504v2).<br> | |
ADRs will be highlighted in <p style="color:purple">purple</p>, offending medications in <p style="color:green">green</p>.<br> | |
Hover over a word to see the strength of each prediction on a 0-1 scale.<br> | |
Our code can be found at [github](https://github.com/AustinMOS/adr-nlp). | |
""" | |
ner_model = pipeline(task = 'token-classification', model = "austin/adr-ner") | |
iface = gr.Interface(_gradio_highlighting, | |
[ | |
gr.inputs.Textbox( | |
lines=7, | |
label="Text", | |
default=default_text), | |
], | |
gr.outputs.HTML(label="ADR Prediction"), | |
title = title, | |
description = description, | |
article = article, | |
theme = "huggingface" | |
) | |
iface.launch() |