Spaces:
Sleeping
Sleeping
File size: 5,562 Bytes
723d6ec 620cefc e70d647 9c18c2f 620cefc f095c1c e2180fd f095c1c e2180fd f095c1c 620cefc f095c1c 620cefc 7dc3119 3a47baa 5ee0be5 7dc3119 eec55b9 7dc3119 eec55b9 f095c1c 1239ceb f095c1c 1239ceb f095c1c 2c4e762 f095c1c eec55b9 2c4e762 f095c1c 9605440 c2cb68d 620cefc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics.pairwise import paired_cosine_distances
from sklearn.preprocessing import normalize
from rolaser import RoLaserEncoder
@st.cache_resource(show_spinner=False)
def load_models():
laser_checkpoint = f"{os.environ['LASER']}/models/laser2.pt"
laser_vocab = f"{os.environ['LASER']}/models/laser2.cvocab"
laser_tokenizer = 'spm'
laser_model = RoLaserEncoder(model_path=laser_checkpoint, vocab=laser_vocab, tokenizer=laser_tokenizer)
rolaser_checkpoint = f"{os.environ['ROLASER']}/models/rolaser.pt"
rolaser_vocab = f"{os.environ['ROLASER']}/models/rolaser.cvocab"
rolaser_tokenizer = 'roberta'
rolaser_model = RoLaserEncoder(model_path=rolaser_checkpoint, vocab=rolaser_vocab, tokenizer=rolaser_tokenizer)
c_rolaser_checkpoint = f"{os.environ['ROLASER']}/models/c-rolaser.pt"
c_rolaser_vocab = f"{os.environ['ROLASER']}/models/c-rolaser.cvocab"
c_rolaser_tokenizer = 'char'
c_rolaser_model = RoLaserEncoder(model_path=c_rolaser_checkpoint, vocab=c_rolaser_vocab, tokenizer=c_rolaser_tokenizer)
return laser_model, rolaser_model, c_rolaser_model
@st.cache_data(show_spinner=False)
def load_sample_data():
STD_SENTENCES = ['See you tomorrow.'] * 10
UGC_SENTENCES = [
'See you t03orro3.',
'C. U. tomorrow.',
'sea you tomorrow.',
'See yo utomorrow.',
'Cu 2moro.',
'See you tkmoerow.',
'See yow tomorrow.',
'See you tmrw.',
'C. Yew tomorrow.',
'c ya 2morrow.'
]
return STD_SENTENCES, UGC_SENTENCES
def main():
sample_std, sample_ugc = load_sample_data()
laser_model, rolaser_model, c_rolaser_model = load_models()
st.title('Pairwise Cosine Distance Calculator')
info = '''
:bookmark: **Paper:** [Making Sentence Embeddings Robust to User-Generated Content (Nishimwe et al., 2024)](https://arxiv.org/abs/2403.17220)
:link: **Github:** [https://github.com/lydianish/RoLASER](https://github.com/lydianish/RoLASER)
:computer: **Demo:** This app computes the cosine distance between standard and non-standard text input pairs using LASER, RoLASER, and c-RoLASER models.
'''
st.markdown(info)
st.header('Standard and Non-standard Text Input Pairs')
cols = st.columns(3)
num_pairs = cols[1].number_input('Number of Text Input Pairs (1-10):', min_value=1, max_value=10, value=5)
with st.form('text_input_form'):
std_text_inputs = []
ugc_text_inputs = []
for i in range(num_pairs):
col1, col2 = st.columns(2)
with col1:
text_input1 = st.text_input(f'Standard text {i+1}:', key=f'std{i}', value=sample_std[i])
std_text_inputs.append(text_input1)
with col2:
text_input2 = st.text_input(f'Non-standard text {i+1}:', key=f'ugc{i}', value=sample_ugc[i])
ugc_text_inputs.append(text_input2)
st.caption('*The models are case-insensitive: all texts will be lowercased.*')
st.form_submit_button('Compute')
X_std_laser = normalize(laser_model.encode(std_text_inputs))
X_ugc_laser = normalize(laser_model.encode(ugc_text_inputs))
X_cos_laser = paired_cosine_distances(X_std_laser, X_ugc_laser)
X_std_rolaser = normalize(rolaser_model.encode(std_text_inputs))
X_ugc_rolaser = normalize(rolaser_model.encode(ugc_text_inputs))
X_cos_rolaser = paired_cosine_distances(X_std_rolaser, X_ugc_rolaser)
X_std_c_rolaser = normalize(c_rolaser_model.encode(std_text_inputs))
X_ugc_c_rolaser = normalize(c_rolaser_model.encode(ugc_text_inputs))
X_cos_c_rolaser = paired_cosine_distances(X_std_c_rolaser, X_ugc_c_rolaser)
outputs = pd.DataFrame(columns=[ 'model', 'pair', 'ugc', 'std', 'cos'])
outputs['model'] = np.repeat(['LASER', 'RoLASER', 'c-RoLASER'], num_pairs)
outputs['pair'] = np.tile(np.arange(1,num_pairs+1), 3)
outputs['std'] = np.tile(std_text_inputs, 3)
outputs['ugc'] = np.tile(ugc_text_inputs, 3)
outputs['cos'] = np.concatenate([X_cos_laser, X_cos_rolaser, X_cos_c_rolaser])
st.header('Cosine Distance Scores')
st.caption('*This bar plot is interactive: Hover on the bars to display values. Click on the legend items to filter models.*')
fig = px.bar(outputs, x='pair', y='cos', color='model', barmode='group', hover_data=['ugc', 'std'])
fig.update_xaxes(title_text='Text Input Pair')
fig.update_yaxes(title_text='Cosine Distance')
st.plotly_chart(fig, use_container_width=True)
if num_pairs > 1:
st.header('Cosine Distance Statistics')
st.caption('*This box plot is interactive: Hover on the boxes to display values. Click on the legend items to filter models.*')
fig = go.Figure()
fig.add_trace(go.Box(
y=outputs[outputs['model']=='LASER']['cos'],
name='LASER',
boxmean='sd'
))
fig.add_trace(go.Box(
y=outputs[outputs['model']=='RoLASER']['cos'],
name='RoLASER',
boxmean='sd'
))
fig.add_trace(go.Box(
y=outputs[outputs['model']=='c-RoLASER']['cos'],
name='c-RoLASER',
boxmean='sd'
))
fig.update_xaxes(title_text='Model')
fig.update_yaxes(title_text='Cosine Distance')
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()
|