|
""" |
|
description: |
|
metrics to compute model performance |
|
""" |
|
|
|
import Bio |
|
from Bio.Align import substitution_matrices |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import re |
|
|
|
import Stage3_source.animation_tools as ani_tools |
|
|
|
|
|
' compute Blosum62 soft accuracy ' |
|
class blosum_soft_accuracy: |
|
|
|
def __init__(self, ): |
|
|
|
self.blosum62 = substitution_matrices.load("BLOSUM62") |
|
self.alphabet = self.blosum62.alphabet |
|
|
|
def blosum_acc( |
|
self, |
|
aa1: str, |
|
aa2: str |
|
) -> np.single: |
|
|
|
row = self.blosum62.alphabet.index(aa1) |
|
col = self.blosum62.alphabet.index(aa2) |
|
substitution_scores = self.blosum62[row, :].values() |
|
|
|
|
|
probs = np.exp(substitution_scores)/np.sum(np.exp(substitution_scores)) |
|
|
|
|
|
|
|
correct_aa = aa2 |
|
correct_index = self.alphabet.index(correct_aa) |
|
one_hot = np.zeros_like(probs) |
|
one_hot[correct_index] = 1 |
|
|
|
|
|
soft_acc = np.dot(probs, one_hot) / np.max(probs) |
|
|
|
return soft_acc |
|
|
|
def split_seq(self, seq: str) ->list: |
|
|
|
|
|
split_seq = re.split(r'(-|<START>|<END>|<PAD>|(?<=\w)(?=\w))', seq) |
|
|
|
|
|
|
|
split_seq = [char for char in split_seq if char and char.strip()] |
|
return split_seq |
|
|
|
|
|
|
|
def compute_soft_accuracy( |
|
self, |
|
seq1_list: list, |
|
seq2_list: list |
|
) -> float: |
|
|
|
|
|
if len(seq1_list) == len(seq2_list): |
|
self.batch_size = len(seq1_list) |
|
|
|
else: |
|
print("Please make sequence batch size equivalent...") |
|
|
|
|
|
if len(seq1_list[0]) == len(seq2_list[0]): |
|
self.L = len(seq1_list[0]) |
|
|
|
else: |
|
|
|
pass |
|
|
|
avg_soft_acc_per_batch = 0 |
|
|
|
for seq1, seq2 in zip(seq1_list, seq2_list): |
|
|
|
|
|
seq1 = self.split_seq(seq1) |
|
seq2 = self.split_seq(seq2) |
|
|
|
self.L = len(seq2) |
|
self.L_h = 0 |
|
self.L_s = 0 |
|
avg_soft_acc_per_seq = 0 |
|
avg_hard_acc_per_seq = 0 |
|
|
|
|
|
for aa1, aa2 in zip(seq1, seq2): |
|
|
|
if (aa1 not in ['-', '<START>', '<END>', '<PAD>']) and (aa2 not in ['-', '<START>', '<END>', '<PAD>']): |
|
self.L_s += 1 |
|
soft_acc = self.blosum_acc(aa1=aa1, aa2=aa2) |
|
avg_soft_acc_per_seq += soft_acc |
|
else: |
|
self.L_h += 1 |
|
acc = 1*(aa1==aa2) |
|
avg_hard_acc_per_seq += acc |
|
|
|
|
|
try: |
|
avg_soft_acc_per_seq *= 1/self.L_s |
|
except ZeroDivisionError: |
|
|
|
avg_soft_acc_per_seq = 0 |
|
|
|
|
|
try: |
|
avg_hard_acc_per_seq *= 1/self.L_h |
|
except ZeroDivisionError: |
|
|
|
avg_hard_acc_per_seq = 0 |
|
|
|
|
|
|
|
if self.L_s == 0: |
|
avg_soft_acc_per_batch += avg_hard_acc_per_seq |
|
elif self.L_h == 0: |
|
avg_soft_acc_per_batch += avg_soft_acc_per_seq |
|
else: |
|
avg_soft_acc_per_batch += (avg_soft_acc_per_seq + avg_hard_acc_per_seq)/2 |
|
|
|
avg_soft_acc_per_batch *= 1/self.batch_size |
|
return avg_soft_acc_per_batch |
|
|
|
|
|
def compute_ppl(probs: torch.Tensor) -> float: |
|
|
|
batch_size, sequence_length, class_labels = probs.shape |
|
|
|
|
|
flattened_probs = probs.reshape(batch_size * sequence_length, class_labels) |
|
|
|
|
|
ppl = [] |
|
for i in range(batch_size * sequence_length): |
|
sequence_probs = flattened_probs[i] |
|
|
|
sequence_ppl = torch.exp(-torch.sum( |
|
sequence_probs * torch.log(sequence_probs) |
|
) |
|
) |
|
ppl.append(sequence_ppl.item()) |
|
|
|
ppl = torch.tensor(ppl).view(batch_size, sequence_length) |
|
avg_ppl = ppl.mean().item() |
|
|
|
return avg_ppl |
|
|
|
def batch_compute_ppl(probs_list: list) -> float: |
|
|
|
batch_prob = sum([ |
|
compute_ppl(probs=probs.unsqueeze(0).permute(0,2,1)) for probs in probs_list |
|
]) / len(probs_list) |
|
|
|
return batch_prob |
|
|
|
|
|
def compute_hard_acc( |
|
seq1: str, |
|
seq2: str |
|
) -> float: |
|
|
|
|
|
hard_acc = sum([aa1 == aa2 for (aa1 ,aa2) in zip(seq1, seq2) if aa2 != '<PAD>']) |
|
valid_length = len([aa2 for aa2 in seq2 if aa2 != '<PAD>']) |
|
if valid_length == 0: |
|
return 1.0 |
|
|
|
hard_acc /= valid_length |
|
|
|
return hard_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_hard_acc(seq1_list: list, seq2_list: list) -> float: |
|
|
|
hard_acc = sum([ |
|
compute_hard_acc(seq1=seq1, seq2=seq2) for (seq1,seq2) in zip(seq1_list, seq2_list) |
|
]) / len(seq2_list) |
|
|
|
return hard_acc |
|
|
|
|
|
def time_split_on_seq( |
|
seq: torch.Tensor, |
|
sample_seq_path: torch.Tensor, |
|
idx: torch.Tensor |
|
) -> ( |
|
list, |
|
list, |
|
list |
|
): |
|
|
|
|
|
if len(seq.shape) != 2: |
|
batch_size, class_labels, _ = seq.shape |
|
|
|
|
|
current_seq, prev_seq, fut_seq = [], [], [] |
|
|
|
for ii in range(batch_size): |
|
current_stack_probs, prev_stack_probs, fut_stack_probs = [], [], [] |
|
|
|
for jj in range(class_labels): |
|
|
|
|
|
current_stack_probs.append( |
|
seq[ii,jj][ |
|
(sample_seq_path.cpu()[ii] == idx.cpu()[ii]) |
|
] |
|
) |
|
|
|
|
|
prev_stack_probs.append( |
|
seq[ii,jj][ |
|
(sample_seq_path.cpu()[ii] < idx.cpu()[ii]) |
|
] |
|
) |
|
|
|
|
|
fut_stack_probs.append( |
|
seq[ii,jj][ |
|
(sample_seq_path.cpu()[ii] > idx.cpu()[ii]) |
|
] |
|
) |
|
|
|
current_seq.append(torch.stack(current_stack_probs)) |
|
prev_seq.append(torch.stack(prev_stack_probs)) |
|
fut_seq.append(torch.stack(fut_stack_probs)) |
|
|
|
else: |
|
|
|
current_seq = [seq[ii][sample_seq_path[ii] == idx[ii]] for ii in range(seq.shape[0])] |
|
prev_seq = [seq[ii][sample_seq_path[ii] < idx[ii]] for ii in range(seq.shape[0])] |
|
fut_seq = [seq[ii][sample_seq_path[ii] > idx[ii]] for ii in range(seq.shape[0])] |
|
|
|
return ( |
|
current_seq, |
|
prev_seq, |
|
fut_seq |
|
) |
|
|
|
@torch.no_grad() |
|
def compute_acc_given_time_pos( |
|
real_tokens: torch.Tensor, |
|
sample_seq: torch.Tensor, |
|
sample_path: torch.Tensor, |
|
idx: torch.Tensor |
|
) -> ( |
|
float, |
|
float, |
|
float, |
|
float, |
|
float, |
|
float |
|
): |
|
|
|
|
|
tokens = ['-', '<START>', 'A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','<END>','<PAD>'] |
|
|
|
tokens = tokens + ['X', 'U', 'Z', 'B', 'O'] |
|
|
|
|
|
|
|
current_real_tokens, prev_real_tokens, fut_real_tokens = time_split_on_seq( |
|
seq=real_tokens.cpu(), |
|
sample_seq_path=sample_path.cpu(), |
|
idx=idx.cpu() |
|
) |
|
|
|
|
|
current_sample_tokens, prev_sample_tokens, fut_sample_tokens = time_split_on_seq( |
|
seq=sample_seq.cpu(), |
|
sample_seq_path=sample_path.cpu(), |
|
idx=idx.cpu() |
|
) |
|
|
|
|
|
current_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in current_real_tokens] |
|
prev_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in prev_real_tokens] |
|
fut_real_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in fut_real_tokens] |
|
|
|
|
|
current_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in current_sample_tokens] |
|
prev_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in prev_sample_tokens] |
|
fut_sample_chars = [ani_tools.convert_num_to_char(tokens,seq_tokens) for seq_tokens in fut_sample_tokens] |
|
|
|
|
|
|
|
|
|
|
|
prev_sample_chars = [item for item in prev_sample_chars if item] |
|
prev_real_chars = [item for item in prev_real_chars if item] |
|
|
|
fut_real_chars = [item for item in fut_real_chars if item] |
|
fut_sample_chars = [item for item in fut_sample_chars if item] |
|
|
|
|
|
soft_acc_tool = blosum_soft_accuracy() |
|
|
|
|
|
prev_real_split_chars = [ |
|
soft_acc_tool.split_seq(sample) for sample in prev_real_chars |
|
] |
|
fut_real_split_chars = [ |
|
soft_acc_tool.split_seq(sample) for sample in fut_real_chars |
|
] |
|
|
|
|
|
prev_sample_split_chars = [ |
|
soft_acc_tool.split_seq(sample) for sample in prev_sample_chars |
|
] |
|
fut_sample_split_chars = [ |
|
soft_acc_tool.split_seq(sample) for sample in fut_sample_chars |
|
] |
|
|
|
|
|
' soft accuracy: ' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_batch_soft_acc, fut_batch_soft_acc, current_soft_acc = 0, 0, 0 |
|
|
|
' hard accuracy: ' |
|
|
|
prev_batch_hard_acc = batch_hard_acc( |
|
seq1_list=prev_sample_split_chars, |
|
seq2_list=prev_real_split_chars |
|
) |
|
|
|
|
|
fut_batch_hard_acc = batch_hard_acc( |
|
seq1_list=fut_sample_split_chars, |
|
seq2_list=fut_real_split_chars |
|
) |
|
|
|
|
|
current_hard_acc = compute_hard_acc( |
|
seq1=current_sample_chars, |
|
seq2=current_real_chars |
|
) |
|
|
|
return ( |
|
prev_batch_hard_acc, |
|
prev_batch_soft_acc, |
|
fut_batch_hard_acc, |
|
fut_batch_soft_acc, |
|
current_hard_acc, |
|
current_soft_acc |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def compute_ppl_given_time_pos( |
|
probs: torch.Tensor, |
|
sample_path: torch.Tensor, |
|
idx: torch.Tensor |
|
) -> ( |
|
float, |
|
float, |
|
float |
|
): |
|
|
|
current_probs, prev_probs, fut_probs = time_split_on_seq( |
|
probs.cpu(), |
|
sample_seq_path=sample_path.cpu(), |
|
idx=idx.cpu() |
|
) |
|
|
|
|
|
|
|
current_ppl = batch_compute_ppl(probs_list=current_probs) |
|
|
|
prev_ppl = batch_compute_ppl(probs_list=prev_probs) |
|
fut_ppl = batch_compute_ppl(probs_list=fut_probs) |
|
|
|
return ( |
|
current_ppl, |
|
prev_ppl, |
|
fut_ppl |
|
) |
|
|