Spaces:
Runtime error
Runtime error
File size: 7,188 Bytes
732e0d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import wandb
from model.lstmlstm import Seq2SeqTest
from model.lstmlstmattention import Seq2SeqAttTest
from model.transformer_seq2seq import Seq2SeqTransformerTest
from modelbuilder import ModelBuilder
class Test:
def run_testing(self, config, model, test_dataloader):
self.config = config
self.device = config['device']
self.loss = self.config['loss']
self.weight = self.config['loss_weight']
self.model_name = self.config['model_name']
self.classification = config['classification']
self.n_output = len(self.config['selected_opensim_labels'])
if not self.n_output == len(self.weight):
self.weight = None
modelbuilder_handler = ModelBuilder(self.config)
criterion = modelbuilder_handler.get_criterion(self.weight)
self.tester = self.setup_tester()
y_pred, y_true, loss = self.tester(model, test_dataloader, criterion, self.device)
return y_pred, y_true, loss
def setup_tester(self):
if self.model_name == 'seq2seqatt':
tester = self.testing_seq2seqatt
elif self.model_name == 'seq2seqtransformer':
tester = self.testing_transformer_seq2seq
elif (self.model_name == 'transformer' and not self.classification) or (self.model_name == 'transformertsai' and not self.classification):
tester = self.testing_transformer
elif self.classification:
tester = self.testing_w_classification
else:
tester = self.testing
return tester
def testing(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
y_pred = model(x.float())
loss = criterion(y, y_pred)
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y)
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
# wandb.log({"Test Loss": test_loss})
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
def testing_w_classification(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y, y_label in test_dataloader:
x = x.to(device).float()
y_label = y_label.type(torch.LongTensor).to(device) # The targets passed to nn.CrossEntropyLoss() should be in torch.long format
y = y.to(device)
y_pred = model(x)
y_pred[0] = y_pred[0].double()
y_pred[1] = y_pred[1].double()
y_true = [y, y_label]
loss = criterion(y_pred, y_true)
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y_true)
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
wandb.log({"Test Loss": test_loss})
test_preds_reg = []
test_trues_reg = []
for pred, true in zip(test_preds, test_trues):
test_preds_reg.append(pred[0])
test_trues_reg.append(true[0])
return torch.cat(test_preds_reg, 0), torch.cat(test_trues_reg, 0), test_loss
def testing_seq2seq(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
# y_pred = model(x.float(), y.float()) # just for seq 2 seq
y_pred = Seq2SeqTest(model, x.float())
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y)
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
wandb.log({"Test Loss": test_loss})
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
def testing_seq2seqatt(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
# y_pred = model(x.float(), y.float()) # just for seq 2 seq
y_pred = Seq2SeqAttTest(model, x.float())
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y)
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
wandb.log({"Test Loss": test_loss})
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
def testing_transformer(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
y_pred = model(x.float()) # just for transformer
loss = criterion(y, y_pred.to(device))
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y)
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
wandb.log({"Test Loss": test_loss})
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
def testing_transformer_seq2seq(self, model, test_dataloader, criterion, device):
model.eval()
with torch.no_grad():
test_loss = []
test_preds = []
test_trues = []
for x, y in test_dataloader:
x = x.to(device)
y = y.to(device)
y_pred = Seq2SeqTransformerTest(model, x.float())
# y_pred = model(x.float(), y.float()[:, :-1, :]) # just for seq 2 seq transformer
loss = criterion(y_pred, y.to(device))
test_loss.append(loss.item())
test_preds.append(y_pred)
test_trues.append(y[:, 1:, :])
test_loss = torch.mean(torch.tensor(test_loss))
print('Test Accuracy of the model: {}'.format(test_loss))
# wandb.log({"Test Loss": test_loss})
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss |