lydianish commited on
Commit
f095c1c
·
verified ·
1 Parent(s): 5336a2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -83
app.py CHANGED
@@ -7,45 +7,45 @@ from sklearn.metrics.pairwise import paired_cosine_distances
7
  from sklearn.preprocessing import normalize
8
  from rolaser import RoLaserEncoder
9
 
10
- laser_checkpoint = f"{os.environ['LASER']}/models/laser2.pt"
11
- laser_vocab = f"{os.environ['LASER']}/models/laser2.cvocab"
12
- laser_tokenizer = 'spm'
13
- laser_model = RoLaserEncoder(model_path=laser_checkpoint, vocab=laser_vocab, tokenizer=laser_tokenizer)
14
-
15
- rolaser_checkpoint = f"{os.environ['ROLASER']}/models/rolaser.pt"
16
- rolaser_vocab = f"{os.environ['ROLASER']}/models/rolaser.cvocab"
17
- rolaser_tokenizer = 'roberta'
18
- rolaser_model = RoLaserEncoder(model_path=rolaser_checkpoint, vocab=rolaser_vocab, tokenizer=rolaser_tokenizer)
19
-
20
- c_rolaser_checkpoint = f"{os.environ['ROLASER']}/models/c-rolaser.pt"
21
- c_rolaser_vocab = f"{os.environ['ROLASER']}/models/c-rolaser.cvocab"
22
- c_rolaser_tokenizer = 'char'
23
- c_rolaser_model = RoLaserEncoder(model_path=c_rolaser_checkpoint, vocab=c_rolaser_vocab, tokenizer=c_rolaser_tokenizer)
24
-
25
-
26
- STD_SENTENCES = ['See you tomorrow.'] * 10
27
- UGC_SENTENCES = [
28
- 'See you t03orro3.',
29
- 'C. U. tomorrow.',
30
- 'sea you tomorrow.',
31
- 'See yo utomorrow.',
32
- 'See you tmrw.',
33
- 'See you tkmoerow.',
34
- 'Cu 2moro.',
35
- 'See yow tomorrow.',
36
- 'C. Yew tomorrow.',
37
- 'c ya 2morrow.'
38
- ]
39
-
40
- def add_text_inputs(i):
41
- col1, col2 = st.columns(2)
42
- with col1:
43
- text_input1 = st.text_input('Enter standard text here:', key=f'std{i}', value=STD_SENTENCES[i], label_visibility='collapsed')
44
- with col2:
45
- text_input2 = st.text_input('Enter non-standard text here:', key=f'ugc{i}', value=UGC_SENTENCES[i], label_visibility='collapsed')
46
- return text_input1, text_input2
47
 
48
  def main():
 
 
 
49
  st.title('Pairwise Cosine Distance Calculator')
50
 
51
  info = '''
@@ -59,52 +59,60 @@ def main():
59
 
60
  st.header('Standard and Non-standard Text Input Pairs:')
61
 
62
- num_pairs = st.sidebar.number_input('Number of Text Input Pairs', min_value=1, max_value=10, value=5)
63
-
64
- col1, col2 = st.columns(2)
65
- with col1:
66
- st.write('Enter standard text here:')
67
- with col2:
68
- st.write('Enter non-standard text here:')
69
-
70
- std_text_inputs = []
71
- ugc_text_inputs = []
72
- for i in range(num_pairs):
73
- pair = add_text_inputs(i)
74
- std_text_inputs.append(pair[0])
75
- ugc_text_inputs.append(pair[1])
76
-
77
- st.caption('*The models are case-insensitive: all text will be lowercased.*')
78
- if st.button('Submit'):
79
- X_std_laser = normalize(laser_model.encode(std_text_inputs))
80
- X_ugc_laser = normalize(laser_model.encode(ugc_text_inputs))
81
- X_cos_laser = paired_cosine_distances(X_std_laser, X_ugc_laser)
82
-
83
- X_std_rolaser = normalize(rolaser_model.encode(std_text_inputs))
84
- X_ugc_rolaser = normalize(rolaser_model.encode(ugc_text_inputs))
85
- X_cos_rolaser = paired_cosine_distances(X_std_rolaser, X_ugc_rolaser)
86
-
87
- X_std_c_rolaser = normalize(c_rolaser_model.encode(std_text_inputs))
88
- X_ugc_c_rolaser = normalize(c_rolaser_model.encode(ugc_text_inputs))
89
- X_cos_c_rolaser = paired_cosine_distances(X_std_c_rolaser, X_ugc_c_rolaser)
90
-
91
- outputs = pd.DataFrame(columns=[ 'model', 'pair', 'ugc', 'std', 'cos'])
92
- outputs['model'] = np.repeat(['LASER', 'RoLASER', 'c-RoLASER'], num_pairs)
93
- outputs['pair'] = np.tile(np.arange(1,num_pairs+1), 3)
94
- outputs['std'] = np.tile(std_text_inputs, 3)
95
- outputs['ugc'] = np.tile(ugc_text_inputs, 3)
96
- outputs['cos'] = np.concatenate([X_cos_laser, X_cos_rolaser, X_cos_c_rolaser])
97
-
98
- st.header('Cosine Distance Scores:')
99
- st.caption('*This bar plot is interactive: Hover on the bars to display values. Click on the legend items to filter models.*')
100
- fig = px.bar(outputs, x='pair', y='cos', color='model', barmode='group', hover_data=['ugc', 'std'])
101
- fig.update_xaxes(title_text='Text Input Pair')
102
- fig.update_yaxes(title_text='Cosine Distance')
103
- st.plotly_chart(fig, use_container_width=True)
104
-
105
- st.header('Average Cosine Distance Scores:')
106
- st.caption('*This data table is interactive: Click on a column header to sort values.*')
107
- st.write(outputs.groupby('model')['cos'].describe())
 
 
 
 
 
 
 
 
108
 
109
 
110
  if __name__ == "__main__":
 
7
  from sklearn.preprocessing import normalize
8
  from rolaser import RoLaserEncoder
9
 
10
+ @st.cache_resource(show_spinner=False)
11
+ def load_models():
12
+ laser_checkpoint = f"{os.environ['LASER']}/models/laser2.pt"
13
+ laser_vocab = f"{os.environ['LASER']}/models/laser2.cvocab"
14
+ laser_tokenizer = 'spm'
15
+ laser_model = RoLaserEncoder(model_path=laser_checkpoint, vocab=laser_vocab, tokenizer=laser_tokenizer)
16
+
17
+ rolaser_checkpoint = f"{os.environ['ROLASER']}/models/rolaser.pt"
18
+ rolaser_vocab = f"{os.environ['ROLASER']}/models/rolaser.cvocab"
19
+ rolaser_tokenizer = 'roberta'
20
+ rolaser_model = RoLaserEncoder(model_path=rolaser_checkpoint, vocab=rolaser_vocab, tokenizer=rolaser_tokenizer)
21
+
22
+ c_rolaser_checkpoint = f"{os.environ['ROLASER']}/models/c-rolaser.pt"
23
+ c_rolaser_vocab = f"{os.environ['ROLASER']}/models/c-rolaser.cvocab"
24
+ c_rolaser_tokenizer = 'char'
25
+ c_rolaser_model = RoLaserEncoder(model_path=c_rolaser_checkpoint, vocab=c_rolaser_vocab, tokenizer=c_rolaser_tokenizer)
26
+ return laser_model, rolaser_model, c_rolaser_model
27
+
28
+ @st.cache_data(show_spinner=False)
29
+ def load_sample_data():
30
+ STD_SENTENCES = ['See you tomorrow.'] * 10
31
+ UGC_SENTENCES = [
32
+ 'See you t03orro3.',
33
+ 'C. U. tomorrow.',
34
+ 'sea you tomorrow.',
35
+ 'See yo utomorrow.',
36
+ 'See you tmrw.',
37
+ 'See you tkmoerow.',
38
+ 'Cu 2moro.',
39
+ 'See yow tomorrow.',
40
+ 'C. Yew tomorrow.',
41
+ 'c ya 2morrow.'
42
+ ]
43
+ return STD_SENTENCES, UGC_SENTENCES
 
 
 
44
 
45
  def main():
46
+ sample_std, sample_ugc = load_sample_data()
47
+ laser_model, rolaser_model, c_rolaser_model = load_models()
48
+
49
  st.title('Pairwise Cosine Distance Calculator')
50
 
51
  info = '''
 
59
 
60
  st.header('Standard and Non-standard Text Input Pairs:')
61
 
62
+ num_pairs = st.sidebar.number_input('Number of Text Input Pairs (1-10)', min_value=1, max_value=10, value=5)
63
+
64
+ with st.form('text_input_form'):
65
+ col1, col2 = st.columns(2)
66
+ with col1:
67
+ st.write('Enter standard text here:')
68
+ with col2:
69
+ st.write('Enter non-standard text here:')
70
+
71
+ std_text_inputs = []
72
+ ugc_text_inputs = []
73
+
74
+ for i in range(num_pairs):
75
+ col1, col2 = st.columns(2)
76
+ with col1:
77
+ text_input1 = st.text_input('Enter standard text here:', key=f'std{i}', value=sample_std[i], label_visibility='collapsed')
78
+ std_text_inputs.append(text_input1)
79
+ with col2:
80
+ text_input2 = st.text_input('Enter non-standard text here:', key=f'ugc{i}', value=sample_ugc[i], label_visibility='collapsed')
81
+ ugc_text_inputs.append(text_input2)
82
+
83
+ st.caption('*The models are case-insensitive: all text will be lowercased.*')
84
+
85
+ st.form_submit_button('Compute')
86
+
87
+ X_std_laser = normalize(laser_model.encode(std_text_inputs))
88
+ X_ugc_laser = normalize(laser_model.encode(ugc_text_inputs))
89
+ X_cos_laser = paired_cosine_distances(X_std_laser, X_ugc_laser)
90
+
91
+ X_std_rolaser = normalize(rolaser_model.encode(std_text_inputs))
92
+ X_ugc_rolaser = normalize(rolaser_model.encode(ugc_text_inputs))
93
+ X_cos_rolaser = paired_cosine_distances(X_std_rolaser, X_ugc_rolaser)
94
+
95
+ X_std_c_rolaser = normalize(c_rolaser_model.encode(std_text_inputs))
96
+ X_ugc_c_rolaser = normalize(c_rolaser_model.encode(ugc_text_inputs))
97
+ X_cos_c_rolaser = paired_cosine_distances(X_std_c_rolaser, X_ugc_c_rolaser)
98
+
99
+ outputs = pd.DataFrame(columns=[ 'model', 'pair', 'ugc', 'std', 'cos'])
100
+ outputs['model'] = np.repeat(['LASER', 'RoLASER', 'c-RoLASER'], num_pairs)
101
+ outputs['pair'] = np.tile(np.arange(1,num_pairs+1), 3)
102
+ outputs['std'] = np.tile(std_text_inputs, 3)
103
+ outputs['ugc'] = np.tile(ugc_text_inputs, 3)
104
+ outputs['cos'] = np.concatenate([X_cos_laser, X_cos_rolaser, X_cos_c_rolaser])
105
+
106
+ st.header('Cosine Distance Scores:')
107
+ st.caption('*This bar plot is interactive: Hover on the bars to display values. Click on the legend items to filter models.*')
108
+ fig = px.bar(outputs, x='pair', y='cos', color='model', barmode='group', hover_data=['ugc', 'std'])
109
+ fig.update_xaxes(title_text='Text Input Pair')
110
+ fig.update_yaxes(title_text='Cosine Distance')
111
+ st.plotly_chart(fig, use_container_width=True)
112
+
113
+ st.header('Average Cosine Distance Scores:')
114
+ st.caption('*This data table is interactive: Click on a column header to sort values.*')
115
+ st.write(outputs.groupby('model')['cos'].describe())
116
 
117
 
118
  if __name__ == "__main__":