Spaces:
Sleeping
Sleeping
import time | |
import spacy | |
import json | |
import gradio as gr | |
from spacy.tokens import Doc, Span | |
from spacy import displacy | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import to_hex | |
from inference.model_inference import Inference | |
from configs import * | |
def get_MEIRa_clusters(doc_name, text, model_type): | |
model_str = MODELS[model_type] | |
model = Inference(model_str) | |
output_dict = model.perform_coreference(text, doc_name) | |
return output_dict | |
def coref_visualizer(doc_name, text, model_type): | |
coref_output = get_MEIRa_clusters(doc_name, text, model_type) | |
tokens = coref_output["tokenized_doc"] | |
clusters = coref_output["clusters"] | |
labels = coref_output["representative_names"] | |
## Get a pastel palette | |
color_palette = { | |
label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i)) | |
for i, label in enumerate(labels) | |
} | |
nlp = spacy.blank("en") | |
doc = Doc(nlp.vocab, words=tokens) | |
print("Tokens:", tokens, flush=True) | |
# print("Doc:", doc, flush=True) | |
print(color_palette) | |
spans = [] | |
for cluster_ind, cluster in enumerate(clusters[:-1]): | |
label = labels[cluster_ind] | |
for (start, end), mention in cluster: | |
span = Span(doc, start, end + 1, label=label) | |
spans.append(span) | |
doc.spans["coref_spans"] = spans | |
print("Rendering the visualization...") | |
# color_map = {label: color_palette[i] for i, label in enumerate(labels)} | |
# Generate the HTML output | |
html = displacy.render( | |
doc, | |
style="span", | |
options={ | |
"spans_key": "coref_spans", | |
"colors": color_palette, | |
}, | |
jupyter=False, | |
) | |
## Create a hash based on time and doc_name | |
time_hash = hash(str(time.time()) + doc_name) | |
html_file = f"gradio_outputs/output_{time_hash}.html" | |
json_file = f"gradio_outputs/output_{time_hash}.json" | |
with open(html_file, "w") as f: | |
f.write(html) | |
with open(json_file, "w") as f: | |
json.dump(coref_output, f) | |
return ( | |
html_file, | |
json_file, | |
gr.DownloadButton(value=html_file, visible=True), | |
gr.DownloadButton(value=json_file, visible=True), | |
) | |
def download_html(): | |
return gr.DownloadButton(visible=False) | |
def download_json(): | |
return gr.DownloadButton(visible=False) | |
options = ["static", "hybrid"] | |
with gr.Blocks() as demo: | |
html_file = gr.File(visible=False) | |
json_file = gr.File(visible=False) | |
html_button = gr.DownloadButton("Download HTML", visible=False) | |
json_button = gr.DownloadButton("Download JSON", visible=False) | |
html_button.click() | |
json_button.click() | |
iface = gr.Interface( | |
fn=coref_visualizer, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Enter document name:"), | |
gr.Textbox(lines=100, placeholder="Enter text for coreference resolution:"), | |
gr.Radio(choices=options, label="Select an Option"), | |
], | |
outputs=[ | |
html_file, | |
json_file, | |
html_button, | |
json_button, | |
], | |
title="Coreference Resolution Visualizer", | |
) | |
demo.launch(debug=True) | |