Spaces:
Sleeping
Sleeping
fvancesco
commited on
Commit
·
1d0c5c9
1
Parent(s):
5f81c24
update app, v1
Browse files- app.py +113 -16
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,23 +1,120 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import pandas as pd
|
3 |
-
import seaborn as sns
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
def
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
return fig
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
fn=plot_pens,
|
16 |
-
layout='vertical',
|
17 |
-
inputs=['checkbox'],
|
18 |
-
outputs=['plot'],
|
19 |
-
title="Scatterplot of Palmer Penguins",
|
20 |
-
description="Let's talk pens.",
|
21 |
-
article="Talk more about Penguins here, shall we?",
|
22 |
-
theme='peach'
|
23 |
-
).launch()
|
|
|
1 |
+
import numpy as np
|
2 |
+
from transformers import pipeline
|
3 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
4 |
+
|
5 |
import gradio as gr
|
|
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
|
8 |
+
dates = []
|
9 |
+
dates.extend([f"18 {m}" for m in range(1,13)])
|
10 |
+
dates.extend([f"19 {m}" for m in range(1,13)])
|
11 |
+
dates.extend([f"20 {m}" for m in range(1,13)])
|
12 |
+
dates.extend([f"21 {m}" for m in range(1,13)])
|
13 |
+
|
14 |
+
months = [x.split(" ")[-1] for x in dates]
|
15 |
+
|
16 |
+
model_name = "fvancesco/tmp_date"
|
17 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
+
model.eval()
|
20 |
+
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
|
21 |
+
|
22 |
+
last_mf_dict = None
|
23 |
+
|
24 |
+
|
25 |
+
def get_mf_dict(text):
|
26 |
+
|
27 |
+
# predictions
|
28 |
+
texts = []
|
29 |
+
for d in dates:
|
30 |
+
texts.append(f"{d} {text}")
|
31 |
+
tmp_preds = pipe(texts, top_k=50265)
|
32 |
+
preds = {}
|
33 |
+
for i in range(len(tmp_preds)):
|
34 |
+
preds[dates[i]] = tmp_preds[i]
|
35 |
+
|
36 |
+
# get preds summary (only top words)
|
37 |
+
top_n = 5 # top n for each prediction
|
38 |
+
most_freq_tokens = set()
|
39 |
+
for d in dates:
|
40 |
+
tmp = [t['token_str'] for t in preds[d][:top_n]]
|
41 |
+
most_freq_tokens.update(tmp)
|
42 |
+
|
43 |
+
token_prob = {}
|
44 |
+
for d in dates:
|
45 |
+
token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}
|
46 |
+
|
47 |
+
mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}
|
48 |
+
|
49 |
+
c=0
|
50 |
+
for d in dates:
|
51 |
+
for t in most_freq_tokens:
|
52 |
+
mf_dict[t][c] = token_prob[d][t]
|
53 |
+
c+=1
|
54 |
+
|
55 |
+
return mf_dict
|
56 |
|
57 |
+
def plot_time(text):
|
58 |
+
mf_dict = get_mf_dict(text)
|
59 |
+
#last_mf_dict = mf_dict # just for debugging, remove in final version
|
60 |
+
|
61 |
+
#max_tokens = 10
|
62 |
+
|
63 |
+
fig = plt.figure(figsize=(16,9))
|
64 |
+
ax = fig.add_subplot(111)
|
65 |
+
#fig, ax = plt.subplots(figsize=(16,9))
|
66 |
+
|
67 |
+
x = [i for i in range(len(dates))]
|
68 |
+
|
69 |
+
ax.set_xlabel('Month')
|
70 |
+
ax.set_xlim(0)
|
71 |
+
ax.set_xticks(x)
|
72 |
+
ax.set_xticklabels(months)
|
73 |
+
# ax.set_yticks([-1,0,1])
|
74 |
+
|
75 |
+
ax2 = ax.twiny()
|
76 |
+
ax2.set_xlabel('Year')
|
77 |
+
ax2.set_xlim(0)
|
78 |
+
ax2.set_xticks([0,12,24,36,47])
|
79 |
+
|
80 |
+
ax2.set_xticklabels('')
|
81 |
+
ax2.set_xticks([6,18,30,42,47], minor=True)
|
82 |
+
ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)
|
83 |
+
|
84 |
+
ax2.grid()
|
85 |
+
|
86 |
+
# plot lines
|
87 |
+
for k in mf_dict.keys():
|
88 |
+
ax.plot(x, mf_dict[k], label = k)
|
89 |
+
# k = list(mf_dict.keys())
|
90 |
+
# for i in range(max_tokens):
|
91 |
+
# ax.plot(x, mf_dict[k[i]], label = k[i])
|
92 |
+
|
93 |
+
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
|
94 |
+
|
95 |
return fig
|
96 |
|
97 |
+
def add_mask(text):
|
98 |
+
out = ""
|
99 |
+
if len(text) == 0 or text[-1] == " ":
|
100 |
+
out = text+"<mask>"
|
101 |
+
else:
|
102 |
+
out = text+" <mask>"
|
103 |
+
return out
|
104 |
+
|
105 |
+
with gr.Blocks() as demo:
|
106 |
+
#textbox = gr.Textbox(placeholder="Type here and press enter...")
|
107 |
+
textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
generate_btn = gr.Button("Generate Plot")
|
111 |
+
mask_btn = gr.Button("Add <mask>")
|
112 |
+
|
113 |
+
plot = gr.Plot()
|
114 |
+
|
115 |
+
#textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
|
116 |
+
#textbox.click(fn=plot_time, inputs=textbox, outputs=plot)
|
117 |
+
generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
|
118 |
+
mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)
|
119 |
|
120 |
+
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
torch
|
2 |
transformers
|
|
|
|
|
1 |
transformers
|