File size: 3,179 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import time
import spacy
import json

import gradio as gr
from spacy.tokens import Doc, Span
from spacy import displacy
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex

from inference.model_inference import Inference
from configs import *


def get_MEIRa_clusters(doc_name, text, model_type):
    model_str = MODELS[model_type]
    model = Inference(model_str)

    output_dict = model.perform_coreference(text, doc_name)
    return output_dict


def coref_visualizer(doc_name, text, model_type):
    coref_output = get_MEIRa_clusters(doc_name, text, model_type)

    tokens = coref_output["tokenized_doc"]
    clusters = coref_output["clusters"]
    labels = coref_output["representative_names"]

    ## Get a pastel palette
    color_palette = {
        label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
        for i, label in enumerate(labels)
    }

    nlp = spacy.blank("en")
    doc = Doc(nlp.vocab, words=tokens)

    print("Tokens:", tokens, flush=True)
    # print("Doc:", doc, flush=True)
    print(color_palette)

    spans = []
    for cluster_ind, cluster in enumerate(clusters[:-1]):
        label = labels[cluster_ind]
        for (start, end), mention in cluster:
            span = Span(doc, start, end + 1, label=label)
            spans.append(span)

    doc.spans["coref_spans"] = spans

    print("Rendering the visualization...")

    # color_map = {label: color_palette[i] for i, label in enumerate(labels)}
    # Generate the HTML output
    html = displacy.render(
        doc,
        style="span",
        options={
            "spans_key": "coref_spans",
            "colors": color_palette,
        },
        jupyter=False,
    )

    ## Create a hash based on time and doc_name
    time_hash = hash(str(time.time()) + doc_name)
    html_file = f"gradio_outputs/output_{time_hash}.html"
    json_file = f"gradio_outputs/output_{time_hash}.json"

    with open(html_file, "w") as f:
        f.write(html)

    with open(json_file, "w") as f:
        json.dump(coref_output, f)

    return (
        html_file,
        json_file,
        gr.DownloadButton(value=html_file, visible=True),
        gr.DownloadButton(value=json_file, visible=True),
    )


def download_html():
    return gr.DownloadButton(visible=False)


def download_json():
    return gr.DownloadButton(visible=False)


options = ["static", "hybrid"]

with gr.Blocks() as demo:
    html_file = gr.File(visible=False)
    json_file = gr.File(visible=False)
    html_button = gr.DownloadButton("Download HTML", visible=False)
    json_button = gr.DownloadButton("Download JSON", visible=False)

    html_button.click()
    json_button.click()
    iface = gr.Interface(
        fn=coref_visualizer,
        inputs=[
            gr.Textbox(lines=1, placeholder="Enter document name:"),
            gr.Textbox(lines=100, placeholder="Enter text for coreference resolution:"),
            gr.Radio(choices=options, label="Select an Option"),
        ],
        outputs=[
            html_file,
            json_file,
            html_button,
            json_button,
        ],
        title="Coreference Resolution Visualizer",
    )

demo.launch(debug=True)