Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -17,6 +17,11 @@ tokenizer_finbert = AutoTokenizer.from_pretrained("ProsusAI/finbert")
|
|
17 |
kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle"
|
18 |
kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle"
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
text = st.text_input("Enter word or key-phrase")
|
21 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
|
22 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
@@ -28,7 +33,7 @@ with st.sidebar:
|
|
28 |
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
|
29 |
|
30 |
#columns
|
31 |
-
col1, col2 = st.columns(
|
32 |
#load kp dicts
|
33 |
with open(kp_dict_checkpoint,'rb') as handle:
|
34 |
kp_dict = pickle.load(handle)
|
@@ -38,11 +43,17 @@ with open(kp_dict_finbert_checkpoint,'rb') as handle:
|
|
38 |
kp_dict_finbert = pickle.load(handle)
|
39 |
keys_finbert = list(kp_dict_finbert.keys())
|
40 |
|
|
|
|
|
|
|
|
|
41 |
#load cosine distances of kp dict
|
42 |
with open(kp_cosine_checkpoint,'rb') as handle:
|
43 |
cosine_kp = pickle.load(handle)
|
44 |
with open(kp_cosine_finbert_checkpoint,'rb') as handle:
|
45 |
cosine_finbert_kp = pickle.load(handle)
|
|
|
|
|
46 |
|
47 |
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
|
48 |
sim_dict = {}
|
@@ -100,11 +111,15 @@ if text:
|
|
100 |
new_tokens.pop("KPS")
|
101 |
new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
|
102 |
new_tokens_finbert.pop("KPS")
|
|
|
|
|
103 |
with torch.no_grad():
|
104 |
outputs = model(**new_tokens)
|
105 |
outputs_finbert = model_finbert(**new_tokens_finbert)
|
|
|
106 |
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
107 |
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
|
|
108 |
if not diversify_box:
|
109 |
with col1:
|
110 |
st.write("distilbert-cvent")
|
@@ -112,11 +127,16 @@ if text:
|
|
112 |
with col2:
|
113 |
st.write("finbert")
|
114 |
st.json(sim_dict_finbert)
|
|
|
|
|
|
|
115 |
else:
|
116 |
idxs = extract_idxs(sim_dict, kp_dict)
|
117 |
idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
|
|
|
118 |
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
|
119 |
distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
|
|
|
120 |
#first do distilbert
|
121 |
candidate = None
|
122 |
min_sim = np.inf
|
@@ -133,6 +153,14 @@ if text:
|
|
133 |
if sim < min_sim:
|
134 |
candidate_finbert = combination
|
135 |
min_sim = sim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
#distilbert
|
137 |
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
|
138 |
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
|
@@ -141,9 +169,16 @@ if text:
|
|
141 |
ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
|
142 |
ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
|
143 |
ret_finbert = {x:y for x,y in ret_finbert}
|
|
|
|
|
|
|
|
|
144 |
with col1:
|
145 |
st.write("distilbert-cvent")
|
146 |
st.json(ret)
|
147 |
with col2:
|
148 |
st.write("finbert")
|
149 |
-
st.json(ret_finbert)
|
|
|
|
|
|
|
|
17 |
kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle"
|
18 |
kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle"
|
19 |
|
20 |
+
tokenizer_sapbert = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
|
21 |
+
model_sapbert = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True)
|
22 |
+
kp_dict_sapbert_checkpoint = "kp_dict_sapbert.pickle"
|
23 |
+
kp_cosine_sapbert_checkpoint = "cosine_kp_sapbert.pickle"
|
24 |
+
|
25 |
text = st.text_input("Enter word or key-phrase")
|
26 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
|
27 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
|
|
33 |
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
|
34 |
|
35 |
#columns
|
36 |
+
col1, col2, col3 = st.columns(3)
|
37 |
#load kp dicts
|
38 |
with open(kp_dict_checkpoint,'rb') as handle:
|
39 |
kp_dict = pickle.load(handle)
|
|
|
43 |
kp_dict_finbert = pickle.load(handle)
|
44 |
keys_finbert = list(kp_dict_finbert.keys())
|
45 |
|
46 |
+
with open(kp_dict_sapbert_checkpoint,'rb') as handle:
|
47 |
+
kp_dict_sapbert = pickle.load(handle)
|
48 |
+
keys_sapbert = list(kp_dict_sapbert.keys())
|
49 |
+
|
50 |
#load cosine distances of kp dict
|
51 |
with open(kp_cosine_checkpoint,'rb') as handle:
|
52 |
cosine_kp = pickle.load(handle)
|
53 |
with open(kp_cosine_finbert_checkpoint,'rb') as handle:
|
54 |
cosine_finbert_kp = pickle.load(handle)
|
55 |
+
with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
|
56 |
+
cosine_sapbert_kp = pickle.load(handle)
|
57 |
|
58 |
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
|
59 |
sim_dict = {}
|
|
|
111 |
new_tokens.pop("KPS")
|
112 |
new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
|
113 |
new_tokens_finbert.pop("KPS")
|
114 |
+
new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert)
|
115 |
+
new_tokens_sapbert.pop("KPS")
|
116 |
with torch.no_grad():
|
117 |
outputs = model(**new_tokens)
|
118 |
outputs_finbert = model_finbert(**new_tokens_finbert)
|
119 |
+
outputs_sapbert = model_sapbert(**new_tokens_sapbert)
|
120 |
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
121 |
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
122 |
+
sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
123 |
if not diversify_box:
|
124 |
with col1:
|
125 |
st.write("distilbert-cvent")
|
|
|
127 |
with col2:
|
128 |
st.write("finbert")
|
129 |
st.json(sim_dict_finbert)
|
130 |
+
with col3:
|
131 |
+
st.write("sapbert")
|
132 |
+
st.json(sim_dict_sapbert)
|
133 |
else:
|
134 |
idxs = extract_idxs(sim_dict, kp_dict)
|
135 |
idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
|
136 |
+
idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert)
|
137 |
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
|
138 |
distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
|
139 |
+
distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)]
|
140 |
#first do distilbert
|
141 |
candidate = None
|
142 |
min_sim = np.inf
|
|
|
153 |
if sim < min_sim:
|
154 |
candidate_finbert = combination
|
155 |
min_sim = sim
|
156 |
+
#sapbert
|
157 |
+
candidate_sapbert = None
|
158 |
+
min_sim = np.inf
|
159 |
+
for combination in itertools.combinations(range(len(idxs_sapbert)), k):
|
160 |
+
sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j])
|
161 |
+
if sim < min_sim:
|
162 |
+
candidate_sapbert = combination
|
163 |
+
min_sim = sim
|
164 |
#distilbert
|
165 |
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
|
166 |
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
|
|
|
169 |
ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
|
170 |
ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
|
171 |
ret_finbert = {x:y for x,y in ret_finbert}
|
172 |
+
#sapbert
|
173 |
+
ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert}
|
174 |
+
ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True)
|
175 |
+
ret_sapbert = {x:y for x,y in ret_sapbert}
|
176 |
with col1:
|
177 |
st.write("distilbert-cvent")
|
178 |
st.json(ret)
|
179 |
with col2:
|
180 |
st.write("finbert")
|
181 |
+
st.json(ret_finbert)
|
182 |
+
with col3:
|
183 |
+
st.write("sapbert")
|
184 |
+
st.json(ret_sapbert)
|