TestTaker
Fix bert bugs
703d114
raw
history blame
12.7 kB
import numpy as np
from math import pow
from nltk.corpus import wordnet as wn
from utilities_language_general.rus_constants import nlp, PHRASES, LEVEL_NUMBERS
def eucledian_distance(x, y):
return np.sqrt(np.sum((x - y) ** 2))
def cosine_similarity(x, y):
out = np.dot(x, y) / (np.sqrt(np.dot(x, x)) * np.sqrt(np.dot(y, y)))
if str(out) != 'nan':
return out
return None
def get_vector_for_token(model, token):
vector = None
splitted = token.split('_')
token_list = [f'{splitted[i]}_{splitted[i+1]}' for i in range(len(splitted)-1)]
if model.has_index_for(token):
vector = model.get_vector(token)
else:
try:
vector = model.get_mean_vector(token_list)
except ValueError:
return None
return vector
def compute_metric(func, vector1, vector2):
if vector1 is not None and vector2 is not None:
return func(vector1, vector2)
else:
return None
def compute_positive_cos(x, y):
cos_sim = cosine_similarity(x, y)
if cos_sim:
return (cos_sim + 1) / 2
else:
return None
def addition_metric(substitute, target, context):
substitute_target_cos = compute_metric(cosine_similarity, substitute, target)
if not substitute_target_cos:
return None
if not context:
return None
context_vectors = []
for context_tk in context:
substitute_context_cos = compute_metric(cosine_similarity, substitute, context_tk)
if substitute_context_cos:
context_vectors.append(substitute_context_cos)
sum_of_context_vectors = np.sum(context_vectors)
metric = (substitute_target_cos + sum_of_context_vectors) / (len(context) + 1)
return metric
def balanced_addition_metric(substitute, target, context):
substitute_target_cos = compute_metric(cosine_similarity, substitute, target)
if not substitute_target_cos:
return None
if not context:
return None
context_vectors = []
for context_tk in context:
substitute_context_cos = compute_metric(cosine_similarity, substitute, context_tk)
if substitute_context_cos:
context_vectors.append(substitute_context_cos)
sum_of_context_vectors = np.sum(context_vectors)
context_len = len(context)
metric = (context_len * substitute_target_cos + sum_of_context_vectors) / (2 * context_len)
return metric
def multiplication_metric(substitute, target, context):
substitute_target_cos = compute_metric(compute_positive_cos, substitute, target)
if not substitute_target_cos:
return None
if not context:
return None
context_vectors = []
for context_tk in context:
substitute_context_positive_cos = compute_metric(compute_positive_cos, substitute, context_tk)
if substitute_context_positive_cos:
context_vectors.append(substitute_context_positive_cos)
prod_of_context_vectors = np.prod(context_vectors)
try:
metric = pow((substitute_target_cos + prod_of_context_vectors), 1 / (len(context) + 1))
except ValueError:
return None
return metric
def balanced_multiplication_metric(substitute, target, context):
substitute_target_cos = compute_metric(compute_positive_cos, substitute, target)
if not substitute_target_cos:
return None
if not context:
return None
context_vectors = []
for context_tk in context:
substitute_context_positive_cos = compute_metric(compute_positive_cos, substitute, context_tk)
if substitute_context_positive_cos:
context_vectors.append(substitute_context_positive_cos)
prod_of_context_vectors = np.prod(context_vectors)
context_len = len(context)
try:
metric = pow((pow(substitute_target_cos, context_len) + prod_of_context_vectors), 1 / (2 * context_len))
except ValueError:
return None
return metric
def bind_phrases(context_list):
context = []
previous_was_phrase = False
for i in range(len(context_list)-1):
phrase_candidate = f'{context_list[i]}_{context_list[i+1]}'
if phrase_candidate in PHRASES and not previous_was_phrase:
context.append(phrase_candidate)
previous_was_phrase = True
else:
if not previous_was_phrase:
context.append(context_list[i])
previous_was_phrase = False
if context_list:
if not context:
context.append(context_list[-1])
elif not context_list[-1] in context[-1]:
context.append(context_list[-1])
return context
def get_context_windows(doc, target_text, window_size):
sentence_str = doc.text
sentence_masked = sentence_str.lower().replace(target_text.lower().strip(), ' [MASK] ')
alpha_tokens_lemma_pos = [f'{tk.lemma_.lower()}_{tk.pos_}' for tk in nlp(sentence_masked) if tk.text.isalpha()]
alpha_tokens_lemma_pos_no_stop = [f'{tk.lemma_.lower()}_{tk.pos_}' for tk in nlp(sentence_masked) if tk.text.isalpha() and not tk.is_stop]
try:
mask_token_index = alpha_tokens_lemma_pos.index('mask_PROPN')
mask_token_index_no_stop = alpha_tokens_lemma_pos_no_stop.index('mask_PROPN')
except ValueError:
return None
left_border = max(mask_token_index-window_size, 0)
right_border = min(mask_token_index+window_size, len(alpha_tokens_lemma_pos))
l_context = alpha_tokens_lemma_pos[left_border:mask_token_index]
r_context = alpha_tokens_lemma_pos[mask_token_index+1:right_border+1]
left_border_no_stop = max(mask_token_index_no_stop-window_size, 0)
right_border_no_stop = min(mask_token_index_no_stop+window_size, len(alpha_tokens_lemma_pos_no_stop))
l_context_no_stop = alpha_tokens_lemma_pos_no_stop[left_border_no_stop:mask_token_index_no_stop]
r_context_no_stop = alpha_tokens_lemma_pos_no_stop[mask_token_index_no_stop+1:right_border_no_stop+1]
return (bind_phrases(l_context) + bind_phrases(r_context), bind_phrases(l_context_no_stop) + bind_phrases(r_context_no_stop))
def get_context_linked_words(doc, target_position, target_text):
answer_list = target_text.split(' ')
context_words = []
for tk in doc:
if tk.text.isalpha():
if (tk.text in answer_list and abs(target_position - tk.idx) <= sum([len(t) for t in answer_list])):
context_words.extend([t for t in tk.subtree if t.text.isalpha() and not t.is_stop])
context_words.extend([t for t in tk.children if t.text.isalpha() and not t.is_stop])
context_words.extend([t for t in tk.ancestors if t.text.isalpha() and not t.is_stop])
context_words = [(tk, f'{tk.lemma_}_{tk.pos_}') for tk in sorted(set(context_words), key=lambda tk: tk.i) if tk.text not in answer_list]
context = []
previous_was_phrase = False
for i in range(len(context_words)-1):
phrase_candidate = f'{context_words[i][1]}_{context_words[i+1][1]}'
if phrase_candidate in PHRASES and not previous_was_phrase and abs(context_words[i][0].i - context_words[i+1][0].i) <=1:
context.append(phrase_candidate)
previous_was_phrase = True
else:
if not previous_was_phrase:
context.append(context_words[i][1])
if context and context_words:
if not context_words[-1][1] in context[-1]:
context.append(context_words[-1][1])
elif context_words:
context.append(context_words[-1][1])
return context
def compute_all_necessary_metrics(target_lemma, target_text, target_position, substitute_lemma, doc, model_type:str, model=None):
if model_type == 'bert':
return
target_vector = get_vector_for_token(model, target_lemma)
substitute_vector = get_vector_for_token(model, substitute_lemma)
cosimilarity = compute_metric(cosine_similarity, substitute_vector, target_vector)
eucledian_similarity = compute_metric(eucledian_distance, substitute_vector, target_vector)
context_window3, context_window3_no_stop = get_context_windows(doc=doc, target_text=target_text, window_size=3)
context_window5, context_window5_no_stop = get_context_windows(doc=doc, target_text=target_text, window_size=5)
context_window_synt = get_context_linked_words(doc, target_position, target_text)
context_window3 = [get_vector_for_token(model, token) for token in context_window3]
context_window3_no_stop = [get_vector_for_token(model, token) for token in context_window3_no_stop]
context_window5 = [get_vector_for_token(model, token) for token in context_window5]
context_window5_no_stop = [get_vector_for_token(model, token) for token in context_window5_no_stop]
context_window_synt = [get_vector_for_token(model, token) for token in context_window_synt]
add_metric_window3 = addition_metric(target_vector, substitute_vector, context_window3)
bal_add_metric_window3 = balanced_addition_metric(target_vector, substitute_vector, context_window3)
add_metric_window3_no_stop = addition_metric(target_vector, substitute_vector, context_window3_no_stop)
bal_add_metric_window3_no_stop = balanced_addition_metric(target_vector, substitute_vector, context_window3_no_stop)
mult_metric_window3 = multiplication_metric(target_vector, substitute_vector, context_window3)
bal_mult_metric_window3 = balanced_multiplication_metric(target_vector, substitute_vector, context_window3)
mult_metric_window3_no_stop = multiplication_metric(target_vector, substitute_vector, context_window3_no_stop)
bal_mult_metric_window3_no_stop = balanced_multiplication_metric(target_vector, substitute_vector, context_window3_no_stop)
add_metric_window5 = addition_metric(target_vector, substitute_vector, context_window5)
bal_add_metric_window5 = balanced_addition_metric(target_vector, substitute_vector, context_window5)
add_metric_window5_no_stop = addition_metric(target_vector, substitute_vector, context_window5_no_stop)
bal_add_metric_window5_no_stop = balanced_addition_metric(target_vector, substitute_vector, context_window5_no_stop)
mult_metric_window5 = multiplication_metric(target_vector, substitute_vector, context_window5)
bal_mult_metric_window5 = balanced_multiplication_metric(target_vector, substitute_vector, context_window5)
mult_metric_window5_no_stop = multiplication_metric(target_vector, substitute_vector, context_window5_no_stop)
bal_mult_metric_window5_no_stop = balanced_multiplication_metric(target_vector, substitute_vector, context_window5_no_stop)
add_metric_synt = addition_metric(target_vector, substitute_vector, context_window_synt)
bal_add_metric_synt = balanced_addition_metric(target_vector, substitute_vector, context_window_synt)
mult_metric_synt = multiplication_metric(target_vector, substitute_vector, context_window_synt)
bal_mult_metric_synt = balanced_multiplication_metric(target_vector, substitute_vector, context_window_synt)
return (cosimilarity, eucledian_similarity,
add_metric_window3, bal_add_metric_window3,
mult_metric_window3, bal_mult_metric_window3,
add_metric_window3_no_stop, bal_add_metric_window3_no_stop,
mult_metric_window3_no_stop, bal_mult_metric_window3_no_stop,
add_metric_window5, bal_add_metric_window5,
mult_metric_window5, bal_mult_metric_window5,
add_metric_window5_no_stop, bal_add_metric_window5_no_stop,
mult_metric_window5_no_stop, bal_mult_metric_window5_no_stop,
add_metric_synt, bal_add_metric_synt,
mult_metric_synt, bal_mult_metric_synt)
def make_decision(doc, model_type, scaler, classifier, pos_dict, level, target_lemma, target_text, target_pos, target_position,
substitute_lemma, substitute_pos, model=None, bert_score=None):
# return True
metrics = compute_all_necessary_metrics(target_lemma=target_lemma, target_text=target_text, target_position=target_position,
substitute_lemma=substitute_lemma, doc=doc, model_type=model_type, model=model)
target_multiword, substitute_multiword = target_lemma.count('_') > 2, substitute_lemma.count('_') > 2
if model_type == 'bert':
scaled_data = scaler.transform([[bert_score]]).tolist()[0]
else:
scaled_data = scaler.transform([metrics]).tolist()[0]
data = [LEVEL_NUMBERS.get(level), pos_dict.get(target_pos), target_multiword, pos_dict.get(substitute_pos), substitute_multiword] + scaled_data
predict = classifier.predict(data)
return bool(predict)