import numpy as np import pickle import urllib from transformers import pipeline from transformers import AutoModelForMaskedLM, AutoTokenizer import gradio as gr import matplotlib.pyplot as plt plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p" dates = [] dates.extend([f"18 {m}" for m in range(1,13)]) dates.extend([f"19 {m}" for m in range(1,13)]) dates.extend([f"20 {m}" for m in range(1,13)]) dates.extend([f"21 {m}" for m in range(1,13)]) months = [x.split(" ")[-1] for x in dates] model_name = "fvancesco/tmp_date" model = AutoModelForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model.eval() #pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0) pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer) def get_mf_dict(text): # predictions texts = [] for d in dates: texts.append(f"{d} {text}") tmp_preds = pipe(texts, top_k=50265) preds = {} for i in range(len(tmp_preds)): preds[dates[i]] = tmp_preds[i] # get preds summary (only top words) top_n = 5 # top n for each prediction most_freq_tokens = set() for d in dates: tmp = [t['token_str'] for t in preds[d][:top_n]] most_freq_tokens.update(tmp) token_prob = {} for d in dates: token_prob[d] = {p['token_str']:p['score'] for p in preds[d]} mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens} c=0 for d in dates: for t in most_freq_tokens: mf_dict[t][c] = token_prob[d][t] c+=1 return mf_dict def plot_time(text): mf_dict = get_mf_dict(text) #max_tokens = 10 fig = plt.figure(figsize=(16,9)) ax = fig.add_subplot(111) #fig, ax = plt.subplots(figsize=(16,9)) x = [i for i in range(len(dates))] ax.set_xlabel('Month') ax.set_xlim(0) ax.set_xticks(x) ax.set_xticklabels(months) # ax.set_yticks([-1,0,1]) ax2 = ax.twiny() ax2.set_xlabel('Year') ax2.set_xlim(0) ax2.set_xticks([0,12,24,36,47]) ax2.set_xticklabels('') ax2.set_xticks([6,18,30,42,47], minor=True) ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True) ax2.grid() # plot lines for k in mf_dict.keys(): ax.plot(x, mf_dict[k], label = k) # k = list(mf_dict.keys()) # for i in range(max_tokens): # ax.plot(x, mf_dict[k[i]], label = k[i]) ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5)) return fig def add_mask(text): out = "" if len(text) == 0 or text[-1] == " ": out = text+"" else: out = text+" " return out with gr.Blocks() as demo: text_description=""" # TimeLMs Demo This is a demo for **timeLMs**: - [Github](https://github.com/cardiffnlp/timelms) - [Paper](https://aclanthology.org/2022.acl-demo.25.pdf) Input any text with a *\* token as in the example, and (the demo does not use GPUs, and it takes about 1 min). In the graph, we show the probability of some token candidates for mask over different months. In this demo we run use a roberta-base model trained on tweets, where the first two tokens are the year and the month ("21 1" for January 2021). It was trained for tweets between January 2018 to December 2021). """ description = gr.Markdown(text_description) textbox = gr.Textbox(value="Happy !", max_lines=1) with gr.Row(): generate_btn = gr.Button("Generate Plot") mask_btn = gr.Button("Add ") # plot (with starting example already loaded) f = urllib.request.urlopen(plot_url) plot_example = pickle.load(f) plot = gr.Plot(plot_example) #textbox.change(fn=plot_time, inputs=textbox, outputs=plot) generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot) mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox) demo.launch(debug=True)