test_time_1.1 / app.py
fvancesco
fix pickle load (plot)
483cccd
raw
history blame
3.29 kB
import numpy as np
import pickle
import urllib
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p"
dates = []
dates.extend([f"18 {m}" for m in range(1,13)])
dates.extend([f"19 {m}" for m in range(1,13)])
dates.extend([f"20 {m}" for m in range(1,13)])
dates.extend([f"21 {m}" for m in range(1,13)])
months = [x.split(" ")[-1] for x in dates]
model_name = "fvancesco/tmp_date"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
#pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer)
def get_mf_dict(text):
# predictions
texts = []
for d in dates:
texts.append(f"{d} {text}")
tmp_preds = pipe(texts, top_k=50265)
preds = {}
for i in range(len(tmp_preds)):
preds[dates[i]] = tmp_preds[i]
# get preds summary (only top words)
top_n = 5 # top n for each prediction
most_freq_tokens = set()
for d in dates:
tmp = [t['token_str'] for t in preds[d][:top_n]]
most_freq_tokens.update(tmp)
token_prob = {}
for d in dates:
token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}
mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}
c=0
for d in dates:
for t in most_freq_tokens:
mf_dict[t][c] = token_prob[d][t]
c+=1
return mf_dict
def plot_time(text):
mf_dict = get_mf_dict(text)
#max_tokens = 10
fig = plt.figure(figsize=(16,9))
ax = fig.add_subplot(111)
#fig, ax = plt.subplots(figsize=(16,9))
x = [i for i in range(len(dates))]
ax.set_xlabel('Month')
ax.set_xlim(0)
ax.set_xticks(x)
ax.set_xticklabels(months)
# ax.set_yticks([-1,0,1])
ax2 = ax.twiny()
ax2.set_xlabel('Year')
ax2.set_xlim(0)
ax2.set_xticks([0,12,24,36,47])
ax2.set_xticklabels('')
ax2.set_xticks([6,18,30,42,47], minor=True)
ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)
ax2.grid()
# plot lines
for k in mf_dict.keys():
ax.plot(x, mf_dict[k], label = k)
# k = list(mf_dict.keys())
# for i in range(max_tokens):
# ax.plot(x, mf_dict[k[i]], label = k[i])
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
return fig
def add_mask(text):
out = ""
if len(text) == 0 or text[-1] == " ":
out = text+"<mask>"
else:
out = text+" <mask>"
return out
with gr.Blocks() as demo:
textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)
with gr.Row():
generate_btn = gr.Button("Generate Plot")
mask_btn = gr.Button("Add <mask>")
# plot (with starting example already loaded)
f = urllib.request.urlopen(plot_url)
plot_example = pickle.load(f)
plot = gr.Plot(plot_example)
#textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)
demo.launch(debug=True)