fvancesco commited on
Commit
1d0c5c9
·
1 Parent(s): 5f81c24

update app, v1

Browse files
Files changed (2) hide show
  1. app.py +113 -16
  2. requirements.txt +0 -1
app.py CHANGED
@@ -1,23 +1,120 @@
 
 
 
 
1
  import gradio as gr
2
- import pandas as pd
3
- import seaborn as sns
4
  import matplotlib.pyplot as plt
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def plot_pens(alpha):
8
- df_pens = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv')
9
- fig = plt.figure()
10
- plt.scatter(x=df_pens['bill_length_mm'], y=df_pens['bill_depth_mm'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return fig
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- iface = gr.Interface(
15
- fn=plot_pens,
16
- layout='vertical',
17
- inputs=['checkbox'],
18
- outputs=['plot'],
19
- title="Scatterplot of Palmer Penguins",
20
- description="Let's talk pens.",
21
- article="Talk more about Penguins here, shall we?",
22
- theme='peach'
23
- ).launch()
 
1
+ import numpy as np
2
+ from transformers import pipeline
3
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
4
+
5
  import gradio as gr
 
 
6
  import matplotlib.pyplot as plt
7
 
8
+ dates = []
9
+ dates.extend([f"18 {m}" for m in range(1,13)])
10
+ dates.extend([f"19 {m}" for m in range(1,13)])
11
+ dates.extend([f"20 {m}" for m in range(1,13)])
12
+ dates.extend([f"21 {m}" for m in range(1,13)])
13
+
14
+ months = [x.split(" ")[-1] for x in dates]
15
+
16
+ model_name = "fvancesco/tmp_date"
17
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model.eval()
20
+ pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0)
21
+
22
+ last_mf_dict = None
23
+
24
+
25
+ def get_mf_dict(text):
26
+
27
+ # predictions
28
+ texts = []
29
+ for d in dates:
30
+ texts.append(f"{d} {text}")
31
+ tmp_preds = pipe(texts, top_k=50265)
32
+ preds = {}
33
+ for i in range(len(tmp_preds)):
34
+ preds[dates[i]] = tmp_preds[i]
35
+
36
+ # get preds summary (only top words)
37
+ top_n = 5 # top n for each prediction
38
+ most_freq_tokens = set()
39
+ for d in dates:
40
+ tmp = [t['token_str'] for t in preds[d][:top_n]]
41
+ most_freq_tokens.update(tmp)
42
+
43
+ token_prob = {}
44
+ for d in dates:
45
+ token_prob[d] = {p['token_str']:p['score'] for p in preds[d]}
46
+
47
+ mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens}
48
+
49
+ c=0
50
+ for d in dates:
51
+ for t in most_freq_tokens:
52
+ mf_dict[t][c] = token_prob[d][t]
53
+ c+=1
54
+
55
+ return mf_dict
56
 
57
+ def plot_time(text):
58
+ mf_dict = get_mf_dict(text)
59
+ #last_mf_dict = mf_dict # just for debugging, remove in final version
60
+
61
+ #max_tokens = 10
62
+
63
+ fig = plt.figure(figsize=(16,9))
64
+ ax = fig.add_subplot(111)
65
+ #fig, ax = plt.subplots(figsize=(16,9))
66
+
67
+ x = [i for i in range(len(dates))]
68
+
69
+ ax.set_xlabel('Month')
70
+ ax.set_xlim(0)
71
+ ax.set_xticks(x)
72
+ ax.set_xticklabels(months)
73
+ # ax.set_yticks([-1,0,1])
74
+
75
+ ax2 = ax.twiny()
76
+ ax2.set_xlabel('Year')
77
+ ax2.set_xlim(0)
78
+ ax2.set_xticks([0,12,24,36,47])
79
+
80
+ ax2.set_xticklabels('')
81
+ ax2.set_xticks([6,18,30,42,47], minor=True)
82
+ ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True)
83
+
84
+ ax2.grid()
85
+
86
+ # plot lines
87
+ for k in mf_dict.keys():
88
+ ax.plot(x, mf_dict[k], label = k)
89
+ # k = list(mf_dict.keys())
90
+ # for i in range(max_tokens):
91
+ # ax.plot(x, mf_dict[k[i]], label = k[i])
92
+
93
+ ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
94
+
95
  return fig
96
 
97
+ def add_mask(text):
98
+ out = ""
99
+ if len(text) == 0 or text[-1] == " ":
100
+ out = text+"<mask>"
101
+ else:
102
+ out = text+" <mask>"
103
+ return out
104
+
105
+ with gr.Blocks() as demo:
106
+ #textbox = gr.Textbox(placeholder="Type here and press enter...")
107
+ textbox = gr.Textbox(value="Happy <mask>!", max_lines=1)
108
+
109
+ with gr.Row():
110
+ generate_btn = gr.Button("Generate Plot")
111
+ mask_btn = gr.Button("Add <mask>")
112
+
113
+ plot = gr.Plot()
114
+
115
+ #textbox.change(fn=plot_time, inputs=textbox, outputs=plot)
116
+ #textbox.click(fn=plot_time, inputs=textbox, outputs=plot)
117
+ generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot)
118
+ mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox)
119
 
120
+ demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,2 +1 @@
1
- torch
2
  transformers
 
 
1
  transformers