clane9 commited on
Commit
eda2644
·
1 Parent(s): de4850a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -0
README.md CHANGED
@@ -7,3 +7,38 @@ license: mit
7
  ![Example training predictions](example.png)
8
 
9
  A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat](https://huggingface.co/datasets/clane9/NSD-Flat). The training objective was to auto-regressively predict the next patch with shuffled patch order.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ![Example training predictions](example.png)
8
 
9
  A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat](https://huggingface.co/datasets/clane9/NSD-Flat). The training objective was to auto-regressively predict the next patch with shuffled patch order.
10
+
11
+ ## Dependencies
12
+
13
+ - [boldGPT](https://github.com/clane9/boldGPT)
14
+ - [huggingface_hub](https://huggingface.co/docs/huggingface_hub/index)
15
+ - [safetensors](https://huggingface.co/docs/safetensors/index)
16
+
17
+ ## Usage
18
+
19
+ ```python
20
+ from boldgpt.data import ActivityTransform
21
+ from boldgpt.models import create_model
22
+ from datasets import load_dataset
23
+ from huggingface_hub import hf_hub_download
24
+ from safetensors.torch import load_model
25
+
26
+ model = create_model("boldgpt_small_patch10")
27
+
28
+ load_model(
29
+ model,
30
+ hf_hub_download(
31
+ repo_id="clane9/boldgpt_small_patch10", filename="model.safetensors"
32
+ ),
33
+ )
34
+
35
+ dataset = load_dataset("clane9/NSD-Flat", split="train")
36
+ dataset.set_format("torch")
37
+ batch = dataset[:1]
38
+
39
+ transform = ActivityTransform()
40
+ batch["activity"] = transform(batch["activity"])
41
+
42
+ # output: (B, N, K) predicted next token logits
43
+ output, state = model(batch)
44
+ ```