File size: 2,807 Bytes
6cf191b
 
 
 
 
 
f633022
 
 
 
6cf191b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import matplotlib.pyplot as plt

from inference import RelationsInference
from utils import KGType,Model_Type

#prep
import nltk
nltk.download('popular')

#############################
#   Constants
#############################

examples = [["What's the meaning of life?", "eli5", "constraint"],
            ["boat, water, bird", "commongen", "constraint"],
            ["What flows under a bridge?", "commonsense_qa", "constraint"]]

bart = RelationsInference(
    model_path='MrVicente/commonsense_bart_commongen',
    kg_type=KGType.CONCEPTNET,
    model_type=Model_Type.RELATIONS,
    max_length=32
)

#############################
#   Helper
#############################

def infer_bart(context, task_type, decoding_type_str):
    response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False)
    return response[0]


def plot_attention(layer, head):
    fig = plt.figure()
    plt.plot([1, 2, 3], [2, 4, 6])
    plt.title("Things")
    plt.ylabel("Cases")
    plt.xlabel("Days since Day 0")
    return fig


#############################
#   Interface
#############################

app = gr.Blocks()
with app:
    gr.Markdown(
        """
        # Demo
        ### Test Commonsense Relation-Aware BART (BART-RA) model
    
        Tutorial: <br>
            1) Select the possible model variations and tasks;<br>
            2) Change the inputs and Click the buttons to produce results;<br>
            3) See attention visualisations, by choosing a specific layer and head;<br>
        """)
    with gr.Row():
        context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
        model_result_output = gr.Textbox(lines=2, label='Model result:')
    with gr.Column():
        task_type_choice = gr.Radio(
            ["eli5", "commongen"], value="eli5", label="What task do you want to try?"
        )
        decoding_type_choice = gr.Radio(
            ["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
        )
    with gr.Row():
        model_btn = gr.Button(value="See Model Results")
    gr.Markdown(
        """
        ---
        Observe Attention
        """
    )
    with gr.Row():
        with gr.Column():
            layer = gr.Slider(0, 11, 0, step=1, label="Layer")
            head = gr.Slider(0, 15, 0, step=1, label="Head")
        with gr.Column():
            plot_output = gr.Plot()
    with gr.Row():
        vis_btn = gr.Button(value="See Attention Scores")
    model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
                    outputs=[model_result_output])
    vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output])

if __name__ == '__main__':
    app.launch()