sample_tool_1 / app.py
Hariharan Vijayachandran
fix
a6af52b
import streamlit as st
import pandas as pd
import sys
import os
from datasets import load_from_disk
# from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import time
from annotated_text import annotated_text
ABSOLUTE_PATH = os.path.dirname(__file__)
ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets')
from nltk.data import find
import nltk
import gensim
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_embed_model():
nltk.download("word2vec_sample")
word2vec_sample = str(find('models/word2vec_sample/pruned.word2vec.txt'))
model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_sample, binary=False)
return model
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_top_n_closest(query_word, candidate, n):
model = get_embed_model()
t = time.time()
p_c = preprocess_text(candidate)
similarity = []
t = time.time()
for i in p_c:
try:
similarity.append(model.similarity(query_word, i))
except:
similarity.append(0)
top_n = min(len(p_c), n)
t = time.time()
sorted = (-1*np.array(similarity)).argsort()[:top_n]
top = [p_c[i] for i in sorted]
return top
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def annotate_text(text, words):
annotated = [text]
for word in words:
for i in range(len(annotated)):
if type(annotated[i]) != str:
continue
string = annotated[i]
try:
index = string.index(word)
except:
continue
first = string[:index]
second = (string[index:index+len(word)],'SIMILAR')
third = string[index+len(word):]
annotated = annotated[:i] + [first, second, third] + annotated[i+1:]
return tuple(annotated)
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def preprocess_text(s):
return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_pairwise_distances(model):
df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index')
return df
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_pairwise_distances_chunked(model, chunk):
# for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
# print(df.iloc[0]['queries'])
# if chunk == int(df.iloc[0]['queries']):
# return df
return get_pairwise_distances(model)
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_query_strings():
df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True)
df['index'] = df.reset_index().index
return df
# df['partition'] = df['index']%100
# df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
# return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_candidate_strings():
df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
df['i'] = df['index']
df = df.set_index('i')
# df['index'] = df.reset_index().index
return df
# df['partition'] = df['index']%100
# df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
# return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_embedding_dataset(model):
data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
return data
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_bad_queries(model):
df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
return df
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_gt_candidates(model, author):
gt_candidates = get_candidate_strings()
df = gt_candidates[gt_candidates['authorIDs'] == author]
return df
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_candidate_text(l):
return get_candidate_strings().at[l,'fullText']
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_annotated_text(text, word, pos):
print("here", word, pos)
start= text.index(word, pos)
end = start+len(word)
return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
# class AgGridBuilder:
# __static_key = 0
# def build_ag_grid(table, display_columns):
# AgGridBuilder.__static_key += 1
# options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
# options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
# options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
# options = options_builder.build()
# return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
if __name__ == '__main__':
st.set_page_config(layout="wide")
models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
with st.sidebar:
current_model = st.selectbox(
"Select Model to analyze",
models
)
pairwise_distances = get_pairwise_distances(current_model)
embedding_dataset = get_embedding_dataset(current_model)
candidate_string_grid = None
gt_candidate_string_grid = None
with st.container():
t1 = time.time()
st.title("Full Text")
col1, col2 = st.columns([14, 2])
t2 = time.time()
query_table = get_bad_queries(current_model)
t3 = time.time()
print(query_table)
with col2:
index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
query_text = query_table.loc[index]['fullText']
preprocessed_query_text = preprocess_text(query_text)
text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
query_index = int(query_table.iloc[index]['index'])
with col1:
if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
st.session_state['pos_highlight'] = text_highlight_index
st.session_state['pos_history'] = [0]
if st.session_state['pos_highlight'] > text_highlight_index:
st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
if len(st.session_state['pos_history']) == 0:
st.session_state['pos_history'] = [0]
print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
if st.session_state['pos_highlight'] < text_highlight_index:
st.session_state['pos_history'].append(pos)
st.session_state['pos_highlight'] = text_highlight_index
annotated_text(*anotated_text_)
# annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
t4 = time.time()
print(f"query time query text: {t3-t2}, total time: {t4-t1}")
with st.container():
st.title("Top 16 Recommended Candidates")
col1, col2, col3 = st.columns([10, 4, 2])
rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
print(rec_candidates)
l = list(rec_candidates)
with col3:
candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
print("l:",l, query_index)
pairwise_candidate_index = int(l[candidate_rec_index])
with col1:
st.header("Text")
t1 = time.time()
candidate_text = get_candidate_text(pairwise_candidate_index)
if st.session_state['pos_highlight'] == 0:
annotated_text(candidate_text)
else:
top_n_words_to_highlight = get_top_n_closest(preprocessed_query_text[text_highlight_index-1], candidate_text, 4)
print("TOPN", top_n_words_to_highlight)
annotated_text(*annotate_text(candidate_text, top_n_words_to_highlight))
t2 = time.time()
with col2:
st.header("Cosine Distance")
st.write(float(pairwise_distances[\
( pairwise_distances['queries'] == query_index ) \
&
( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
print(f"candidate string retreival: {t2-t1}")
with st.container():
t1 = time.time()
st.title("Candidates With Same Authors As Query")
col1, col2, col3 = st.columns([10, 4, 2])
t2 = time.time()
gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
t3 = time.time()
with col3:
candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
print(gt_candidates.head())
gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
with col1:
st.header("Text")
st.write(gt_candidates.iloc[candidate_index]['fullText'])
with col2:
t4 = time.time()
st.header("Cosine Distance")
indices = list(embedding_dataset['candidates']['index'])
st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][indices.index(gt_candidate_index)]['embedding']]))[0,0])
t5 = time.time()
print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")