Spaces:
Sleeping
Sleeping
File size: 3,286 Bytes
1d0c5c9 3c16c9f 483cccd 1d0c5c9 5f81c24 483cccd 1d0c5c9 0487285 1d0c5c9 5f81c24 1d0c5c9 5f81c24 1d0c5c9 483cccd 1d0c5c9 483cccd 3c16c9f 1d0c5c9 5f81c24 1d0c5c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import numpy as np
import pickle
import urllib
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p"
dates = []
dates.extend([f"18 {m}" for m in range(1,13)])
dates.extend([f"19 {m}" for m in range(1,13)])
dates.extend([f"20 {m}" for m in range(1,13)])
dates.extend([f"21 {m}" for m in range(1,13)])
months = [x.split(" ")[-1] for x in dates]
model_name = "fvancesco/tmp_date"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
#pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer)
def get_mf_dict(text):
# predictions
texts = []
for d in dates:
texts.append(f"{d} {text}")
tmp_preds = pipe(texts, top_k=50265)
preds = {}
for i in range(len(tmp_preds)):
preds[dates[i]] = tmp_preds[i]
# get preds summary (only top words)
top_n = 5 # top n for each prediction
most_freq_tokens = set()
for d in dates:
tmp = [t['token_str'] for t in preds[d][:top_n]]
most_freq_tokens.update(tmp)
token_prob = {}
for d in dates:
token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}
mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}
c=0
for d in dates:
for t in most_freq_tokens:
mf_dict[t][c] = token_prob[d][t]
c+=1
return mf_dict
def plot_time(text):
mf_dict = get_mf_dict(text)
#max_tokens = 10
fig = plt.figure(figsize=(16,9))
ax = fig.add_subplot(111)
#fig, ax = plt.subplots(figsize=(16,9))
x = [i for i in range(len(dates))]
ax.set_xlabel('Month')
ax.set_xlim(0)
ax.set_xticks(x)
ax.set_xticklabels(months)
# ax.set_yticks([-1,0,1])
ax2 = ax.twiny()
ax2.set_xlabel('Year')
ax2.set_xlim(0)
ax2.set_xticks([0,12,24,36,47])
ax2.set_xticklabels('')
ax2.set_xticks([6,18,30,42,47], minor=True)
ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)
ax2.grid()
# plot lines
for k in mf_dict.keys():
ax.plot(x, mf_dict[k], label = k)
# k = list(mf_dict.keys())
# for i in range(max_tokens):
# ax.plot(x, mf_dict[k[i]], label = k[i])
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
return fig
def add_mask(text):
out = ""
if len(text) == 0 or text[-1] == " ":
out = text+"<mask>"
else:
out = text+" <mask>"
return out
with gr.Blocks() as demo:
textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)
with gr.Row():
generate_btn = gr.Button("Generate Plot")
mask_btn = gr.Button("Add <mask>")
# plot (with starting example already loaded)
f = urllib.request.urlopen(plot_url)
plot_example = pickle.load(f)
plot = gr.Plot(plot_example)
#textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)
demo.launch(debug=True) |