Hariharan Vijayachandran commited on
Commit
0c44400
·
1 Parent(s): 22427a2
Files changed (2) hide show
  1. app.py +69 -11
  2. requirements.txt +4 -1
app.py CHANGED
@@ -13,23 +13,73 @@ from annotated_text import annotated_text
13
  ABSOLUTE_PATH = os.path.dirname(__file__)
14
  ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets')
15
 
16
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def preprocess_text(s):
18
  return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
19
 
20
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
21
  def get_pairwise_distances(model):
22
  df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index')
23
  return df
24
 
25
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
26
  def get_pairwise_distances_chunked(model, chunk):
27
  # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
28
  # print(df.iloc[0]['queries'])
29
  # if chunk == int(df.iloc[0]['queries']):
30
  # return df
31
  return get_pairwise_distances(model)
32
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
33
  def get_query_strings():
34
  df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True)
35
  df['index'] = df.reset_index().index
@@ -38,7 +88,7 @@ def get_query_strings():
38
  # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
39
 
40
  # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
41
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
42
  def get_candidate_strings():
43
  df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
44
  df['i'] = df['index']
@@ -49,24 +99,24 @@ def get_candidate_strings():
49
  # df['partition'] = df['index']%100
50
  # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
51
  # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
52
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
53
  def get_embedding_dataset(model):
54
  data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
55
  return data
56
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
57
  def get_bad_queries(model):
58
  df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
59
  return df
60
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
61
  def get_gt_candidates(model, author):
62
  gt_candidates = get_candidate_strings()
63
  df = gt_candidates[gt_candidates['authorIDs'] == author]
64
  return df
65
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
66
  def get_candidate_text(l):
67
  return get_candidate_strings().at[l,'fullText']
68
 
69
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
70
  def get_annotated_text(text, word, pos):
71
  print("here", word, pos)
72
  start= text.index(word, pos)
@@ -146,7 +196,15 @@ if __name__ == '__main__':
146
  with col1:
147
  st.header("Text")
148
  t1 = time.time()
149
- st.write(get_candidate_text(pairwise_candidate_index))
 
 
 
 
 
 
 
 
150
  t2 = time.time()
151
  with col2:
152
  st.header("Cosine Distance")
 
13
  ABSOLUTE_PATH = os.path.dirname(__file__)
14
  ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets')
15
 
16
+
17
+ from nltk.data import find
18
+ import nltk
19
+ import gensim
20
+
21
+ @st.cache_data
22
+ def get_embed_model():
23
+ nltk.download("word2vec_sample")
24
+ word2vec_sample = str(find('models/word2vec_sample/pruned.word2vec.txt'))
25
+
26
+ model = gensim.models.KeyedVectors.load_word2vec_format(word2vec_sample, binary=False)
27
+ return model
28
+
29
+ @st.cache_data
30
+ def get_top_n_closest(query_word, candidate, n):
31
+ model = get_embed_model()
32
+ t = time.time()
33
+ p_c = preprocess_text(candidate)
34
+ similarity = []
35
+ t = time.time()
36
+ for i in p_c:
37
+ try:
38
+ similarity.append(model.similarity(query_word, i))
39
+ except:
40
+ similarity.append(0)
41
+ top_n = min(len(p_c), n)
42
+ t = time.time()
43
+ sorted = (-1*np.array(similarity)).argsort()[:top_n]
44
+ top = [p_c[i] for i in sorted]
45
+ return top
46
+
47
+ @st.cache_data
48
+ def annotate_text(text, words):
49
+ annotated = [text]
50
+ for word in words:
51
+ for i in range(len(annotated)):
52
+ if type(annotated[i]) != str:
53
+ continue
54
+ string = annotated[i]
55
+ try:
56
+ index = string.index(word)
57
+ except:
58
+ continue
59
+ first = string[:index]
60
+ second = (string[index:index+len(word)],'SIMILAR')
61
+ third = string[index+len(word):]
62
+ annotated = annotated[:i] + [first, second, third] + annotated[i+1:]
63
+ return tuple(annotated)
64
+
65
+
66
+ @st.cache_data
67
  def preprocess_text(s):
68
  return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
69
 
70
+ @st.cache_data
71
  def get_pairwise_distances(model):
72
  df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index')
73
  return df
74
 
75
+ @st.cache_data
76
  def get_pairwise_distances_chunked(model, chunk):
77
  # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
78
  # print(df.iloc[0]['queries'])
79
  # if chunk == int(df.iloc[0]['queries']):
80
  # return df
81
  return get_pairwise_distances(model)
82
+ @st.cache_data
83
  def get_query_strings():
84
  df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True)
85
  df['index'] = df.reset_index().index
 
88
  # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
89
 
90
  # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
91
+ @st.cache_data
92
  def get_candidate_strings():
93
  df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
94
  df['i'] = df['index']
 
99
  # df['partition'] = df['index']%100
100
  # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
101
  # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
102
+ @st.cache_data
103
  def get_embedding_dataset(model):
104
  data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
105
  return data
106
+ @st.cache_data
107
  def get_bad_queries(model):
108
  df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
109
  return df
110
+ @st.cache_data
111
  def get_gt_candidates(model, author):
112
  gt_candidates = get_candidate_strings()
113
  df = gt_candidates[gt_candidates['authorIDs'] == author]
114
  return df
115
+ @st.cache_data
116
  def get_candidate_text(l):
117
  return get_candidate_strings().at[l,'fullText']
118
 
119
+ @st.cache_data
120
  def get_annotated_text(text, word, pos):
121
  print("here", word, pos)
122
  start= text.index(word, pos)
 
196
  with col1:
197
  st.header("Text")
198
  t1 = time.time()
199
+ candidate_text = get_candidate_text(pairwise_candidate_index)
200
+
201
+ if st.session_state['pos_highlight'] == 0:
202
+ annotated_text(candidate_text)
203
+ else:
204
+ top_n_words_to_highlight = get_top_n_closest(preprocessed_query_text[text_highlight_index-1], candidate_text, 4)
205
+ print("TOPN", top_n_words_to_highlight)
206
+ annotated_text(*annotate_text(candidate_text, top_n_words_to_highlight))
207
+
208
  t2 = time.time()
209
  with col2:
210
  st.header("Cosine Distance")
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  scikit-learn==1.2.0
2
  numpy==1.23.5
3
  pandas==1.5.2
4
- st-annotated-text==3.0.0
 
 
 
 
1
  scikit-learn==1.2.0
2
  numpy==1.23.5
3
  pandas==1.5.2
4
+ st-annotated-text==3.0.0
5
+ nltk==3.8.1
6
+ gensim==4.3.1
7
+ streamlit==1.20.0