Spaces:
Runtime error
Runtime error
File size: 8,270 Bytes
f563c99 ed162b2 7fdb2f4 caee7a0 b7a045b ed162b2 4d020be b045767 06da3ff b045767 58bd5f2 73c3a05 b045767 aed5b51 2a61a57 ed162b2 06da3ff 121c34f f050ba4 ed162b2 560e686 caee7a0 560e686 06da3ff 2a61a57 06da3ff 8d94e80 ed162b2 52ce487 06da3ff 2a61a57 caee7a0 8d94e80 caee7a0 06da3ff 2a61a57 9a15376 e9c9402 ed162b2 e9c9402 54edc92 ed162b2 06da3ff ed162b2 caee7a0 ed162b2 0b48aba 06da3ff ed162b2 06da3ff 2a61a57 ed162b2 06da3ff 8767f7f 06da3ff 8767f7f caee7a0 06da3ff 05fc1b3 caee7a0 a8bcdc3 2a61a57 caee7a0 56eac52 2a61a57 06da3ff caee7a0 06da3ff caee7a0 9c74908 06da3ff 2a61a57 06da3ff 4c852a8 d20a440 06da3ff 09ac8a0 084fc81 2a61a57 06da3ff 2a61a57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification, AutoModel
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import numpy as np
import itertools
import tokenizers
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForMaskedLM:lambda _: None})
def load_bert():
return (AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True),
AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022"))
model, tokenizer = load_bert()
kp_dict_checkpoint = "kp_dict_merged.pickle"
kp_cosine_checkpoint = "cosine_kp.pickle"
@st.cache
def load_finbert():
return (AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True),
AutoTokenizer.from_pretrained("ProsusAI/finbert"))
model_finbert, tokenizer_finbert = load_finbert()
kp_dict_finbert_checkpoint = "kp_dict_finance.pickle"
kp_cosine_finbert_checkpoint = "cosine_kp_finance.pickle"
@st.cache
def load_sapbert():
return (AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True),
AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext"))
model_sapbert, tokenizer_sapbert = load_sapbert()
kp_dict_sapbert_checkpoint = "kp_dict_medical.pickle"
kp_cosine_sapbert_checkpoint = "cosine_kp_medical.pickle"
text = st.text_input("Enter word or key-phrase")
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)
with st.sidebar:
diversify_box = st.checkbox("Diversify results",True)
if diversify_box:
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
#columns
col1, col2, col3 = st.columns(3)
#load kp dicts
with open(kp_dict_checkpoint,'rb') as handle:
kp_dict = pickle.load(handle)
keys = list(kp_dict.keys())
with open(kp_dict_finbert_checkpoint,'rb') as handle:
kp_dict_finbert = pickle.load(handle)
keys_finbert = list(kp_dict_finbert.keys())
with open(kp_dict_sapbert_checkpoint,'rb') as handle:
kp_dict_sapbert = pickle.load(handle)
keys_sapbert = list(kp_dict_sapbert.keys())
#load cosine distances of kp dict
with open(kp_cosine_checkpoint,'rb') as handle:
cosine_kp = pickle.load(handle)
with open(kp_cosine_finbert_checkpoint,'rb') as handle:
cosine_finbert_kp = pickle.load(handle)
with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
cosine_sapbert_kp = pickle.load(handle)
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5, pooler=True):
sim_dict = {}
if pooler:
pools = pool_embeddings(out, tokens).detach().numpy()
else:
pools = out["pooler_output"].detach().numpy()
for key in kp_dict.keys():
if key == text:
continue
if exclude_text and text in key:
continue
if exclude_words and True in [x in key for x in text.split(" ")]:
continue
sim_dict[key] = cosine_similarity(
pools,
[kp_dict[key]]
)[0][0]
sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
return {x:y for x,y in sims}
def concat_tokens(sentences, tokenizer):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
for sentence in sentences:
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'].append(sentence)
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
def pool_embeddings(out, tok):
embeddings = out["hidden_states"][-1]
attention_mask = tok['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
def extract_idxs(top_dict, kp_dict):
idxs = []
c = 0
for i in list(kp_dict.keys()):
if i in top_dict.keys():
idxs.append(c)
c+=1
return idxs
if text:
text = text.lower()
new_tokens = concat_tokens([text], tokenizer)
new_tokens.pop("KPS")
new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
new_tokens_finbert.pop("KPS")
new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert)
new_tokens_sapbert.pop("KPS")
with torch.no_grad():
outputs = model(**new_tokens)
outputs_finbert = model_finbert(**new_tokens_finbert)
outputs_sapbert = model_sapbert(**new_tokens_sapbert)
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k, pooler=False)
if not diversify_box:
with col1:
st.write("distilbert-cvent")
st.json(sim_dict)
with col2:
st.write("finbert")
st.json(sim_dict_finbert)
with col3:
st.write("sapbert")
st.json(sim_dict_sapbert)
else:
idxs = extract_idxs(sim_dict, kp_dict)
idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert)
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)]
#first do distilbert
candidate = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs)), k):
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate = combination
min_sim = sim
#then do finbert
candidate_finbert = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs_finbert)), k):
sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate_finbert = combination
min_sim = sim
#sapbert
candidate_sapbert = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs_sapbert)), k):
sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate_sapbert = combination
min_sim = sim
#distilbert
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
ret = {x:y for x,y in ret}
#finbert
ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
ret_finbert = {x:y for x,y in ret_finbert}
#sapbert
ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert}
ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True)
ret_sapbert = {x:y for x,y in ret_sapbert}
with col1:
st.write("distilbert-cvent")
st.json(ret)
with col2:
st.write("finbert")
st.json(ret_finbert)
with col3:
st.write("sapbert")
st.json(ret_sapbert) |