Update README.md
Browse files
README.md
CHANGED
@@ -32,7 +32,6 @@ import torch
|
|
32 |
from PIL import Image
|
33 |
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextStreamer
|
34 |
|
35 |
-
|
36 |
model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
38 |
processor = AutoProcessor.from_pretrained(model_id)
|
@@ -46,7 +45,7 @@ pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["
|
|
46 |
messages = [
|
47 |
{"role": "user", "content": "What is in this image?"}
|
48 |
]
|
49 |
-
input_ids = tokenizer.apply_chat_template(messages,
|
50 |
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
|
51 |
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
|
52 |
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
|
@@ -70,6 +69,48 @@ The following hyperparameters were used during training:
|
|
70 |
- lr_scheduler_type: cosine
|
71 |
- mixed_precision_training: bf16
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
### Framework versions
|
74 |
|
75 |
- Pytorch 2.3.0
|
|
|
32 |
from PIL import Image
|
33 |
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextStreamer
|
34 |
|
|
|
35 |
model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
|
36 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
37 |
processor = AutoProcessor.from_pretrained(model_id)
|
|
|
45 |
messages = [
|
46 |
{"role": "user", "content": "What is in this image?"}
|
47 |
]
|
48 |
+
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
49 |
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
|
50 |
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
|
51 |
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
|
|
|
69 |
- lr_scheduler_type: cosine
|
70 |
- mixed_precision_training: bf16
|
71 |
|
72 |
+
<details>
|
73 |
+
<summary><b>Show Llama Factory Config [CLICK TO EXPAND]</b></summary>
|
74 |
+
|
75 |
+
```yaml
|
76 |
+
### model
|
77 |
+
model_name_or_path: google/paligemma-3b-mix-448
|
78 |
+
visual_inputs: true
|
79 |
+
|
80 |
+
### method
|
81 |
+
stage: sft
|
82 |
+
do_train: true
|
83 |
+
finetuning_type: full
|
84 |
+
|
85 |
+
### ddp
|
86 |
+
ddp_timeout: 180000000
|
87 |
+
deepspeed: examples/deepspeed/ds_z3_config.json
|
88 |
+
|
89 |
+
### dataset
|
90 |
+
dataset: identity,llava_1k_en,llava_1k_zh
|
91 |
+
template: gemma
|
92 |
+
cutoff_len: 1536
|
93 |
+
overwrite_cache: true
|
94 |
+
preprocessing_num_workers: 16
|
95 |
+
|
96 |
+
### output
|
97 |
+
output_dir: saves/paligemma-chat
|
98 |
+
logging_steps: 10
|
99 |
+
save_steps: 100
|
100 |
+
plot_loss: true
|
101 |
+
|
102 |
+
### train
|
103 |
+
per_device_train_batch_size: 1
|
104 |
+
gradient_accumulation_steps: 8
|
105 |
+
learning_rate: 0.00001
|
106 |
+
num_train_epochs: 3.0
|
107 |
+
lr_scheduler_type: cosine
|
108 |
+
warmup_steps: 50
|
109 |
+
bf16: true
|
110 |
+
```
|
111 |
+
|
112 |
+
</details>
|
113 |
+
|
114 |
### Framework versions
|
115 |
|
116 |
- Pytorch 2.3.0
|