Allan Victor commited on
Commit
6824f92
·
1 Parent(s): ba148d6

train_loop

Browse files
Files changed (1) hide show
  1. Util_funs.py +15 -13
Util_funs.py CHANGED
@@ -242,22 +242,20 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
242
  input_ids, attention_mask,q_token_type_ids, label_id = batch
243
 
244
  # Predictions
245
- _, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
246
 
247
- # Save batch's predictions
248
- prediction = prediction.detach().cpu().squeeze()
249
- label_id = label_id.detach().cpu()
250
- labels.append(label_id.numpy().squeeze())
251
-
252
  logit = feature[1].detach().cpu()
253
- predi_logit.append(logit.numpy())
254
 
255
- feature_lat = feature[0].detach().cpu()
256
- features.append(feature_lat.numpy())
 
257
 
258
  # Accuracy over the test's bach
259
- acc = fn.accuracy(prediction, label_id).item()
260
- all_acc.append(acc)
261
  del input_ids, attention_mask, label_id, batch
262
 
263
  if print_info:
@@ -268,9 +266,13 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
268
  torch.cuda.empty_cache()
269
 
270
  del model_meta, optimizer
 
 
 
 
271
 
272
- return map_feature_tsne(features, labels, predi_logit)
273
-
274
  # Process predictions and map the feature_map in tsne
275
  def map_feature_tsne(features, labels, predi_logit):
276
 
 
242
  input_ids, attention_mask,q_token_type_ids, label_id = batch
243
 
244
  # Predictions
245
+ _, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
246
 
247
+ # prediction = prediction.detach().cpu().squeeze()
248
+ # label_id = label_id.detach().cpu()
 
 
 
249
  logit = feature[1].detach().cpu()
250
+ # feature_lat = feature[0].detach().cpu()
251
 
252
+ # labels.append(label_id.numpy().squeeze())
253
+ # features.append(feature_lat.numpy())
254
+ predi_logit.append(logit.numpy())
255
 
256
  # Accuracy over the test's bach
257
+ # acc = fn.accuracy(prediction, label_id).item()
258
+ # all_acc.append(acc)
259
  del input_ids, attention_mask, label_id, batch
260
 
261
  if print_info:
 
266
  torch.cuda.empty_cache()
267
 
268
  del model_meta, optimizer
269
+
270
+ logits = np.concatenate(np.array(predi_logit,dtype=object))
271
+ logits = torch.tensor(logits.astype(np.float32)).detach().clone()
272
+ # return features, labels, predi_logit
273
 
274
+ return logits.detach().clone()
275
+
276
  # Process predictions and map the feature_map in tsne
277
  def map_feature_tsne(features, labels, predi_logit):
278