MEIRa / app.py
KawshikManikantan's picture
large_example
0d75905
raw
history blame
4.89 kB
import time
import spacy
import json
import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
import tempfile
from inference.model_inference import Inference
from configs import *
DESC_MD = """
<font size="3">
This space is a demo for <a href="https://arxiv.org/abs/2406.14654"> Major Entity Identification (MEI) </a>. MEI takes entities as additional input and aims to detect the mentions that refer only to these entities. <br/>
<br/>
Place the text in the text box with a single phrase of a selected entity in double curly braces(example: a single instance of {{Ron}} if you want to track Ron). Note that you can select one phrase for each entity and multiple entities can be selected. Check out the example below for clarity. <br/>
<br/>
Static: Uses an instance of: MEIRa-S model <br/>
Hybrid: Uses an instance of: MEIRa-H model <br/>
<br/>
The demo provides a json file with clusters and an HTML file with visualizations. The visualizations are color-coded based on the clusters. <br/>
</font>
"""
def get_MEIRa_clusters(doc_name, text, model_type):
model_str = MODELS[model_type]
model = Inference(model_str)
output_dict = model.perform_coreference(text, doc_name)
return output_dict
def coref_visualizer(doc_name, text, model_type):
coref_output = get_MEIRa_clusters(doc_name, text, model_type)
tokens = coref_output["tokenized_doc"]
clusters = coref_output["clusters"]
labels = coref_output["representative_names"]
## Get a pastel palette
color_palette = {
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
for i, label in enumerate(labels)
}
nlp = spacy.blank("en")
doc = Doc(nlp.vocab, words=tokens)
print("Tokens:", tokens, flush=True)
# print("Doc:", doc, flush=True)
print(color_palette)
spans = []
for cluster_ind, cluster in enumerate(clusters[:-1]):
label = labels[cluster_ind]
for (start, end), mention in cluster:
span = Span(doc, start, end + 1, label=label)
spans.append(span)
doc.spans["coref_spans"] = spans
print("Rendering the visualization...")
# color_map = {label: color_palette[i] for i, label in enumerate(labels)}
# Generate the HTML output
html = displacy.render(
doc,
style="span",
options={
"spans_key": "coref_spans",
"colors": color_palette,
},
jupyter=False,
)
## Create a hash based on time and doc_name
time_hash = hash(str(time.time()) + doc_name)
# html_file = f"temp/gradio_outputs/output_{time_hash}.html"
# json_file = f"temp/gradio_outputs/output_{time_hash}.json"
# Create a temporary HTML file
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as tmp_html_file:
html_file = tmp_html_file.name
tmp_html_file.write(html.encode("utf-8"))
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp_json_file:
json_file = tmp_json_file.name
tmp_json_file.write(json.dumps(coref_output).encode("utf-8"))
# with open(html_file, "w") as f:
# f.write(html)
# with open(json_file, "w") as f:
# json.dump(coref_output, f)
print("HTML file:", html_file)
print("JSON file:", json_file)
return (
html_file,
json_file,
gr.DownloadButton(value=html_file, visible=True),
gr.DownloadButton(value=json_file, visible=True),
)
def download_html():
return gr.DownloadButton(visible=False)
def download_json():
return gr.DownloadButton(visible=False)
with open("example_harry.txt", "r") as f:
example_harry = f.read()
options = ["static", "hybrid"]
with gr.Blocks() as demo:
html_file = gr.File(visible=False)
json_file = gr.File(visible=False)
html_button = gr.DownloadButton("Download HTML", visible=False)
json_button = gr.DownloadButton("Download JSON", visible=False)
html_button.click()
json_button.click()
iface = gr.Interface(
fn=coref_visualizer,
inputs=[
gr.Textbox(lines=1, placeholder="Enter document name:"),
gr.Textbox(lines=10, placeholder="Enter text for coreference resolution:"),
gr.Radio(choices=options, label="Select an Option"),
],
outputs=[
html_file,
json_file,
html_button,
json_button,
],
title="MEI Visualizer",
description=DESC_MD,
examples=[
[
"example",
"{{Harry}} went to Hogwarts to meet Hemoine and {{Ron}} . He also met Ron's mother at the railway station.",
"static",
],
["example_large", example_harry, "static"],
],
)
demo.launch(debug=True)