clane9 commited on
Commit
72efd6d
·
1 Parent(s): a91432e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -17
README.md CHANGED
@@ -1,18 +1,16 @@
1
  ---
2
- license: mit
3
  ---
4
 
5
- # Model card for boldgpt_small_patch10
6
 
7
  ![Example training predictions](example.png)
8
 
9
- A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat](https://huggingface.co/datasets/clane9/NSD-Flat). The training objective was to auto-regressively predict the next patch with shuffled patch order.
10
 
11
  ## Dependencies
12
 
13
  - [boldGPT](https://github.com/clane9/boldGPT)
14
- - [huggingface_hub](https://huggingface.co/docs/huggingface_hub/index)
15
- - [safetensors](https://huggingface.co/docs/safetensors/index)
16
 
17
  ## Usage
18
 
@@ -20,25 +18,16 @@ A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat]
20
  from boldgpt.data import ActivityTransform
21
  from boldgpt.models import create_model
22
  from datasets import load_dataset
23
- from huggingface_hub import hf_hub_download
24
- from safetensors.torch import load_model
25
 
26
- model = create_model("boldgpt_small_patch10")
27
-
28
- load_model(
29
- model,
30
- hf_hub_download(
31
- repo_id="clane9/boldgpt_small_patch10", filename="model.safetensors"
32
- ),
33
- )
34
 
35
  dataset = load_dataset("clane9/NSD-Flat", split="train")
36
  dataset.set_format("torch")
37
- batch = dataset[:1]
38
 
39
  transform = ActivityTransform()
 
40
  batch["activity"] = transform(batch["activity"])
41
 
42
- # output: (B, N, K) predicted next token logits
43
  output, state = model(batch)
44
  ```
 
1
  ---
2
+ license: cc-by-nc-4.0
3
  ---
4
 
5
+ # Model card for `boldgpt_small_patch10.kmq`
6
 
7
  ![Example training predictions](example.png)
8
 
9
+ A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat](https://huggingface.co/datasets/clane9/NSD-Flat). Patches were quantized to discrete tokens using k-means (`KMeansTokenizer`). The training objective was to auto-regressively predict the next patch with shuffled patch order and cross-entropy loss.
10
 
11
  ## Dependencies
12
 
13
  - [boldGPT](https://github.com/clane9/boldGPT)
 
 
14
 
15
  ## Usage
16
 
 
18
  from boldgpt.data import ActivityTransform
19
  from boldgpt.models import create_model
20
  from datasets import load_dataset
 
 
21
 
22
+ model = create_model("boldgpt_small_patch10.kmq", pretrained=True)
 
 
 
 
 
 
 
23
 
24
  dataset = load_dataset("clane9/NSD-Flat", split="train")
25
  dataset.set_format("torch")
 
26
 
27
  transform = ActivityTransform()
28
+ batch = dataset[:1]
29
  batch["activity"] = transform(batch["activity"])
30
 
31
+ # output: (B, N + 1, K) predicted next token logits
32
  output, state = model(batch)
33
  ```