rolaser-demo / app.py
lydianish's picture
Update app.py
e70d647 verified
raw
history blame
3.71 kB
import os, sys
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
from sklearn.metrics.pairwise import paired_cosine_distances
from sklearn.preprocessing import normalize
from rolaser import RoLaserEncoder
laser_checkpoint = f"{os.environ['LASER']}/models/laser2.pt"
laser_vocab = f"{os.environ['LASER']}/models/laser2.cvocab"
laser_tokenizer = 'spm'
laser_model = RoLaserEncoder(model_path=laser_checkpoint, vocab=laser_vocab, tokenizer=laser_tokenizer)
rolaser_checkpoint = f"{os.environ['ROLASER']}/models/rolaser.pt"
rolaser_vocab = f"{os.environ['ROLASER']}/models/rolaser.cvocab"
rolaser_tokenizer = 'roberta'
rolaser_model = RoLaserEncoder(model_path=rolaser_checkpoint, vocab=rolaser_vocab, tokenizer=rolaser_tokenizer)
c_rolaser_checkpoint = f"{os.environ['ROLASER']}/models/c-rolaser.pt"
c_rolaser_vocab = f"{os.environ['ROLASER']}/models/c-rolaser.cvocab"
c_rolaser_tokenizer = 'char'
c_rolaser_model = RoLaserEncoder(model_path=c_rolaser_checkpoint, vocab=c_rolaser_vocab, tokenizer=c_rolaser_tokenizer)
STD_SENTENCES = ['See you tomorrow.'] * 10
UGC_SENTENCES = [
'See you tmrw.',
'See you t03orro3.',
'C. U. tomorrow.',
'sea you tomorrow.',
'See yo utomorrow.',
'See you tkmoerow.',
'Cu 2moro.',
'See yow tomorrow.',
'C. Yew tomorrow.',
'c ya 2morrow.'
]
def add_text_inputs(i):
col1, col2 = st.columns(2)
with col1:
text_input1 = st.text_input('Enter standard text here:', key=f'std{i}', value=STD_SENTENCES[i])
with col2:
text_input2 = st.text_input('Enter non-standard text here:', key=f'ugc{i}', value=UGC_SENTENCES[i])
return text_input1, text_input2
def main():
st.title('Pairwise Cosine Distance Calculator')
num_pairs = st.sidebar.number_input('Number of Text Input Pairs', min_value=1, max_value=10, value=5)
std_text_inputs = []
ugc_text_inputs = []
for i in range(num_pairs):
pair = add_text_inputs(i)
std_text_inputs.append(pair[0])
ugc_text_inputs.append(pair[1])
if st.button('Submit'):
X_std_laser = normalize(laser_model.encode(std_text_inputs))
X_ugc_laser = normalize(laser_model.encode(ugc_text_inputs))
X_cos_laser = paired_cosine_distances(X_std_laser, X_ugc_laser)
X_std_rolaser = normalize(rolaser_model.encode(std_text_inputs))
X_ugc_rolaser = normalize(rolaser_model.encode(ugc_text_inputs))
X_cos_rolaser = paired_cosine_distances(X_std_rolaser, X_ugc_rolaser)
X_std_c_rolaser = normalize(c_rolaser_model.encode(std_text_inputs))
X_ugc_c_rolaser = normalize(c_rolaser_model.encode(ugc_text_inputs))
X_cos_c_rolaser = paired_cosine_distances(X_std_c_rolaser, X_ugc_c_rolaser)
outputs = pd.DataFrame(columns=[ 'model', 'pair', 'ugc', 'std', 'cos'])
outputs['model'] = np.repeat(['LASER', 'RoLASER', 'c-RoLASER'], num_pairs)
outputs['pair'] = np.tile(np.arange(1,num_pairs+1), 3)
outputs['std'] = np.tile(std_text_inputs, 3)
outputs['ugc'] = np.tile(ugc_text_inputs, 3)
outputs['cos'] = np.concatenate([X_cos_laser, X_cos_rolaser, X_cos_c_rolaser])
st.write('## Cosine Distance Scores:')
fig = px.bar(outputs, x='x_column', y='y_column', color='model', barmode='group')
fig.update_layout(title='Cosine Distance Scores')
fig.update_xaxes(title_text='Text Input Pair')
fig.update_yaxes(title_text='Cosine Distance')
st.plotly_chart(fig, use_container_width=True)
st.write('## Average Cosine Distance Scores:')
st.write(outputs.groupby('model')['cos'].describe())
if __name__ == "__main__":
main()