Spaces:
Sleeping
Sleeping
File size: 4,006 Bytes
1d0c5c9 3c16c9f 483cccd 1d0c5c9 5f81c24 483cccd 1d0c5c9 0487285 1d0c5c9 5f81c24 1d0c5c9 5f81c24 1d0c5c9 483cccd e08b566 1d0c5c9 483cccd 3c16c9f 1d0c5c9 5f81c24 1d0c5c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import pickle
import urllib
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p"
dates = []
dates.extend([f"18 {m}" for m in range(1,13)])
dates.extend([f"19 {m}" for m in range(1,13)])
dates.extend([f"20 {m}" for m in range(1,13)])
dates.extend([f"21 {m}" for m in range(1,13)])
months = [x.split(" ")[-1] for x in dates]
model_name = "fvancesco/tmp_date"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
#pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer)
def get_mf_dict(text):
# predictions
texts = []
for d in dates:
texts.append(f"{d} {text}")
tmp_preds = pipe(texts, top_k=50265)
preds = {}
for i in range(len(tmp_preds)):
preds[dates[i]] = tmp_preds[i]
# get preds summary (only top words)
top_n = 5 # top n for each prediction
most_freq_tokens = set()
for d in dates:
tmp = [t['token_str'] for t in preds[d][:top_n]]
most_freq_tokens.update(tmp)
token_prob = {}
for d in dates:
token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}
mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}
c=0
for d in dates:
for t in most_freq_tokens:
mf_dict[t][c] = token_prob[d][t]
c+=1
return mf_dict
def plot_time(text):
mf_dict = get_mf_dict(text)
#max_tokens = 10
fig = plt.figure(figsize=(16,9))
ax = fig.add_subplot(111)
#fig, ax = plt.subplots(figsize=(16,9))
x = [i for i in range(len(dates))]
ax.set_xlabel('Month')
ax.set_xlim(0)
ax.set_xticks(x)
ax.set_xticklabels(months)
# ax.set_yticks([-1,0,1])
ax2 = ax.twiny()
ax2.set_xlabel('Year')
ax2.set_xlim(0)
ax2.set_xticks([0,12,24,36,47])
ax2.set_xticklabels('')
ax2.set_xticks([6,18,30,42,47], minor=True)
ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)
ax2.grid()
# plot lines
for k in mf_dict.keys():
ax.plot(x, mf_dict[k], label = k)
# k = list(mf_dict.keys())
# for i in range(max_tokens):
# ax.plot(x, mf_dict[k[i]], label = k[i])
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
return fig
def add_mask(text):
out = ""
if len(text) == 0 or text[-1] == " ":
out = text+"<mask>"
else:
out = text+" <mask>"
return out
with gr.Blocks() as demo:
text_description="""
# TimeLMs Demo
This is a demo for **timeLMs**:
- [Github](https://github.com/cardiffnlp/timelms)
- [Paper](https://aclanthology.org/2022.acl-demo.25.pdf)
Input any text with a *\<mask\>* token as in the example, and (the demo does not
use GPUs, and it takes about 1 min). In the graph, we show the probability of
some token candidates for mask over different months.
In this demo we run use a roberta-base model trained on tweets, where the first two
tokens are the year and the month ("21 1" for January 2021). It was trained
for tweets between January 2018 to December 2021).
"""
description = gr.Markdown(text_description)
textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)
with gr.Row():
generate_btn = gr.Button("Generate Plot")
mask_btn = gr.Button("Add <mask>")
# plot (with starting example already loaded)
f = urllib.request.urlopen(plot_url)
plot_example = pickle.load(f)
plot = gr.Plot(plot_example)
#textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)
demo.launch(debug=True) |