File size: 4,006 Bytes
1d0c5c9
3c16c9f
483cccd
1d0c5c9
 
 
5f81c24
 
 
483cccd
 
1d0c5c9
 
 
 
 
 
 
 
 
 
 
0487285
 
1d0c5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f81c24
1d0c5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f81c24
 
1d0c5c9
 
 
 
 
 
 
 
 
483cccd
e08b566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d0c5c9
 
 
 
 
 
483cccd
 
 
3c16c9f
1d0c5c9
 
 
 
5f81c24
1d0c5c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import pickle
import urllib
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer

import gradio as gr
import matplotlib.pyplot as plt

plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p"

dates = []
dates.extend([f"18 {m}" for m in range(1,13)])
dates.extend([f"19 {m}" for m in range(1,13)])
dates.extend([f"20 {m}" for m in range(1,13)])
dates.extend([f"21 {m}" for m in range(1,13)])
months = [x.split(" ")[-1] for x in dates]

model_name = "fvancesco/tmp_date"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
#pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer)


def get_mf_dict(text):

    # predictions
    texts = []
    for d in dates:
        texts.append(f"{d} {text}")
    tmp_preds = pipe(texts, top_k=50265)
    preds = {}
    for i in range(len(tmp_preds)):
        preds[dates[i]] = tmp_preds[i]

    # get preds summary (only top words)
    top_n = 5 # top n for each prediction
    most_freq_tokens = set()
    for d in dates:
        tmp = [t['token_str'] for t in preds[d][:top_n]]
        most_freq_tokens.update(tmp)

    token_prob = {}
    for d in dates:
        token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}

    mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}

    c=0
    for d in dates:
        for t in most_freq_tokens:
            mf_dict[t][c] = token_prob[d][t]
        c+=1

    return mf_dict

def plot_time(text):
    mf_dict = get_mf_dict(text)
    
    #max_tokens = 10

    fig = plt.figure(figsize=(16,9))
    ax = fig.add_subplot(111)
    #fig, ax = plt.subplots(figsize=(16,9))

    x = [i for i in range(len(dates))]

    ax.set_xlabel('Month')
    ax.set_xlim(0)
    ax.set_xticks(x)
    ax.set_xticklabels(months)
    # ax.set_yticks([-1,0,1])

    ax2 = ax.twiny()
    ax2.set_xlabel('Year')
    ax2.set_xlim(0)
    ax2.set_xticks([0,12,24,36,47])

    ax2.set_xticklabels('')
    ax2.set_xticks([6,18,30,42,47], minor=True)
    ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)

    ax2.grid()

    # plot lines
    for k in mf_dict.keys():
        ax.plot(x, mf_dict[k], label = k)
    # k = list(mf_dict.keys())
    # for i in range(max_tokens):
    #     ax.plot(x, mf_dict[k[i]], label = k[i])

    ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
    
    return fig

def add_mask(text):
    out = ""
    if len(text) == 0 or text[-1] == " ": 
        out = text+"<mask>"
    else: 
        out = text+" <mask>"
    return out

with gr.Blocks() as demo:
    
    text_description="""
    # TimeLMs Demo

    This is a demo for **timeLMs**:
    - [Github](https://github.com/cardiffnlp/timelms)
    - [Paper](https://aclanthology.org/2022.acl-demo.25.pdf)

    Input any text with a *\<mask\>* token as in the example, and (the demo does not
    use GPUs, and it takes about 1 min). In the graph, we show the probability of 
    some token candidates for mask over different months. 
    
    In this demo we run use a roberta-base model trained on tweets, where the first two 
    tokens are the year and the month ("21 1" for January 2021). It was trained 
    for tweets between January 2018 to December 2021).
    """
    
    description = gr.Markdown(text_description)
    
    textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)

    with gr.Row():
        generate_btn = gr.Button("Generate Plot")
        mask_btn = gr.Button("Add <mask>")

    # plot (with starting example already loaded)
    f = urllib.request.urlopen(plot_url)
    plot_example = pickle.load(f)
    plot = gr.Plot(plot_example)
    
    #textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
    generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
    mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)

demo.launch(debug=True)