motionsh commited on
Commit
732e0d9
·
1 Parent(s): adf1a95

Upload codes

Browse files
Files changed (5) hide show
  1. config.py +34 -0
  2. dataset.py +261 -0
  3. datasetbuilder.py +62 -0
  4. modelbuilder.py +127 -0
  5. test.py +164 -0
config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import yaml
3
+
4
+ def get_config_universal(dataset_name):
5
+ with open('./configs/' + dataset_name + '_config.json') as f:
6
+ config = json.load(f)
7
+ return config
8
+
9
+
10
+ def get_sweep_config_universal(dataset_name):
11
+ with open('./configs/sweep_' + dataset_name + '_config.yaml') as f:
12
+ config = yaml.load(f,Loader=yaml.FullLoader)
13
+ return config
14
+
15
+
16
+ def get_config():
17
+ with open('configs/camargo_config.json') as f:
18
+ config = json.load(f)
19
+ return config
20
+
21
+ def get_kiha_config():
22
+ with open('./configs/kiha_config.json') as f:
23
+ config = json.load(f)
24
+ return config
25
+
26
+ def get_model_config(model_config):
27
+ with open(f'./configs/{model_config}.json') as f:
28
+ config = json.load(f)
29
+ return config
30
+
31
+ def get_sweep_config():
32
+ with open('configs/sweep_camargo_config.yaml') as f:
33
+ config = yaml.load(f,Loader=yaml.FullLoader)
34
+ return config
dataset.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from loading.loadpickledataset import LoadPickleDataSet
4
+ from preprocessing.augmentation.gaussiannoise import GaussianNoise
5
+ from preprocessing.augmentation.imurotation import IMURotation
6
+ from preprocessing.filter_imu import FilterIMU
7
+ from preprocessing.filter_opensim import FilterOpenSim
8
+ from preprocessing.remove_outlier import remove_outlier
9
+ from preprocessing.resample import Resample
10
+ from preprocessing.segmentation.fixwindowsegmentation import FixWindowSegmentation
11
+ from preprocessing.segmentation.gaitcyclesegmentation import GaitCycleSegmentation
12
+ from preprocessing.segmentation.zeropaddingsegmentation import ZeroPaddingSegmentation
13
+
14
+
15
+ class DataSet:
16
+ def __init__(self, config, load_dataset=True):
17
+ self.config = config
18
+ self.x = []
19
+ self.y = []
20
+ self.labels = []
21
+ self.selected_trial_type = config['selected_trial_type']
22
+ self.selected_activity_label = config['selected_activity_label']
23
+ self.segmentation_method = config['segmentation_method']
24
+ if self.config['gc_dataset']:
25
+ self.segmentation_method = 'zeropadding'
26
+ self.resample = config['resample']
27
+ self.n_sample = len(self.y)
28
+ if load_dataset:
29
+ self.load_dataset()
30
+ self.train_subjects = config['train_subjects']
31
+ self.test_subjects = config['test_subjects']
32
+ self.train_activity = config['train_activity']
33
+ self.test_activity = config['test_activity']
34
+ # self.winsize = 128
35
+ self.train_dataset = {}
36
+ self.test_dataset = {}
37
+
38
+ def load_dataset(self):
39
+ getdata_handler = LoadPickleDataSet(self.config)
40
+ x, y, labels = getdata_handler.run_get_dataset()
41
+ self.x, self.y, self.labels = self.run_activity_based_filter(x, y, labels)
42
+ self._preprocess()
43
+
44
+ def _preprocess(self):
45
+ self.x, self.y, self.labels = remove_outlier(self.x, self.y, self.labels)
46
+ if self.resample:
47
+ self.x, self.y, self.labels = self.run_resample_signal(self.x, self.y, self.labels)
48
+ if self.config['opensim_filter']:
49
+ filteropensim_handler = FilterOpenSim(self.y, lowcut=6, fs=100, order=2)
50
+ self.y = filteropensim_handler.run_lowpass_filter()
51
+ if self.config['imu_filter']:
52
+ filterimu_handler = FilterIMU(self.x, lowcut=10, fs=100, order=2)
53
+ self.x = filterimu_handler.run_lowpass_filter()
54
+
55
+
56
+ def run_resample_signal(self, x, y, labels):
57
+ resample_handler = Resample(x, y, labels, 200, 100)
58
+ x, y, labels = resample_handler._run_resample()
59
+ return x, y, labels
60
+
61
+ def run_segmentation(self, x, y, labels):
62
+ if self.segmentation_method == 'fixedwindow':
63
+ segmentation_handler = FixWindowSegmentation(x, y, labels, winsize=self.config['target_padding_length'], overlap=0.5, start_over=True)
64
+ self.x, self.y, self.labels = segmentation_handler._run_segmentation()
65
+
66
+ elif self.segmentation_method == 'zeropadding':
67
+ segmentation_handler = ZeroPaddingSegmentation(x, y, labels, target_padding_length=self.config['target_padding_length'], start_over=True)
68
+ self.x, self.y, self.labels = segmentation_handler._run_segmentation()
69
+
70
+ elif self.segmentation_method == 'gaitcycle':
71
+ segmentation_handler = GaitCycleSegmentation(x, y, labels, winsize=128, overlap=0.5, start_over=True)
72
+ self.x, self.y, self.labels = segmentation_handler._run_segmentation()
73
+
74
+ if self.config['opensim_filter']:
75
+ filteropensim_handler = FilterOpenSim(self.y, lowcut=6, fs=100, order=2)
76
+ self.y = filteropensim_handler.run_lowpass_filter()
77
+
78
+ if self.config['rotation']:
79
+ imu_rotation_handler = IMURotation(knom=10)
80
+ self.x, self.y, self.labels = imu_rotation_handler.run_rotation(self.x.copy(), self.y.copy(), self.labels.copy())
81
+
82
+ if self.config['gaussian_noise']:
83
+ gaussian_noise_handler = GaussianNoise(0, .05)
84
+ self.x, self.y, self.labels = gaussian_noise_handler.run_add_noise(self.x, self.y, self.labels)
85
+ del x, y, labels
86
+ return self.x, self.y, self.labels
87
+
88
+ def run_activity_based_filter(self, x, y, label):
89
+ '''
90
+ :return: updated x, y, and labels which contains only the selected labels (activity section)
91
+ '''
92
+ updated_x = []
93
+ update_y = []
94
+ updated_label = []
95
+ s = 0
96
+ for ll, xx, yy, in zip(label, x, y):
97
+ # print(ll['subject'][0])
98
+ # print(ll['trialNum'][0])
99
+ if self.config['dataset_name']=='camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all_idle']:
100
+ l_temp = ll[ll['trialType'].isin(self.selected_trial_type)]
101
+ l_temp_index = l_temp.index.values
102
+ xx_temp = xx[l_temp_index]
103
+ yy_temp = yy[l_temp_index]
104
+
105
+ updated_x.append(xx_temp)
106
+ update_y.append(yy_temp)
107
+ updated_label.append(l_temp)
108
+ elif self.config['dataset_name']=='camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all']:
109
+ update_selected_activity_label = list(ll['Label'].unique())
110
+ update_selected_activity_label = [i for i in update_selected_activity_label if i not in ['idle', 'stand']]
111
+ l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['Label'].isin(update_selected_activity_label))]
112
+ l_temp_index = l_temp.index.values
113
+ xx_temp = xx[l_temp_index]
114
+ yy_temp = yy[l_temp_index]
115
+ updated_x.append(xx_temp)
116
+ update_y.append(yy_temp)
117
+ updated_label.append(l_temp)
118
+
119
+ elif self.config['dataset_name'] == 'camargo' and ll['trialType'].isin(self.selected_trial_type).all() and self.selected_activity_label == ['all_split']:
120
+ ll_temp = ll.copy()
121
+ ll_temp['trialType2'] =ll_temp['Label']
122
+ if ll['trialType'][0] =='levelground':
123
+ # get the turn index if it's there
124
+ turn1_indx = ll_temp[ll_temp['Label'] == 'turn1'].index.values
125
+ turn2_indx = ll_temp[ll_temp['Label'] == 'turn2'].index.values
126
+ # check which turn is turn 1
127
+ if turn1_indx[0]<turn2_indx[0]:
128
+ pass
129
+ else:
130
+ turn2_indx_temp = turn1_indx
131
+ turn1_indx = turn2_indx
132
+ turn2_indx = turn2_indx_temp
133
+ # devide into two segments
134
+ seg1 = ll_temp.iloc[0:turn1_indx[-1]+1]
135
+ seg2 = ll_temp.iloc[turn2_indx[0]:]
136
+ seg1_trialType2 = seg1['trialType2'].replace({'idle': 'idle', 'stand': 'idle', 'turn1': 'idle', 'turn2': 'idle',
137
+ 'stand-walk':'levelground1', 'walk':'levelground1',
138
+ 'walk-stand': 'levelground1'})
139
+ seg2_trialType2 = seg2['trialType2'].replace({'idle': 'idle', 'stand': 'idle', 'turn1': 'idle','turn2': 'idle',
140
+ 'stand-walk':'levelground2', 'walk':'levelground2',
141
+ 'walk-stand': 'levelground2'})
142
+ ll_temp['trialType2'] = pd.concat([seg1_trialType2, seg2_trialType2])
143
+ ll = ll_temp
144
+ elif ll['trialType'][0] =='ramp':
145
+ ll_temp['trialType2'] = ll_temp['trialType2'].replace({'idle': 'idle',
146
+ 'walk-rampascent': 'rampascent', 'rampascent':'rampascent','rampascent-walk': 'rampascent',
147
+ 'walk-rampdescent': 'rampdescent', 'rampdescent':'rampdescent','rampdescent-walk': 'rampdescent'})
148
+ ll = ll_temp
149
+ elif ll['trialType'][0] == 'stair':
150
+ ll_temp['trialType2'] = ll_temp['trialType2'].replace({'idle': 'idle',
151
+ 'walk-stairascent': 'stairascent', 'stairascent':'stairascent','stairascent-walk': 'stairascent',
152
+ 'walk-stairdescent': 'stairdescent', 'stairdescent':'stairdescent','stairdescent-walk': 'stairdescent'})
153
+ ll = ll_temp
154
+
155
+ update_selected_activity_label = list(ll['trialType2'].unique())
156
+ # remove stand, idle, turn1, turn2 samples
157
+ update_selected_activity_label = [i for i in update_selected_activity_label if
158
+ i not in ['idle']]
159
+ for activity_label in update_selected_activity_label:
160
+ # if trial type == levelground ->save stand-walk and walk into one trial and walk-stand into another trial. all samples would be continues
161
+ # if ramp or stair--> save trial for ascent and descent individually
162
+ if isinstance(activity_label, str):
163
+ l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['trialType2']==activity_label)]
164
+ l_temp_index = l_temp.index.values
165
+ xx_temp = xx[l_temp_index]
166
+ yy_temp = yy[l_temp_index]
167
+ updated_x.append(xx_temp)
168
+ update_y.append(yy_temp)
169
+ updated_label.append(l_temp)
170
+ if len(xx_temp)==0:
171
+ print(i)
172
+ elif self.config['dataset_name']=='camargo':
173
+ l_temp = ll[(ll['trialType'].isin(self.selected_trial_type)) & (ll['Label'].isin(self.selected_activity_label))]
174
+ l_temp_index = l_temp.index.values
175
+ xx_temp = xx[l_temp_index]
176
+ yy_temp = yy[l_temp_index]
177
+
178
+ updated_x.append(xx_temp)
179
+ update_y.append(yy_temp)
180
+ updated_label.append(l_temp)
181
+ elif self.config['dataset_name']=='kiha':
182
+ l_temp = ll[(ll['trialType'].isin(self.selected_trial_type))]
183
+ l_temp_index = l_temp.index.values
184
+ xx_temp = xx[l_temp_index]
185
+ yy_temp = yy[l_temp_index]
186
+
187
+
188
+ updated_x.append(xx_temp)
189
+ update_y.append(yy_temp)
190
+ updated_label.append(l_temp)
191
+ # else:
192
+ # continue
193
+ return updated_x, update_y, updated_label
194
+
195
+ def concatenate_data(self):
196
+ self.labels = pd.concat(self.labels, axis=0, ignore_index = True)
197
+ self.x = np.concatenate(self.x, axis=0)
198
+ self.y = np.concatenate(self.y, axis=0)
199
+
200
+ def run_dataset_split_loop(self):
201
+ train_labels = []
202
+ test_labels = []
203
+ train_x = []
204
+ train_y = []
205
+ test_x = []
206
+ test_y = []
207
+ for t, trial in enumerate(self.labels):
208
+ if all(trial['subject'].isin(self.train_subjects)) and all(trial['trialType2'].isin(self.train_activity)):
209
+ train_labels.append(trial)
210
+ train_x.append(self.x[t])
211
+ train_y.append(self.y[t])
212
+
213
+ elif all(trial['subject'].isin(self.test_subjects)) and all(trial['trialType2'].isin(self.test_activity)):
214
+ test_labels.append(trial)
215
+ test_x.append(self.x[t])
216
+ test_y.append(self.y[t])
217
+
218
+ self.train_dataset['x'] = train_x
219
+ self.train_dataset['y'] = train_y
220
+ self.train_dataset['labels'] = train_labels
221
+
222
+ self.test_dataset['x'] = test_x
223
+ self.test_dataset['y'] = test_y
224
+ self.test_dataset['labels'] = test_labels
225
+ return self.train_dataset, self.test_dataset
226
+
227
+ def run_dataset_split(self):
228
+ if set(self.test_subjects).issubset(self.train_subjects):
229
+ train_labels = self.labels[~self.labels['subject'].isin(self.test_subjects)]
230
+ test_labels = self.labels[(self.labels['subjects'].isin(self.test_subjects))]
231
+ else:
232
+ train_labels = self.labels[self.labels['subject'].isin(self.train_subjects)]
233
+ test_labels = self.labels[(self.labels['subject'].isin(self.test_subjects))]
234
+ print(train_labels['subject'].unique())
235
+ print(test_labels['subject'].unique())
236
+
237
+
238
+ train_index = train_labels.index.values
239
+ test_index = test_labels.index.values
240
+ print('training length', len(train_index))
241
+ print('test length', len(test_index))
242
+
243
+ train_x = self.x[train_index]
244
+ train_y = self.y[train_index]
245
+ # self.train_dataset['x'] = train_x.reshape([int(train_x.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], train_x.shape[1]])
246
+ # self.train_dataset['y'] = train_y.reshape([int(train_y.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], train_y.shape[1]])
247
+ self.train_dataset['x'] = train_x
248
+ self.train_dataset['y'] = train_y
249
+ self.train_dataset['labels'] = train_labels.reset_index(drop=True)
250
+
251
+ test_x = self.x[test_index]
252
+ test_y = self.y[test_index]
253
+ # self.test_dataset['x'] = test_x.reshape([int(test_x.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], test_x.shape[1]])
254
+ # self.test_dataset['y'] = test_y.reshape([int(test_y.shape[0]/self.config['target_padding_length']), self.config['target_padding_length'], test_y.shape[1]])
255
+ self.test_dataset['x'] = test_x
256
+ self.test_dataset['y'] = test_y
257
+ self.test_dataset['labels'] = test_labels.reset_index(drop=True)
258
+ del train_labels, test_labels, train_x, train_y, test_x, test_y
259
+ return self.train_dataset, self.test_dataset
260
+
261
+
datasetbuilder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from preprocessing.augmentation.gaussiannoise import GaussianNoise
5
+ from preprocessing.transformation.transformation import Transformation
6
+ import torch.nn.functional as F
7
+ from sklearn import preprocessing
8
+ import torch
9
+
10
+
11
+ class DataSetBuilder(Dataset):
12
+ def __init__(self, x, y, labels, transform_method=None, scaler=None, noise=None, classification=None):
13
+ self.x = x
14
+ self.y = y
15
+ self.labels = labels
16
+ self.y_label = []
17
+
18
+ self.transform_method = transform_method
19
+ self.scaler = scaler
20
+ self.noise = noise
21
+ self.classification = classification
22
+ self._preprocess()
23
+ if self.classification:
24
+ self._run_label_encoding()
25
+ self.n_sample = len(y)
26
+
27
+ # x = np.transpose(self.x, (0, 2, 1))
28
+ self.x = torch.from_numpy(x).double()
29
+ self.y = torch.from_numpy(self.y).double()
30
+
31
+ def _run_label_encoding(self):
32
+ le = preprocessing.LabelEncoder()
33
+ y_label = le.fit_transform(self.labels[:, 0, 3])
34
+ y_label = torch.as_tensor(y_label)
35
+ # self.y_label = F.one_hot(y_label.to(torch.int64))
36
+ self.y_label = y_label.to(torch.int64)
37
+
38
+ def _preprocess(self):
39
+ if self.transform_method['data_transformer_method'] is not None:
40
+ self._run_transform()
41
+ if self.noise is not None:
42
+ self._run_noise()
43
+
44
+ def _run_transform(self):
45
+ transform_handler = Transformation(method=self.transform_method['data_transformer_method'], by=self.transform_method['data_transformer_by'])
46
+ if self.scaler is None:
47
+ self.scaler, self.x = transform_handler.run_transform(train=self.x, scaler_fit=self.scaler)
48
+ else:
49
+ self.x = transform_handler.run_transform(val=self.x, scaler_fit=self.scaler)
50
+
51
+ def _run_noise(self, ):
52
+ gaussiannoise_handler = GaussianNoise(mean=0, std=1)
53
+ self.x, self.y, self.labels = gaussiannoise_handler.run_add_noise(self.x, self.y, self.labels)
54
+
55
+ def __len__(self):
56
+ return self.n_sample
57
+
58
+ def __getitem__(self, item):
59
+ if self.classification:
60
+ return self.x[item], self.y[item], self.y_label[item]
61
+ else:
62
+ return self.x[item], self.y[item]
modelbuilder.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ # from tsai.models.TST import TST
4
+ from sklearn.neighbors import KNeighborsRegressor
5
+ from config import get_model_config
6
+ from loss.weightedmseloss import WeightedMSELoss
7
+ from loss.weightedmultioutputsloss import WeightedMultiOutputLoss
8
+ from loss.weightedrmseloss import WeightedRMSELoss
9
+ from model.Hernandez2021cnnlstm import Hernandez2021CNNLSTM
10
+ from model.bilstmmodel import BiLSTMModel
11
+ from model.cnnlstm import CNNLSTM
12
+ from model.dorschky2020cnn import Dorschky2020CNN
13
+ from model.gholami2020cnn import Gholami2020CNN
14
+ from model.lstmlstm import Seq2Seq
15
+ from model.lstmlstmattention import Seq2SeqAtt
16
+ from model.lstmlstmrec import Seq2SeqRec
17
+ from model.lstmmodel import LSTMModel
18
+ from model.tcnmodel import TCNModel
19
+ from model.transformer import Transformer
20
+ from model.transformer_seq2seq import Seq2SeqTransformer
21
+ from model.transformer_tsai import TransformerTSAI
22
+ from model.zrenner2018cnn import Zrenner2018CNN
23
+ from utils.update_config import update_model_config
24
+
25
+
26
+ class ModelBuilder:
27
+ def __init__(self, config):
28
+ self.config = config
29
+ self.n_input_channel = len(self.config['selected_sensors'])*6
30
+ self.n_output = len(self.config['selected_opensim_labels'])
31
+ self.model_name = self.config['model_name']
32
+ self.model_config = get_model_config(f'config_{self.model_name}')
33
+ self.model_config = update_model_config(self.config, self.model_config)
34
+ self.optimizer_name = self.config['optimizer_name']
35
+ self.learning_rate = self.config['learning_rate']
36
+ self.l2_weight_decay_status = self.config['l2_weight_decay_status']
37
+ self.l2_weight_decay = self.config['l2_weight_decay']
38
+ self.loss = self.config['loss']
39
+ self.weight = self.config['loss_weight']
40
+ self.device = self.config['device']
41
+ # self.device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
42
+ if not self.n_output == len(self.weight):
43
+ self.weight = None
44
+
45
+ def run_model_builder(self):
46
+ model = self.get_model_architecture()
47
+ criterion = self.get_criterion(self.weight)
48
+ optimizer = self.get_optimizer()
49
+ return model, optimizer, criterion
50
+
51
+ def get_model_architecture(self):
52
+ if self.model_name == 'lstm': # done
53
+ self.model = LSTMModel(self.model_config)
54
+ elif self.model_name == 'bilstm': # done
55
+ self.model = BiLSTMModel(self.model_config)
56
+ elif self.model_name == 'cnnlstm': # done
57
+ self.model = CNNLSTM(self.model_config)
58
+ elif self.model_name == 'hernandez2021cnnlstm': # done
59
+ self.model = Hernandez2021CNNLSTM(self.model_config)
60
+ elif self.model_name == 'seq2seq': # done
61
+ self.model = Seq2Seq(self.config)
62
+ elif self.model_name == 'seq2seqrec':
63
+ self.model = Seq2SeqRec(self.n_input_channel, self.n_output)
64
+ elif self.model_name == 'seq2seqatt':# done
65
+ self.model = Seq2SeqAtt(self.model_config)
66
+ elif self.model_name == 'transformer': #done
67
+ self.model = Transformer(d_input=self.n_input_channel, d_model=12, d_output=self.n_output, d_len=self.config['target_padding_length'], h=8, N=1, attention_size=None,
68
+ dropout=0.5, chunk_mode=None, pe='original', multihead=True)
69
+ elif self.model_name == 'seq2seqtransformer':
70
+ self.model = Seq2SeqTransformer(d_input=self.n_input_channel, d_model=24, d_output=self.n_output, h=8, N=4, attention_size=None,
71
+ dropout=0.1, chunk_mode=None, pe='original')
72
+ elif self.model_name == 'transformertsai':
73
+ c_in = self.n_input_channel # aka channels, features, variables, dimensions
74
+ c_out = self.n_output
75
+ seq_len = self.config['target_padding_length']
76
+ y_range = self.config['target_padding_length']
77
+ max_seq_len = self.config['target_padding_length']
78
+ d_model = self.model_config['tsai_d_model']
79
+ n_heads = self.model_config['tsai_n_heads']
80
+ d_k = d_v = None # if None --> d_model // n_heads
81
+ d_ff = self.model_config['tsai_d_ff']
82
+ res_dropout = self.model_config['tsai_res_dropout_p']
83
+ activation = "gelu"
84
+ n_layers = self.model_config['tsai_n_layers']
85
+ fc_dropout = self.model_config['tsai_fc_dropout_p']
86
+ classification = self.model_config['classification']
87
+ kwargs = {}
88
+ self.model = TransformerTSAI(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads,
89
+ d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, act=activation, n_layers=n_layers,
90
+ fc_dropout=fc_dropout, classification=classification, **kwargs)
91
+ elif self.model_name == 'Gholami2020CNN':
92
+ self.model = Gholami2020CNN(self.model_config)
93
+ elif self.model_name == 'Dorschky2020CNN':
94
+ self.model = Dorschky2020CNN(self.model_config)
95
+ elif self.model_name == 'Zrenner2018CNN':
96
+ self.model = Zrenner2018CNN(self.model_config)
97
+ elif self.model_name == 'tcn':
98
+ self.model = TCNModel(self.model_config)
99
+ elif self.model_name == 'knn':
100
+ self.model = KNeighborsRegressor()
101
+ return self.model
102
+
103
+ def get_optimizer(self):
104
+ if self.optimizer_name == 'Adam':
105
+ if self.l2_weight_decay_status:
106
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.l2_weight_decay)
107
+ else:
108
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
109
+ return self.optimizer
110
+
111
+ def get_criterion(self, weight=None):
112
+ if self.loss == 'RMSE' and weight is not None:
113
+ weight = torch.tensor(weight).to(self.device)
114
+ self.criterion = WeightedRMSELoss(weight)
115
+ elif self.loss == 'RMSE' and weight is None:
116
+ self.criterion = torch.sqrt(nn.MSELoss())
117
+ elif self.loss == 'MSE' and weight is not None:
118
+ weight = torch.tensor(weight).to(self.device)
119
+ self.criterion = WeightedMSELoss(weight)
120
+ elif self.loss == 'MSE-CE' and weight is not None:
121
+ weight = torch.tensor(weight).to(self.device)
122
+ self.criterion = WeightedMultiOutputLoss(weight)
123
+ else:
124
+ self.criterion = nn.MSELoss()
125
+ return self.criterion
126
+
127
+
test.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import wandb
3
+
4
+ from model.lstmlstm import Seq2SeqTest
5
+ from model.lstmlstmattention import Seq2SeqAttTest
6
+ from model.transformer_seq2seq import Seq2SeqTransformerTest
7
+ from modelbuilder import ModelBuilder
8
+
9
+
10
+ class Test:
11
+ def run_testing(self, config, model, test_dataloader):
12
+ self.config = config
13
+ self.device = config['device']
14
+ self.loss = self.config['loss']
15
+ self.weight = self.config['loss_weight']
16
+ self.model_name = self.config['model_name']
17
+ self.classification = config['classification']
18
+ self.n_output = len(self.config['selected_opensim_labels'])
19
+ if not self.n_output == len(self.weight):
20
+ self.weight = None
21
+ modelbuilder_handler = ModelBuilder(self.config)
22
+ criterion = modelbuilder_handler.get_criterion(self.weight)
23
+ self.tester = self.setup_tester()
24
+ y_pred, y_true, loss = self.tester(model, test_dataloader, criterion, self.device)
25
+ return y_pred, y_true, loss
26
+
27
+ def setup_tester(self):
28
+ if self.model_name == 'seq2seqatt':
29
+ tester = self.testing_seq2seqatt
30
+ elif self.model_name == 'seq2seqtransformer':
31
+ tester = self.testing_transformer_seq2seq
32
+ elif (self.model_name == 'transformer' and not self.classification) or (self.model_name == 'transformertsai' and not self.classification):
33
+ tester = self.testing_transformer
34
+ elif self.classification:
35
+ tester = self.testing_w_classification
36
+ else:
37
+ tester = self.testing
38
+ return tester
39
+
40
+ def testing(self, model, test_dataloader, criterion, device):
41
+ model.eval()
42
+ with torch.no_grad():
43
+ test_loss = []
44
+ test_preds = []
45
+ test_trues = []
46
+ for x, y in test_dataloader:
47
+ x = x.to(device)
48
+ y = y.to(device)
49
+ y_pred = model(x.float())
50
+ loss = criterion(y, y_pred)
51
+ test_loss.append(loss.item())
52
+ test_preds.append(y_pred)
53
+ test_trues.append(y)
54
+ test_loss = torch.mean(torch.tensor(test_loss))
55
+ print('Test Accuracy of the model: {}'.format(test_loss))
56
+ # wandb.log({"Test Loss": test_loss})
57
+ return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
58
+
59
+ def testing_w_classification(self, model, test_dataloader, criterion, device):
60
+ model.eval()
61
+ with torch.no_grad():
62
+ test_loss = []
63
+ test_preds = []
64
+ test_trues = []
65
+ for x, y, y_label in test_dataloader:
66
+ x = x.to(device).float()
67
+ y_label = y_label.type(torch.LongTensor).to(device) # The targets passed to nn.CrossEntropyLoss() should be in torch.long format
68
+ y = y.to(device)
69
+ y_pred = model(x)
70
+ y_pred[0] = y_pred[0].double()
71
+ y_pred[1] = y_pred[1].double()
72
+ y_true = [y, y_label]
73
+ loss = criterion(y_pred, y_true)
74
+ test_loss.append(loss.item())
75
+ test_preds.append(y_pred)
76
+ test_trues.append(y_true)
77
+ test_loss = torch.mean(torch.tensor(test_loss))
78
+ print('Test Accuracy of the model: {}'.format(test_loss))
79
+ wandb.log({"Test Loss": test_loss})
80
+ test_preds_reg = []
81
+ test_trues_reg = []
82
+ for pred, true in zip(test_preds, test_trues):
83
+ test_preds_reg.append(pred[0])
84
+ test_trues_reg.append(true[0])
85
+ return torch.cat(test_preds_reg, 0), torch.cat(test_trues_reg, 0), test_loss
86
+
87
+ def testing_seq2seq(self, model, test_dataloader, criterion, device):
88
+ model.eval()
89
+ with torch.no_grad():
90
+ test_loss = []
91
+ test_preds = []
92
+ test_trues = []
93
+ for x, y in test_dataloader:
94
+ x = x.to(device)
95
+ y = y.to(device)
96
+ # y_pred = model(x.float(), y.float()) # just for seq 2 seq
97
+ y_pred = Seq2SeqTest(model, x.float())
98
+ loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
99
+ test_loss.append(loss.item())
100
+ test_preds.append(y_pred)
101
+ test_trues.append(y)
102
+ test_loss = torch.mean(torch.tensor(test_loss))
103
+ print('Test Accuracy of the model: {}'.format(test_loss))
104
+ wandb.log({"Test Loss": test_loss})
105
+ return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
106
+
107
+ def testing_seq2seqatt(self, model, test_dataloader, criterion, device):
108
+ model.eval()
109
+ with torch.no_grad():
110
+ test_loss = []
111
+ test_preds = []
112
+ test_trues = []
113
+ for x, y in test_dataloader:
114
+ x = x.to(device)
115
+ y = y.to(device)
116
+ # y_pred = model(x.float(), y.float()) # just for seq 2 seq
117
+ y_pred = Seq2SeqAttTest(model, x.float())
118
+ loss = criterion(y_pred[:, 1:, :].to(device), y[:, 1:, :])
119
+ test_loss.append(loss.item())
120
+ test_preds.append(y_pred)
121
+ test_trues.append(y)
122
+ test_loss = torch.mean(torch.tensor(test_loss))
123
+ print('Test Accuracy of the model: {}'.format(test_loss))
124
+ wandb.log({"Test Loss": test_loss})
125
+ return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
126
+
127
+ def testing_transformer(self, model, test_dataloader, criterion, device):
128
+ model.eval()
129
+ with torch.no_grad():
130
+ test_loss = []
131
+ test_preds = []
132
+ test_trues = []
133
+ for x, y in test_dataloader:
134
+ x = x.to(device)
135
+ y = y.to(device)
136
+ y_pred = model(x.float()) # just for transformer
137
+ loss = criterion(y, y_pred.to(device))
138
+ test_loss.append(loss.item())
139
+ test_preds.append(y_pred)
140
+ test_trues.append(y)
141
+ test_loss = torch.mean(torch.tensor(test_loss))
142
+ print('Test Accuracy of the model: {}'.format(test_loss))
143
+ wandb.log({"Test Loss": test_loss})
144
+ return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss
145
+
146
+ def testing_transformer_seq2seq(self, model, test_dataloader, criterion, device):
147
+ model.eval()
148
+ with torch.no_grad():
149
+ test_loss = []
150
+ test_preds = []
151
+ test_trues = []
152
+ for x, y in test_dataloader:
153
+ x = x.to(device)
154
+ y = y.to(device)
155
+ y_pred = Seq2SeqTransformerTest(model, x.float())
156
+ # y_pred = model(x.float(), y.float()[:, :-1, :]) # just for seq 2 seq transformer
157
+ loss = criterion(y_pred, y.to(device))
158
+ test_loss.append(loss.item())
159
+ test_preds.append(y_pred)
160
+ test_trues.append(y[:, 1:, :])
161
+ test_loss = torch.mean(torch.tensor(test_loss))
162
+ print('Test Accuracy of the model: {}'.format(test_loss))
163
+ # wandb.log({"Test Loss": test_loss})
164
+ return torch.cat(test_preds, 0), torch.cat(test_trues, 0), test_loss