MEIRa / app.py
KawshikManikantan's picture
upload_trial
98e2ea5
raw
history blame
3.18 kB
import time
import spacy
import json
import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
from inference.model_inference import Inference
from configs import *
def get_MEIRa_clusters(doc_name, text, model_type):
model_str = MODELS[model_type]
model = Inference(model_str)
output_dict = model.perform_coreference(text, doc_name)
return output_dict
def coref_visualizer(doc_name, text, model_type):
coref_output = get_MEIRa_clusters(doc_name, text, model_type)
tokens = coref_output["tokenized_doc"]
clusters = coref_output["clusters"]
labels = coref_output["representative_names"]
## Get a pastel palette
color_palette = {
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
for i, label in enumerate(labels)
}
nlp = spacy.blank("en")
doc = Doc(nlp.vocab, words=tokens)
print("Tokens:", tokens, flush=True)
# print("Doc:", doc, flush=True)
print(color_palette)
spans = []
for cluster_ind, cluster in enumerate(clusters[:-1]):
label = labels[cluster_ind]
for (start, end), mention in cluster:
span = Span(doc, start, end + 1, label=label)
spans.append(span)
doc.spans["coref_spans"] = spans
print("Rendering the visualization...")
# color_map = {label: color_palette[i] for i, label in enumerate(labels)}
# Generate the HTML output
html = displacy.render(
doc,
style="span",
options={
"spans_key": "coref_spans",
"colors": color_palette,
},
jupyter=False,
)
## Create a hash based on time and doc_name
time_hash = hash(str(time.time()) + doc_name)
html_file = f"gradio_outputs/output_{time_hash}.html"
json_file = f"gradio_outputs/output_{time_hash}.json"
with open(html_file, "w") as f:
f.write(html)
with open(json_file, "w") as f:
json.dump(coref_output, f)
return (
html_file,
json_file,
gr.DownloadButton(value=html_file, visible=True),
gr.DownloadButton(value=json_file, visible=True),
)
def download_html():
return gr.DownloadButton(visible=False)
def download_json():
return gr.DownloadButton(visible=False)
options = ["static", "hybrid"]
with gr.Blocks() as demo:
html_file = gr.File(visible=False)
json_file = gr.File(visible=False)
html_button = gr.DownloadButton("Download HTML", visible=False)
json_button = gr.DownloadButton("Download JSON", visible=False)
html_button.click()
json_button.click()
iface = gr.Interface(
fn=coref_visualizer,
inputs=[
gr.Textbox(lines=1, placeholder="Enter document name:"),
gr.Textbox(lines=100, placeholder="Enter text for coreference resolution:"),
gr.Radio(choices=options, label="Select an Option"),
],
outputs=[
html_file,
json_file,
html_button,
json_button,
],
title="Coreference Resolution Visualizer",
)
demo.launch(debug=True)