Spaces:
Runtime error
Runtime error
import torch | |
import wandb | |
from model.lstmlstm import Seq2SeqTest | |
from model.lstmlstmattention import Seq2SeqAttTest | |
from model.transformer_seq2seq import Seq2SeqTransformerTest | |
from modelbuilder import ModelBuilder | |
class Test: | |
def run_testing(self, config, model, test_dataloader): | |
self.config = config | |
self.device = config['device'] | |
self.loss = self.config['loss'] | |
self.weight = self.config['loss_weight'] | |
self.model_name = self.config['model_name'] | |
self.classification = config['classification'] | |
self.n_output = len(self.config['selected_opensim_labels']) | |
if not self.n_output == len(self.weight): | |
self.weight = None | |
modelbuilder_handler = ModelBuilder(self.config) | |
criterion = modelbuilder_handler.get_criterion(self.weight) | |
self.tester = self.setup_tester() | |
y_pred, y_true, loss = self.tester(model, test_dataloader, criterion, self.device) | |
return y_pred, y_true, loss | |
def setup_tester(self): | |
if self.model_name == 'seq2seqatt': | |
tester = self.testing_seq2seqatt | |
elif self.model_name == 'seq2seqtransformer': | |
tester = self.testing_transformer_seq2seq | |
elif (self.model_name == 'transformer' and not self.classification) or (self.model_name == 'transformertsai' and not self.classification): | |
tester = self.testing_transformer | |
elif self.classification: | |
tester = self.testing_w_classification | |
else: | |
tester = self.testing | |
return tester | |
def testing(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y in test_dataloader: | |
x = x.to(device) | |
y = y.to(device) | |
y_pred = model(x.float()) | |
loss = criterion(y, y_pred) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
# wandb.log({"Test Loss": test_loss}) | |
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss | |
def testing_w_classification(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y, y_label in test_dataloader: | |
x = x.to(device).float() | |
y_label = y_label.type(torch.LongTensor).to(device) # The targets passed to nn.CrossEntropyLoss() should be in torch.long format | |
y = y.to(device) | |
y_pred = model(x) | |
y_pred[0] = y_pred[0].double() | |
y_pred[1] = y_pred[1].double() | |
y_true = [y, y_label] | |
loss = criterion(y_pred, y_true) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y_true) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
wandb.log({"Test Loss": test_loss}) | |
test_preds_reg = [] | |
test_trues_reg = [] | |
for pred, true in zip(test_preds, test_trues): | |
test_preds_reg.append(pred[0]) | |
test_trues_reg.append(true[0]) | |
return torch.cat(test_preds_reg, 0), torch.cat(test_trues_reg, 0), test_loss | |
def testing_seq2seq(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y in test_dataloader: | |
x = x.to(device) | |
y = y.to(device) | |
# y_pred = model(x.float(), y.float()) # just for seq 2 seq | |
y_pred = Seq2SeqTest(model, x.float()) | |
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :]) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
wandb.log({"Test Loss": test_loss}) | |
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss | |
def testing_seq2seqatt(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y in test_dataloader: | |
x = x.to(device) | |
y = y.to(device) | |
# y_pred = model(x.float(), y.float()) # just for seq 2 seq | |
y_pred = Seq2SeqAttTest(model, x.float()) | |
loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :]) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
wandb.log({"Test Loss": test_loss}) | |
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss | |
def testing_transformer(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y in test_dataloader: | |
x = x.to(device) | |
y = y.to(device) | |
y_pred = model(x.float()) # just for transformer | |
loss = criterion(y, y_pred.to(device)) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
wandb.log({"Test Loss": test_loss}) | |
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss | |
def testing_transformer_seq2seq(self, model, test_dataloader, criterion, device): | |
model.eval() | |
with torch.no_grad(): | |
test_loss = [] | |
test_preds = [] | |
test_trues = [] | |
for x, y in test_dataloader: | |
x = x.to(device) | |
y = y.to(device) | |
y_pred = Seq2SeqTransformerTest(model, x.float()) | |
# y_pred = model(x.float(), y.float()[:, :-1, :]) # just for seq 2 seq transformer | |
loss = criterion(y_pred, y.to(device)) | |
test_loss.append(loss.item()) | |
test_preds.append(y_pred) | |
test_trues.append(y[:, 1:, :]) | |
test_loss = torch.mean(torch.tensor(test_loss)) | |
print('Test Accuracy of the model: {}'.format(test_loss)) | |
# wandb.log({"Test Loss": test_loss}) | |
return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss |