License change to apache-2.0
Browse files- README.md +8 -6
- inference_brain2vec.py +19 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
license:
|
3 |
language:
|
4 |
- en
|
5 |
task_categories:
|
@@ -52,11 +52,13 @@ nohup python brain2vec.py train \
|
|
52 |
--n_epochs 10 \
|
53 |
> train_log.txt 2>&1 &
|
54 |
|
55 |
-
#
|
56 |
-
python
|
57 |
-
--
|
58 |
-
--
|
59 |
-
--output_dir
|
|
|
|
|
60 |
```
|
61 |
|
62 |
# Methods
|
|
|
1 |
---
|
2 |
+
license: apache-2.0
|
3 |
language:
|
4 |
- en
|
5 |
task_categories:
|
|
|
52 |
--n_epochs 10 \
|
53 |
> train_log.txt 2>&1 &
|
54 |
|
55 |
+
# model inference
|
56 |
+
python inference_brain2vec.py \
|
57 |
+
--checkpoint_path /path/to/model.pth \
|
58 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
59 |
+
--output_dir ./vae_inference_outputs \
|
60 |
+
--embeddings_filename pca_output/pca_embeddings_2.npy \
|
61 |
+
--save_recons
|
62 |
```
|
63 |
|
64 |
# Methods
|
inference_brain2vec.py
CHANGED
@@ -156,6 +156,18 @@ def main() -> None:
|
|
156 |
"--csv_input", type=str,
|
157 |
help="Path to a CSV file with an 'image_path' column."
|
158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
args = parser.parse_args()
|
160 |
|
161 |
os.makedirs(args.output_dir, exist_ok=True)
|
@@ -198,6 +210,7 @@ def main() -> None:
|
|
198 |
z_sigma_np = z_sigma.detach().cpu().numpy()
|
199 |
|
200 |
# Save each reconstruction (per image) as .npy
|
|
|
201 |
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
202 |
np.save(recon_path, recon_np)
|
203 |
print(f"[INFO] Saved reconstruction to {recon_path}")
|
@@ -210,8 +223,13 @@ def main() -> None:
|
|
210 |
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
211 |
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
212 |
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
|
|
215 |
np.save(mu_path, stacked_mu)
|
216 |
np.save(sigma_path, stacked_sigma)
|
217 |
|
|
|
156 |
"--csv_input", type=str,
|
157 |
help="Path to a CSV file with an 'image_path' column."
|
158 |
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--embeddings_filename",
|
161 |
+
type=str,
|
162 |
+
required=True,
|
163 |
+
help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--save_recons",
|
167 |
+
action="store_true",
|
168 |
+
help="If set, saves each reconstruction as .npy. Default is not to save."
|
169 |
+
)
|
170 |
+
|
171 |
args = parser.parse_args()
|
172 |
|
173 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
210 |
z_sigma_np = z_sigma.detach().cpu().numpy()
|
211 |
|
212 |
# Save each reconstruction (per image) as .npy
|
213 |
+
if args.save_recons:
|
214 |
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
215 |
np.save(recon_path, recon_np)
|
216 |
print(f"[INFO] Saved reconstruction to {recon_path}")
|
|
|
223 |
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
224 |
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
225 |
|
226 |
+
mu_filename = args.embeddings_filename
|
227 |
+
if not mu_filename.lower().endswith(".npy"):
|
228 |
+
mu_filename += ".npy"
|
229 |
+
|
230 |
+
mu_path = os.path.join(args.output_dir, mu_filename)
|
231 |
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
232 |
+
|
233 |
np.save(mu_path, stacked_mu)
|
234 |
np.save(sigma_path, stacked_sigma)
|
235 |
|