Code changes
Browse files- README.md +6 -0
- inference_brain2vec.py +1 -1
README.md
CHANGED
@@ -44,6 +44,12 @@ python create_csv.py
|
|
44 |
mkdir ae_cache
|
45 |
mkdir ae_output
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# train the model
|
48 |
nohup python brain2vec.py train \
|
49 |
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
|
|
44 |
mkdir ae_cache
|
45 |
mkdir ae_output
|
46 |
|
47 |
+
# install git lfs to pull large model weights
|
48 |
+
sudo apt-get update
|
49 |
+
sudo apt install git-lfs
|
50 |
+
git lfs install
|
51 |
+
git lfs pull
|
52 |
+
|
53 |
# train the model
|
54 |
nohup python brain2vec.py train \
|
55 |
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
inference_brain2vec.py
CHANGED
@@ -119,7 +119,7 @@ class Brain2vec(AutoencoderKL):
|
|
119 |
if checkpoint_path is not None:
|
120 |
if not os.path.exists(checkpoint_path):
|
121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
122 |
-
state_dict = torch.load(checkpoint_path, map_location=device
|
123 |
model.load_state_dict(state_dict)
|
124 |
|
125 |
model.to(device)
|
|
|
119 |
if checkpoint_path is not None:
|
120 |
if not os.path.exists(checkpoint_path):
|
121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
122 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
123 |
model.load_state_dict(state_dict)
|
124 |
|
125 |
model.to(device)
|