jer233 commited on
Commit
f0366e9
·
verified ·
1 Parent(s): 2172586

Create feature_ref_loader.py

Browse files
Files changed (1) hide show
  1. feature_ref_loader.py +23 -0
feature_ref_loader.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils import DEVICE
4
+
5
+
6
+ class FeatureRefLoader:
7
+ def __init__(self):
8
+ print("Feature Ref Loader init")
9
+
10
+ # TODO: The format of feature
11
+ def load(self, feature_ref_file_name, num_ref=5000):
12
+ print("Feature Ref Loader load")
13
+ load_ref_data = torch.load(feature_ref_file_name, map_location=DEVICE) # cpu
14
+ load_ref_data = load_ref_data.to(DEVICE)
15
+ feature_ref = load_ref_data[np.random.permutation(load_ref_data.shape[0])][
16
+ :num_ref
17
+ ].to(DEVICE)
18
+ return feature_ref
19
+
20
+
21
+ feature_two_sample_tester_ref = FeatureRefLoader().load("./feature_ref_for_test.pt")
22
+ feature_hwt_ref = FeatureRefLoader().load("./feature_ref_HWT.pt")
23
+ feature_mgt_ref = FeatureRefLoader().load("./feature_ref_MGT.pt")