English
medical
brain-data
mri
jesseab commited on
Commit
8c84e52
·
1 Parent(s): ac3730a

License change to apache-2.0

Browse files
Files changed (2) hide show
  1. README.md +8 -6
  2. inference_brain2vec.py +19 -1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- license: mit
3
  language:
4
  - en
5
  task_categories:
@@ -52,11 +52,13 @@ nohup python brain2vec.py train \
52
  --n_epochs 10 \
53
  > train_log.txt 2>&1 &
54
 
55
- # run model inference to create *_embeddings.npz files
56
- python brain2vec.py infererence \
57
- --dataset_csv home/ubuntu/brain2vec/inputs.csv \
58
- --aekl_ckpt /home/ubuntu/brain2vec/autoencoder_final.pth \
59
- --output_dir /home/ubuntu/brain2vec
 
 
60
  ```
61
 
62
  # Methods
 
1
  ---
2
+ license: apache-2.0
3
  language:
4
  - en
5
  task_categories:
 
52
  --n_epochs 10 \
53
  > train_log.txt 2>&1 &
54
 
55
+ # model inference
56
+ python inference_brain2vec.py \
57
+ --checkpoint_path /path/to/model.pth \
58
+ --input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
59
+ --output_dir ./vae_inference_outputs \
60
+ --embeddings_filename pca_output/pca_embeddings_2.npy \
61
+ --save_recons
62
  ```
63
 
64
  # Methods
inference_brain2vec.py CHANGED
@@ -156,6 +156,18 @@ def main() -> None:
156
  "--csv_input", type=str,
157
  help="Path to a CSV file with an 'image_path' column."
158
  )
 
 
 
 
 
 
 
 
 
 
 
 
159
  args = parser.parse_args()
160
 
161
  os.makedirs(args.output_dir, exist_ok=True)
@@ -198,6 +210,7 @@ def main() -> None:
198
  z_sigma_np = z_sigma.detach().cpu().numpy()
199
 
200
  # Save each reconstruction (per image) as .npy
 
201
  recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
202
  np.save(recon_path, recon_np)
203
  print(f"[INFO] Saved reconstruction to {recon_path}")
@@ -210,8 +223,13 @@ def main() -> None:
210
  stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
211
  stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
212
 
213
- mu_path = os.path.join(args.output_dir, "all_z_mu.npy")
 
 
 
 
214
  sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
 
215
  np.save(mu_path, stacked_mu)
216
  np.save(sigma_path, stacked_sigma)
217
 
 
156
  "--csv_input", type=str,
157
  help="Path to a CSV file with an 'image_path' column."
158
  )
159
+ parser.add_argument(
160
+ "--embeddings_filename",
161
+ type=str,
162
+ required=True,
163
+ help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
164
+ )
165
+ parser.add_argument(
166
+ "--save_recons",
167
+ action="store_true",
168
+ help="If set, saves each reconstruction as .npy. Default is not to save."
169
+ )
170
+
171
  args = parser.parse_args()
172
 
173
  os.makedirs(args.output_dir, exist_ok=True)
 
210
  z_sigma_np = z_sigma.detach().cpu().numpy()
211
 
212
  # Save each reconstruction (per image) as .npy
213
+ if args.save_recons:
214
  recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
215
  np.save(recon_path, recon_np)
216
  print(f"[INFO] Saved reconstruction to {recon_path}")
 
223
  stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
224
  stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
225
 
226
+ mu_filename = args.embeddings_filename
227
+ if not mu_filename.lower().endswith(".npy"):
228
+ mu_filename += ".npy"
229
+
230
+ mu_path = os.path.join(args.output_dir, mu_filename)
231
  sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
232
+
233
  np.save(mu_path, stacked_mu)
234
  np.save(sigma_path, stacked_sigma)
235