amd
/

Text Generation
Prakamya Mishra commited on
Commit
27651a9
1 Parent(s): e76f89a

Upload RREADME and Scripts

Browse files
AMD-OLMo-1B-SFT-1st-phase.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: AMD-OLMo-1B-SFT-1st-phase
2
+ seed: 6198
3
+ dry_run: false
4
+
5
+ wandb:
6
+ name: ${run_name}
7
+ project: AMD-OLMo
8
+ group: SFT
9
+
10
+ model:
11
+ d_model: 2048
12
+ n_heads: 16
13
+ n_layers: 16
14
+ mlp_ratio: 8
15
+ weight_tying: true
16
+ alibi: false
17
+ rope: true
18
+ flash_attention: false
19
+ attention_dropout: 0.0
20
+ attention_layer_norm: false
21
+ multi_query_attention: false
22
+ include_bias: false
23
+ block_type: sequential
24
+ layer_norm_type: default
25
+ layer_norm_with_affine: false
26
+ bias_for_layer_norm: false
27
+ attention_layer_norm_with_affine: false
28
+ activation_type: swiglu
29
+ residual_dropout: 0.0
30
+ embedding_dropout: 0.0
31
+ max_sequence_length: 2048
32
+ vocab_size: 50280
33
+ embedding_size: 50304
34
+ eos_token_id: 50279
35
+ pad_token_id: 1
36
+ init_device: meta
37
+ init_fn: mitchell
38
+
39
+ compile:
40
+ fullgraph: false
41
+
42
+ optimizer:
43
+ name: adamw
44
+ learning_rate: 2.0e-5
45
+ weight_decay: 0
46
+ betas:
47
+ - 0.9
48
+ - 0.95
49
+ metrics_log_interval: 10
50
+
51
+ scheduler:
52
+ name: linear_with_warmup
53
+ t_warmup: 200
54
+ alpha_f: 0.001
55
+
56
+ tokenizer:
57
+ identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
58
+ truncate_direction: right
59
+
60
+ save_folder: ./outputs/${run_name}/
61
+ save_overwrite: true
62
+ # Sharded checkpoints (best for restarts)
63
+ save_interval: 1000
64
+ save_num_checkpoints_to_keep: -1
65
+ # Unsharded checkpoints (for final storage)
66
+ save_interval_unsharded: 10000
67
+ save_num_unsharded_checkpoints_to_keep: -1
68
+
69
+ load_path: path_to_unsharded_pretrain_checkpoint
70
+ reset_trainer_state: true
71
+
72
+ max_duration: 3ep # train 3 epochs
73
+ global_train_batch_size: 128
74
+ device_train_microbatch_size: 8
75
+
76
+ precision: amp_bf16
77
+
78
+ fsdp:
79
+ wrapping_strategy: null
80
+ precision: mixed
81
+
82
+ max_grad_norm: 1.0
83
+ max_grad_norm_ratio: null
84
+
85
+ speed_monitor:
86
+ window_size: 20
87
+
88
+ eval_interval: ${save_interval}
89
+ eval_subset_num_batches: -1
90
+ device_eval_batch_size: ${device_train_microbatch_size}
91
+ evaluators:
92
+ - label: piqa
93
+ type: downstream
94
+
95
+ - label: hellaswag
96
+ type: downstream
97
+
98
+ - label: winogrande
99
+ type: downstream
100
+
101
+ - label: openbook_qa
102
+ type: downstream
103
+
104
+ # - label: boolq # requires implemention of the pmi_dc matrix
105
+ # type: downstream
106
+
107
+ - label: sciq
108
+ type: downstream
109
+
110
+ - label: arc_easy
111
+ type: downstream
112
+
113
+ # - label: arc_challenge # requires implemention of the pmi_dc matrix
114
+ # type: downstream
115
+
116
+ - label: copa
117
+ type: downstream
118
+
119
+ - label: rte
120
+ type: downstream
121
+
122
+ - label: commitment_bank
123
+ type: downstream
124
+
125
+ - label: mrpc
126
+ type: downstream
127
+
128
+ - label: sst2
129
+ type: downstream
130
+
131
+ data:
132
+ pad_direction: right
133
+ num_workers: 0
134
+ drop_last: true
135
+ pin_memory: true
136
+ prefetch_factor: 1
137
+ persistent_workers: true
138
+ timeout: 0
139
+ generate_attention_mask: true
140
+ paths:
141
+ - ./datasets/tulu/input_ids.npy
142
+ label_mask_paths:
143
+ - ./datasets/tulu/label_mask.npy
AMD-OLMo-1B-SFT-2nd-phase.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: AMD-OLMo-1B-SFT-2nd-phase
2
+ seed: 6198
3
+ dry_run: false
4
+
5
+ wandb:
6
+ name: ${run_name}
7
+ project: AMD-OLMo
8
+ group: SFT
9
+
10
+ model:
11
+ d_model: 2048
12
+ n_heads: 16
13
+ n_layers: 16
14
+ mlp_ratio: 8
15
+ weight_tying: true
16
+ alibi: false
17
+ rope: true
18
+ flash_attention: false
19
+ attention_dropout: 0.0
20
+ attention_layer_norm: false
21
+ multi_query_attention: false
22
+ include_bias: false
23
+ block_type: sequential
24
+ layer_norm_type: default
25
+ layer_norm_with_affine: false
26
+ bias_for_layer_norm: false
27
+ attention_layer_norm_with_affine: false
28
+ activation_type: swiglu
29
+ residual_dropout: 0.0
30
+ embedding_dropout: 0.0
31
+ max_sequence_length: 2048
32
+ vocab_size: 50280
33
+ embedding_size: 50304
34
+ eos_token_id: 50279
35
+ pad_token_id: 1
36
+ init_device: meta
37
+ init_fn: mitchell
38
+
39
+ compile:
40
+ fullgraph: false
41
+
42
+ optimizer:
43
+ name: adamw
44
+ learning_rate: 2.0e-5
45
+ weight_decay: 0
46
+ betas:
47
+ - 0.9
48
+ - 0.95
49
+ metrics_log_interval: 10
50
+
51
+ scheduler:
52
+ name: linear_with_warmup
53
+ t_warmup: 200
54
+ alpha_f: 0.001
55
+
56
+ tokenizer:
57
+ identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
58
+ truncate_direction: right
59
+
60
+ save_folder: ./outputs/${run_name}/
61
+ save_overwrite: true
62
+ # Sharded checkpoints (best for restarts)
63
+ save_interval: 1000
64
+ save_num_checkpoints_to_keep: -1
65
+ # Unsharded checkpoints (for final storage)
66
+ save_interval_unsharded: 10000
67
+ save_num_unsharded_checkpoints_to_keep: -1
68
+
69
+ load_path: path_to_unsharded_1st_phase_SFT_checkpoint
70
+ reset_trainer_state: true
71
+
72
+ max_duration: 3ep # train 3 epochs
73
+ global_train_batch_size: 512
74
+ device_train_microbatch_size: 8
75
+
76
+ precision: amp_bf16
77
+
78
+ fsdp:
79
+ wrapping_strategy: null
80
+ precision: mixed
81
+
82
+ max_grad_norm: 1.0
83
+ max_grad_norm_ratio: null
84
+
85
+ speed_monitor:
86
+ window_size: 20
87
+
88
+ eval_interval: ${save_interval}
89
+ eval_subset_num_batches: -1
90
+ device_eval_batch_size: ${device_train_microbatch_size}
91
+ evaluators:
92
+ - label: piqa
93
+ type: downstream
94
+
95
+ - label: hellaswag
96
+ type: downstream
97
+
98
+ - label: winogrande
99
+ type: downstream
100
+
101
+ - label: openbook_qa
102
+ type: downstream
103
+
104
+ # - label: boolq # requires implemention of the pmi_dc matrix
105
+ # type: downstream
106
+
107
+ - label: sciq
108
+ type: downstream
109
+
110
+ - label: arc_easy
111
+ type: downstream
112
+
113
+ # - label: arc_challenge # requires implemention of the pmi_dc matrix
114
+ # type: downstream
115
+
116
+ - label: copa
117
+ type: downstream
118
+
119
+ - label: rte
120
+ type: downstream
121
+
122
+ - label: commitment_bank
123
+ type: downstream
124
+
125
+ - label: mrpc
126
+ type: downstream
127
+
128
+ - label: sst2
129
+ type: downstream
130
+
131
+ data:
132
+ pad_direction: right
133
+ num_workers: 0
134
+ drop_last: true
135
+ pin_memory: true
136
+ prefetch_factor: 1
137
+ persistent_workers: true
138
+ timeout: 0
139
+ generate_attention_mask: true
140
+ paths:
141
+ - ./datasets/OpenHermes_WebInstructSub_CodeFeedBack/input_ids.npy
142
+ label_mask_paths:
143
+ - ./datasets/OpenHermes_WebInstructSub_CodeFeedBack/label_mask.npy
AMD-OLMo-1B-dpo.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytest: disable
2
+ # Model arguments
3
+ model_name_or_path: AMD-OLMo-1B-dpo
4
+ torch_dtype: null
5
+ use_flash_attention_2: false
6
+
7
+ chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
8
+ # Data training arguments
9
+ # For definitions, see: src/h4/training/config.py
10
+ dataset_mixer:
11
+ csarron/argilla-ultrafeedback-binarized-preferences-cleaned: 1.0
12
+ dataset_splits:
13
+ - train
14
+ - test
15
+ preprocessing_num_workers: 16
16
+
17
+ # DPOTrainer arguments
18
+ bf16: true
19
+ beta: 0.01
20
+ do_eval: true
21
+ evaluation_strategy: steps
22
+ eval_steps: 100
23
+ gradient_accumulation_steps: 2
24
+ gradient_checkpointing: true
25
+ gradient_checkpointing_kwargs:
26
+ use_reentrant: False
27
+ hub_model_id: AMD-OLMo-1B-dpo
28
+ learning_rate: 5.0e-5
29
+ log_level: info
30
+ logging_steps: 10
31
+ lr_scheduler_type: cosine
32
+ max_length: 1024
33
+ max_prompt_length: 512
34
+ num_train_epochs: 3
35
+ optim: adamw_torch
36
+ output_dir: data/AMD-OLMo-1B-dpo
37
+ per_device_train_batch_size: 8
38
+ per_device_eval_batch_size: 8
39
+ push_to_hub: false
40
+ save_strategy: "steps"
41
+ save_steps: 100
42
+ save_total_limit: 1
43
+ seed: 42
44
+ warmup_ratio: 0.1
AMD-OLMo-1B.yaml ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,293 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - allenai/dolma
5
+ ---
6
+ # AMD-OLMo
7
+
8
+ AMD-OLMo are a series of 1B language models trained from scratch by AMD on AMD Instinct™ MI250 GPUs. The training code used is based on [OLMo](https://github.com/allenai/OLMo).
9
+ We release the pre-trained model, supervised fine-tuned model, and DPO aligned model as follows:
10
+
11
+ - [AMD-OLMo-1B](https://huggingface.co/amd/AMD-OLMo-1B): Pre-trained on a subset of [Dolma v1.7](https://huggingface.co/datasets/allenai/dolma) that consists of 1.3 trillion tokens.
12
+ - [AMD-OLMo-1B-SFT](https://huggingface.co/amd/AMD-OLMo-1B-SFT): Supervised fine-tuned (SFT) on [Tulu V2](https://huggingface.co/datasets/allenai/tulu-v2-sft-mixture) dataset (1st phase) and then [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5), [WebInstructSub](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub), and [Code-Feedback](https://huggingface.co/datasets/m-a-p/Code-Feedback) datasets (2nd phase).
13
+ - [AMD-OLMo-1B-SFT-DPO](https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO): Aligned with human preferences using Direct Preference Optimization (DPO) on [UltraFeedback](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned) dataset.
14
+
15
+ Description:
16
+
17
+ - **Hardware**: Each compute node consists of 4 AMD Instinct™ MI250 GPUs. We use 16 nodes for pretraining AMD-OLMo-1B
18
+
19
+ - **Training throughput**: 12,200 tokens/sec/gpu
20
+
21
+ - **Model architecture**: AMD-OLMo-1B is based on the model architecture and training set up of fully open source 1 billion version of [OLMo-1B](https://github.com/allenai/OLMo) with the details below:
22
+
23
+ | Parameter size | Number of layers | Number of heads | Hidden size | Context length | Vocabulary Size |
24
+ |-----------------:|:------------------:|:-----------------:|:-------------:|:----------------:|:----------------:|
25
+ | 1.2B | 16 | 16 | 2048 | 2048 | 50,280 |
26
+
27
+ - **Hyper-parameters**:
28
+ |Stage | LR schedule | Peak LR | Warmup steps |Epochs| Batch size (tokens) |
29
+ |------------:|:--------------:|:---------:|:--------------:|:------:|:---------------------:|
30
+ |Pretraining | Cosine | 4.0e-4 | 2000 | 1 | 4M |
31
+ |SFT Phase 1 | Linear | 2.0e-5 | 200 | 3 | 262K |
32
+ |SFT Phase 2 | Linear | 2.0e-5 | 200 | 3 | 1024K |
33
+ |DPO | Cosine | 4.0e-6 | 47 | 1 | 64K |
34
+
35
+ ## Usage
36
+
37
+ ### PyTorch on AMD GPUs
38
+ For running pytorch on AMD GPUs you can use the following rocm docker as in [docker hub](https://hub.docker.com/r/rocm/pytorch)
39
+
40
+ ```bash
41
+ docker pull rocm/pytorch:latest
42
+ # Inside docker
43
+ pip install transformers
44
+ ```
45
+
46
+ ### Use Example
47
+
48
+ ```python
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+
51
+ model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT").to("cuda") # remove .to("cuda") to load on cpu
52
+ tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT")
53
+
54
+ prompt = "What is large language model?"
55
+ bos = tokenizer.eos_token
56
+ template = bos + "<|user|>\n{prompt}\n<|assistant|>\n"
57
+
58
+ input_text = template.format(prompt=prompt)
59
+ inputs = tokenizer([input_text], return_tensors='pt', return_token_type_ids=False).to("cuda")
60
+ outputs = model.generate(**inputs, max_new_tokens=1000, do_sample=True, top_k=50, top_p=0.95)
61
+ print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
62
+ ```
63
+
64
+
65
+ ## Main Results
66
+
67
+ ### Pretraining Results
68
+
69
+ | **Standard Benchmarks** | [TinyLLaMA-v1.1](https://huggingface.co/TinyLlama/TinyLlama_v1.1) (1.1B) | [MobiLLaMA-1B](https://huggingface.co/MBZUAI/MobiLlama-1B) (1.2B) | [OLMo-1B](https://huggingface.co/allenai/OLMo-1B-hf) (1.2B) | [OpenELM-1_1B](https://huggingface.co/apple/OpenELM-1_1B) (1.1B) | [OLMo-1B-0724-hf](https://huggingface.co/allenai/OLMo-1B-0724-hf) (1.2B) | [AMD-OLMo-1B](https://huggingface.co/amd/AMD-OLMo-1B) (1.2B) |
70
+ |---------------------:|:-----------------:|:-----------:|:-----------:|:---------------:|:---------------:|:-----------:|
71
+ | **arc_easy** | 55.47 | 56.65 | 57.28 | 55.43 | 56.65 | **63.64** |
72
+ | **arc_challenge** | 32.68 | 32.00 | 31.06 | 32.34 | 32.34 | **33.70** |
73
+ | **hellaswag** | 61.47 | 61.80 | 62.92 | 64.81 | **66.12** | 63.61 |
74
+ | **piqa** | 73.56 | 75.30 | 75.14 | **75.57** | 75.08 | **75.57** |
75
+ | **boolq** | 55.99 | 60.83 | 61.74 | 63.58 | **66.18** | 60.58 |
76
+ | **sciq** | 89.30 | 88.20 | 87.00 | 90.60 | 92.70 | **93.20** |
77
+ | **winogrande** | 59.43 | 59.27 | 59.98 | **61.72** | **61.72** | 61.64 |
78
+ | **openbookqa** | **36.80** | 35.40 | 36.20 | 36.20 | 35.60 | 35.80 |
79
+ | **mmlu (0-shot)** | 25.02 | 24.81 | 24.23 | 25.26 | **25.45** | 24.88 |
80
+ | **gsm8k (8-shot)** | 1.82 | 0.00 | 2.50 | 2.81 | **8.95** | 2.88 |
81
+ | **bbh (3-shot)** | **25.63** | 0.00 | **25.63** | 16.77 | 21.67 | 20.95 |
82
+ | **Average** | 47.02 | 44.93 | 47.61 | 47.73 | **49.31** | 48.77 |
83
+
84
+
85
+ ### Instruction Tuning Results
86
+
87
+ | **Standard Benchmarks**|[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) (1.1B)|[MobiLlama-1B-Chat](https://huggingface.co/MBZUAI/MobiLlama-1B-Chat) (1.2B)|[OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) (1.1B)|[AMD-OLMo-1B-SFT](https://huggingface.co/amd/AMD-OLMo-1B-SFT) (1.2B)|[AMD-OLMo-1B-SFT-DPO](https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO) (1.2B)|
88
+ |------------------:|:---------:|:---------:|:---------:|:---------:|:---------:|
89
+ | **arc_easy** | 54.42 | 57.41 | 52.44 | 63.68 | **64.31** |
90
+ | **arc_challenge** | 32.85 | 34.56 | **37.80** | 37.12 | 37.37 |
91
+ | **hellaswag** | 60.40 | 62.51 | **71.29** | 61.63 | 61.91 |
92
+ | **piqa** | 74.48 | **75.73** | 75.03 | 74.43 | 74.16 |
93
+ | **boolq** | 61.04 | 55.66 | **70.28** | 68.53 | 70.24 |
94
+ | **sciq** | 88.40 | 87.10 | 89.50 | 91.20 | **92.10** |
95
+ | **winogrande** | 60.54 | 60.77 | **62.19** | 60.22 | 60.62 |
96
+ | **openbookqa** | 37.20 | 36.80 | 39.20 | 37.40 | **40.20** |
97
+ | **mmlu** | 24.61 | 25.25 | 25.54 | 29.97 | **30.52** |
98
+ | **gsm8k (8-shot)**| 2.81 | 0.23 | 1.82 | **18.20** | 15.77 |
99
+ | **bbh (3-shot)** | **26.83** | 0.00 | 13.40 | 25.17 | 25.45 |
100
+ | **Average** | 47.60 | 45.09 | 48.95 | 51.60 | **52.06** |
101
+
102
+ |**Chat Benchmarks**|[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) (1.1B)|[MobiLlama-1B-Chat](https://huggingface.co/MBZUAI/MobiLlama-1B-Chat) (1.2B)|[OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) (1.1B)|[AMD-OLMo-1B-SFT](https://huggingface.co/amd/AMD-OLMo-1B-SFT) (1.2B)|[AMD-OLMo-1B-SFT-DPO](https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO) (1.2B)|
103
+ |------------------:|:---------:|:---------:|:---------:|:---------:|:---------:|
104
+ | **AlpacaEval 1 (Win Rate)** | 50.81 | 34.90 | 37.72 | 50.12 | **54.22** |
105
+ | **AlpacaEval 2 (LC Win Rate)**| 1.54 | 1.59 | 0.49 | **3.88** | 2.37 |
106
+ | **MTBench** | 3.38 | 2.89 | - | **4.35** | 4.10 |
107
+
108
+ |**Responsible AI Benchmarks**|[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) (1.1B)|[MobiLlama-1B-Chat](https://huggingface.co/MBZUAI/MobiLlama-1B-Chat) (1.2B)|[OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) (1.1B)|[AMD-OLMo-1B-SFT](https://huggingface.co/amd/AMD-OLMo-1B-SFT) (1.2B)|[AMD-OLMo-1B-SFT-DPO](https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO) (1.2B)|
109
+ |------------------:|:---------:|:---------:|:---------:|:---------:|:---------:|
110
+ | **ToxiGen** | 41.70 | **37.23** | 42.34 | 39.04 | 39.68 |
111
+ | **crows_pairs** | 60.35 | 58.50 | 59.93 | 60.29 | **61.00** |
112
+ | **TruthfulQA-mc2**| 37.92 | 38.46 | **45.84** | 37.45 | 40.06 |
113
+
114
+ *In generating tokens for chat benchmark evaluations, we use `max_length=2048` for AlpacaEval and `max_new_tokens=2048` for MTBench.
115
+
116
+ *All numbers in above tables were obtained from our evaluations.
117
+
118
+
119
+ ## Evaluation
120
+ We use the following open source evaluation frameworks for evaluating our models:
121
+ - [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness): For evaluating on commonsense reasoning, multi-task understanding & responsible AI benchmarks
122
+ - [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): For evaluating instruction-following capabilities of chat models.
123
+ - [MT-Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge): For evaluating multi-turn capabilities of chat models.
124
+
125
+ ### Setup
126
+ ```bash
127
+ # lm-eval-harness
128
+ git clone https://github.com/EleutherAI/lm-evaluation-harness
129
+ cd lm-evaluation-harness
130
+ pip install -e .
131
+
132
+ # AlpacaEval
133
+ pip install git+https://github.com/tatsu-lab/alpaca_eval
134
+ cd alpaca_eval
135
+ pip install -e .
136
+
137
+ # MT-Bench
138
+ git clone https://github.com/lm-sys/FastChat.git
139
+ cd FastChat
140
+ pip install -e ".[model_worker,llm_judge]"
141
+ ```
142
+
143
+ ### Run evaluation
144
+ ```bash
145
+ # lm-eval-harness
146
+ HF_MODEL=amd/AMD-OLMo-1B-SFT-DPO
147
+ accelerate launch -m lm_eval --model hf \
148
+ --model_args pretrained=$HF_MODEL,trust_remote_code=True \
149
+ --tasks arc_easy,arc_challenge,hellaswag,piqa,boolq,sciq,winogrande,openbookqa,mmlu,gsm8k_cot,bbh_cot_fewshot,toxigen,truthfulqa,crows_pairs \
150
+ --device cuda \
151
+ --batch_size 32 \
152
+ --output_path ./lm-eval-results/$HF_MODEL
153
+ ```
154
+
155
+ ## Training
156
+
157
+ ### Setup
158
+ ```bash
159
+ WORK_DIR="<path_to_your_working_directory>"
160
+ cd $WORK_DIR
161
+ # Clone OLMo codebase:
162
+ git clone https://github.com/allenai/OLMo.git --branch v0.3.0
163
+ cd OLMo
164
+ # Clone AMD-OLMo that contains files to reproduce our model training
165
+ git clone https://huggingface.co/amd/AMD-OLMo
166
+
167
+ docker pull rocm/pytorch:latest
168
+ docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 8G -v $WORK_DIR/OLMo:/OLMo -w /OLMo rocm/pytorch:latest
169
+
170
+ # Remove Line 17 as the docker already has ROCm PyTorch installed
171
+ sed -i '17d' pyproject.toml
172
+ pip install -e .[all]
173
+ ```
174
+
175
+ ### Download and prepare pretraining datasets
176
+ ```bash
177
+ # Download
178
+ DATA_DIR=./datasets/dolma
179
+ mkdir -p $DATA_DIR
180
+
181
+ PARALLEL_DOWNLOADS="<number_of_parallel_downloads>"
182
+ cat "AMD-OLMo/dolma_v1_7_subset.txt" | xargs -n 1 -P $PARALLEL_DOWNLOADS wget -q -P $DATA_DIR
183
+
184
+ # Prepare
185
+ NUM_WORKERS="<number_of_workers>"
186
+ python scripts/prepare_memmap_dataset.py $DATA_DIR/*.json.gz -o $DATA_DIR/memmap_dataset --workers $NUM_WORKERS
187
+ ```
188
+
189
+ ### Download and prepare SFT datasets
190
+ ```bash
191
+ # 1st phase SFT dataset
192
+ python AMD-OLMo/prepare_sft_data.py --output_dir ./datasets/tulu --tokenizer tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json --dataset tulu
193
+
194
+ # 2nd phase SFT dataset
195
+ python AMD-OLMo/prepare_sft_data.py --output_dir ./datasets/OpenHermes_WebInstructSub_CodeFeedBack --tokenizer tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json --dataset 2nd-phase
196
+ ```
197
+
198
+ ### Run Training
199
+ Pretrainig config: [AMD-OLMo-1B.yaml](AMD-OLMo-1B.yaml)
200
+
201
+ SFT config: [AMD-OLMo-1B-SFT-1st-phase.yaml](AMD-OLMo-1B-SFT-1st-phase.yaml) and [AMD-OLMo-1B-SFT-2nd-phase.yaml](AMD-OLMo-1B-SFT-2nd-phase.yaml)
202
+ ```bash
203
+ # Single node
204
+ HSA_FORCE_FINE_GRAIN_PCIE=1 OMP_NUM_THREADS=128 NCCL_DEBUG=INFO torchrun --nproc_per_node=8 ./scripts/train.py AMD-OLMo/AMD-OLMo-1B.yaml
205
+
206
+ # Multiple nodes
207
+ HSA_FORCE_FINE_GRAIN_PCIE=1 OMP_NUM_THREADS=128 NCCL_DEBUG=INFO torchrun --nnodes=$nnodes --node-rank=$node_rank --master_addr=$master_addr --master_port=$master_port --nproc_per_node=8 ./scripts/train.py AMD-OLMo/AMD-OLMo-1B.yaml
208
+ ```
209
+
210
+ ### Run DPO Training
211
+
212
+ DPO recipe: [AMD-OLMo-1B-dpo.yaml](AMD-OLMo-1B-dpo.yaml).
213
+ ```bash
214
+ # install trl library
215
+ git clone https://github.com/huggingface/trl.git -b v0.8.6
216
+
217
+ # replace dpo_trainer.py
218
+ cp AMD-OLMo/dpo_trainer.py trl/trl/trainer
219
+
220
+ pip install -e ./trl
221
+
222
+ # install alignment-handbook
223
+ git clone https://github.com/huggingface/alignment-handbook.git hf-align
224
+ # 70769f9 is the main branch on 2024-04-11.
225
+ cd hf-align && git checkout 70769f9 && cd ..
226
+
227
+ pip install -e ./hf-align
228
+
229
+ # Copy AMD OLMo DPO recipe to hf-align/recipes.
230
+ cp AMD-OLMo/AMD-OLMo-1B-dpo.yaml hf-align/recipes/
231
+
232
+ # Prepare the converted AMD-OLMo SFT Huggingface model to ckpt_dir.
233
+ ckpt_dir=amd/AMD-OLMo-1B-SFT
234
+ local_tokenizer_dir=${ckpt_dir}
235
+
236
+ # Set output checkpoint dir.
237
+ dpo_ckpt_dir=<your_output_checkpoint_dir>
238
+
239
+ accelerate launch --config_file hf-align/recipes/accelerate_configs/deepspeed_zero3.yaml \
240
+ hf-align/scripts/run_dpo.py hf-align/recipes/AMD-OLMo-1B-dpo.yaml \
241
+ --trust_remote_code=true \
242
+ --model_name_or_path=${ckpt_dir} \
243
+ --tokenizer_name_or_path=${local_tokenizer_dir} \
244
+ --output_dir=${dpo_ckpt_dir} \
245
+ --num_train_epochs=1 \
246
+ --learning_rate=4e-6 \
247
+ --beta=0.3 \
248
+ --loss_type=sigmoid
249
+ ```
250
+
251
+ ## Bias, Risks, and Limitations
252
+
253
+ - The models are being released for research purposes only and are not intended for use cases that require high levels of factuality, safety critical situations, health or medical applications, generating false information, facilitating toxic conversations.
254
+ - Model checkpoints are made accessible without any safety guarantees. It is crucial for users to conduct comprehensive evaluations and implement safety filtering mechanisms as per their respective use cases.
255
+ - It may be possible to prompt the model to generate content that may be factually inaccurate, harmful, violent, toxic, biased, or otherwise objectionable. Such content may also get generated by prompts that did not intend to produce output as such. Users are thus requested to be aware of this and exercise caution and responsible thinking when using the model.
256
+ - Multi-lingual abilities of the models have not been tested and thus may misunderstand and generate erroneous responses across different languages.
257
+
258
+ ## Appendix
259
+ ### Evaluation Metrics
260
+ | **Benchmark** | Metric |
261
+ |---------------------:|:-----------------:|
262
+ | **arc_easy** | Normalized Accuracy |
263
+ | **arc_challenge** | Normalized Accuracy |
264
+ | **hellaswag** | Normalized Accuracy |
265
+ | **piqa** | Accuracy |
266
+ | **boolq** | Accuracy |
267
+ | **sciq** | Accuracy |
268
+ | **winogrande** | Accuracy |
269
+ | **openbookqa** | Normalized Accuracy |
270
+ | **mmlu** | Accuracy |
271
+ | **gsm8k (8-shot)** | Exact Match (Flexible Extract) |
272
+ | **bbh (3-shot)** | Exact Match |
273
+ | **ToxiGen** | Accuracy |
274
+ | **crows_pairs** | PCT Stereotype |
275
+ | **TruthfulQA-mc2** | Accuracy |
276
+ | **AlpacaEval 1 (Win Rate)** | Win Rate (chatgpt_fn) |
277
+ | **AlpacaEval 2 (LC Win Rate)** | Length Control Win Rate (weighted_alpaca_eval_gpt4_turbo) |
278
+ | **MTBench** | Average score for single-answer grading (2 turns) |
279
+
280
+ #### License
281
+ Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved.
282
+
283
+ Licensed under the Apache License, Version 2.0 (the "License");
284
+ you may not use this file except in compliance with the License.
285
+ You may obtain a copy of the License at
286
+
287
+ http://www.apache.org/licenses/LICENSE-2.0
288
+
289
+ Unless required by applicable law or agreed to in writing, software
290
+ distributed under the License is distributed on an "AS IS" BASIS,
291
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
292
+ See the License for the specific language governing permissions and
293
+ limitations under the License.
dolma_v1_7_subset.txt ADDED
The diff for this file is too large to render. See raw diff
 
dpo_trainer.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (Modifications Copyright(C) [2024] Advanced Micro Devices, Inc. All rights reserved)
2
+ # DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
3
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import inspect
17
+ import random
18
+ import warnings
19
+ from collections import defaultdict
20
+ from contextlib import contextmanager, nullcontext
21
+ from copy import deepcopy
22
+ from functools import wraps
23
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from accelerate import PartialState
30
+ from accelerate.utils import is_deepspeed_available, tqdm
31
+ from datasets import Dataset
32
+ from torch.utils.data import DataLoader
33
+ from transformers import (
34
+ AutoModelForCausalLM,
35
+ DataCollator,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Trainer,
39
+ TrainingArguments,
40
+ )
41
+ from transformers.trainer_callback import TrainerCallback
42
+ from transformers.trainer_utils import EvalLoopOutput
43
+
44
+ from ..import_utils import is_peft_available, is_wandb_available
45
+ from ..models import PreTrainedModelWrapper, create_reference_model
46
+ from .utils import (
47
+ DPODataCollatorWithPadding,
48
+ disable_dropout_in_model,
49
+ pad_to_length,
50
+ peft_module_casting_to_bf16,
51
+ trl_sanitze_kwargs_for_tagging,
52
+ )
53
+
54
+
55
+ if is_peft_available():
56
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
57
+
58
+
59
+ if is_wandb_available():
60
+ import wandb
61
+
62
+ if is_deepspeed_available():
63
+ import deepspeed
64
+
65
+
66
+ class DPOTrainer(Trainer):
67
+ r"""
68
+ Initialize DPOTrainer.
69
+
70
+ Args:
71
+ model (`transformers.PreTrainedModel`):
72
+ The model to train, preferably an `AutoModelForSequenceClassification`.
73
+ ref_model (`PreTrainedModelWrapper`):
74
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
75
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
76
+ beta (`float`, defaults to 0.1):
77
+ The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
78
+ label_smoothing (`float`, defaults to 0):
79
+ The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
80
+ loss_type (`str`, defaults to `"sigmoid"`):
81
+ The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
82
+ args (`transformers.TrainingArguments`):
83
+ The arguments to use for training.
84
+ data_collator (`transformers.DataCollator`):
85
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
86
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
87
+ label_pad_token_id (`int`, defaults to `-100`):
88
+ The label pad token id. This argument is required if you want to use the default data collator.
89
+ padding_value (`int`, defaults to `0`):
90
+ The padding value if it is different to the tokenizer's pad_token_id.
91
+ truncation_mode (`str`, defaults to `keep_end`):
92
+ The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
93
+ train_dataset (`datasets.Dataset`):
94
+ The dataset to use for training.
95
+ eval_dataset (`datasets.Dataset`):
96
+ The dataset to use for evaluation.
97
+ tokenizer (`transformers.PreTrainedTokenizerBase`):
98
+ The tokenizer to use for training. This argument is required if you want to use the default data collator.
99
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
100
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
101
+ callbacks (`List[transformers.TrainerCallback]`):
102
+ The callbacks to use for training.
103
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
104
+ The optimizer and scheduler to use for training.
105
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
106
+ The function to use to preprocess the logits before computing the metrics.
107
+ max_length (`int`, defaults to `None`):
108
+ The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
109
+ max_prompt_length (`int`, defaults to `None`):
110
+ The maximum length of the prompt. This argument is required if you want to use the default data collator.
111
+ max_target_length (`int`, defaults to `None`):
112
+ The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
113
+ peft_config (`Dict`, defaults to `None`):
114
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
115
+ is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
116
+ If no model is provided, we need to know if the model_init returns an encoder-decoder.
117
+ disable_dropout (`bool`, defaults to `True`):
118
+ Whether or not to disable dropouts in `model` and `ref_model`.
119
+ generate_during_eval (`bool`, defaults to `False`):
120
+ Whether to sample and log generations during evaluation step.
121
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
122
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
123
+ a dictionary string to metric values.
124
+ precompute_ref_log_probs (`bool`, defaults to `False`):
125
+ Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
126
+ without the reference model and reduce the total GPU memory needed.
127
+ dataset_num_proc (`Optional[int]`, *optional*):
128
+ The number of workers to use to tokenize the data. Defaults to None.
129
+ model_init_kwargs (`Optional[Dict]`, *optional*):
130
+ Dict of Optional kwargs to pass when instantiating the model from a string
131
+ ref_model_init_kwargs (`Optional[Dict]`, *optional*):
132
+ Dict of Optional kwargs to pass when instantiating the ref model from a string
133
+ model_adapter_name (`str`, defaults to `None`):
134
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
135
+ ref_adapter_name (`str`, defaults to `None`):
136
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
137
+ reference_free (`bool`):
138
+ If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
139
+ force_use_ref_model (`bool`, defaults to `False`):
140
+ In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
141
+ """
142
+
143
+ _tag_names = ["trl", "dpo"]
144
+
145
+ def __init__(
146
+ self,
147
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
148
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
149
+ beta: float = 0.1,
150
+ label_smoothing: float = 0,
151
+ loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
152
+ args: Optional[TrainingArguments] = None,
153
+ data_collator: Optional[DataCollator] = None,
154
+ label_pad_token_id: int = -100,
155
+ padding_value: Optional[int] = None,
156
+ truncation_mode: str = "keep_end",
157
+ train_dataset: Optional[Dataset] = None,
158
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
159
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
160
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
161
+ callbacks: Optional[List[TrainerCallback]] = None,
162
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
163
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
164
+ max_length: Optional[int] = None,
165
+ max_prompt_length: Optional[int] = None,
166
+ max_target_length: Optional[int] = None,
167
+ peft_config: Optional[Dict] = None,
168
+ is_encoder_decoder: Optional[bool] = None,
169
+ disable_dropout: bool = True,
170
+ generate_during_eval: bool = False,
171
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
172
+ precompute_ref_log_probs: bool = False,
173
+ dataset_num_proc: Optional[int] = None,
174
+ model_init_kwargs: Optional[Dict] = None,
175
+ ref_model_init_kwargs: Optional[Dict] = None,
176
+ model_adapter_name: Optional[str] = None,
177
+ ref_adapter_name: Optional[str] = None,
178
+ reference_free: bool = False,
179
+ force_use_ref_model: bool = False,
180
+ ):
181
+ if model_init_kwargs is None:
182
+ model_init_kwargs = {}
183
+ elif not isinstance(model, str):
184
+ raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.")
185
+
186
+ if ref_model_init_kwargs is None:
187
+ ref_model_init_kwargs = {}
188
+ elif not isinstance(ref_model, str):
189
+ raise ValueError(
190
+ "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated."
191
+ )
192
+
193
+ if isinstance(model, str):
194
+ warnings.warn(
195
+ "You passed a model_id to the DPOTrainer. This will automatically create an "
196
+ "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
197
+ )
198
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
199
+
200
+ if isinstance(ref_model, str):
201
+ warnings.warn(
202
+ "You passed a ref model_id to the DPOTrainer. This will automatically create an "
203
+ "`AutoModelForCausalLM`"
204
+ )
205
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
206
+
207
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
208
+ # has been called in order to properly call autocast if needed.
209
+ self._peft_has_been_casted_to_bf16 = False
210
+
211
+ if not is_peft_available() and peft_config is not None:
212
+ raise ValueError(
213
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
214
+ )
215
+ elif is_peft_available() and peft_config is not None:
216
+ # if model is a peft model and we have a peft_config, we merge and unload it first
217
+ if isinstance(model, PeftModel):
218
+ model = model.merge_and_unload()
219
+
220
+ if ref_model is not None and not force_use_ref_model:
221
+ raise ValueError(
222
+ "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
223
+ " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
224
+ " if you want to use a different ref_model."
225
+ )
226
+
227
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
228
+ _support_gc_kwargs = hasattr(
229
+ args, "gradient_checkpointing_kwargs"
230
+ ) and "gradient_checkpointing_kwargs" in list(
231
+ inspect.signature(prepare_model_for_kbit_training).parameters
232
+ )
233
+
234
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
235
+
236
+ if _support_gc_kwargs:
237
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
238
+
239
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
240
+ elif getattr(args, "gradient_checkpointing", False):
241
+ # For backward compatibility with older versions of transformers
242
+ if hasattr(model, "enable_input_require_grads"):
243
+ model.enable_input_require_grads()
244
+ else:
245
+
246
+ def make_inputs_require_grad(module, input, output):
247
+ output.requires_grad_(True)
248
+
249
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
250
+
251
+ # get peft model with the given config
252
+ model = get_peft_model(model, peft_config)
253
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
254
+ peft_module_casting_to_bf16(model)
255
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
256
+ self._peft_has_been_casted_to_bf16 = True
257
+
258
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
259
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
260
+ # fail or completely fail.
261
+ elif getattr(args, "gradient_checkpointing", False):
262
+ # For backward compatibility with older versions of transformers
263
+ if hasattr(model, "enable_input_require_grads"):
264
+ model.enable_input_require_grads()
265
+ else:
266
+
267
+ def make_inputs_require_grad(module, input, output):
268
+ output.requires_grad_(True)
269
+
270
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
271
+
272
+ if generate_during_eval and not is_wandb_available():
273
+ raise ValueError(
274
+ "`generate_during_eval=True` requires Weights and Biases to be installed."
275
+ " Please install `wandb` to resolve."
276
+ )
277
+
278
+ if model is not None:
279
+ self.is_encoder_decoder = model.config.is_encoder_decoder
280
+ elif is_encoder_decoder is None:
281
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
282
+ else:
283
+ self.is_encoder_decoder = is_encoder_decoder
284
+
285
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
286
+ self.model_adapter_name = model_adapter_name
287
+ self.ref_adapter_name = ref_adapter_name
288
+ self.reference_free = reference_free
289
+
290
+ if ref_model:
291
+ self.ref_model = ref_model
292
+ elif self.is_peft_model or precompute_ref_log_probs:
293
+ # The `model` with adapters turned off will be used as the reference model
294
+ self.ref_model = None
295
+ else:
296
+ self.ref_model = create_reference_model(model)
297
+
298
+ if tokenizer is None:
299
+ raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
300
+ if max_length is None:
301
+ warnings.warn(
302
+ "`max_length` is not set in the DPOTrainer's init"
303
+ " it will default to `512` by default, but you should do it yourself in the future.",
304
+ UserWarning,
305
+ )
306
+ max_length = 512
307
+ if max_prompt_length is None:
308
+ warnings.warn(
309
+ "`max_prompt_length` is not set in the DPOTrainer's init"
310
+ " it will default to `128` by default, but you should do it yourself in the future.",
311
+ UserWarning,
312
+ )
313
+ max_prompt_length = 128
314
+
315
+ if max_target_length is None and self.is_encoder_decoder:
316
+ warnings.warn(
317
+ "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init"
318
+ " it will default to `128` by default, but you should do it yourself in the future.",
319
+ UserWarning,
320
+ )
321
+ max_target_length = 128
322
+
323
+ if data_collator is None:
324
+ data_collator = DPODataCollatorWithPadding(
325
+ pad_token_id=tokenizer.pad_token_id,
326
+ label_pad_token_id=label_pad_token_id,
327
+ is_encoder_decoder=self.is_encoder_decoder,
328
+ )
329
+
330
+ if args.remove_unused_columns:
331
+ args.remove_unused_columns = False
332
+ # warn users
333
+ warnings.warn(
334
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
335
+ " we have set it for you, but you should do it yourself in the future.",
336
+ UserWarning,
337
+ )
338
+
339
+ self.use_dpo_data_collator = True
340
+ else:
341
+ self.use_dpo_data_collator = False
342
+
343
+ if disable_dropout:
344
+ disable_dropout_in_model(model)
345
+ if self.ref_model is not None:
346
+ disable_dropout_in_model(self.ref_model)
347
+
348
+ self.max_length = max_length
349
+ self.generate_during_eval = generate_during_eval
350
+ self.label_pad_token_id = label_pad_token_id
351
+ self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id
352
+ self.max_prompt_length = max_prompt_length
353
+ self.truncation_mode = truncation_mode
354
+ self.max_target_length = max_target_length
355
+ self.tokenizer = tokenizer
356
+ self.precompute_ref_log_probs = precompute_ref_log_probs
357
+
358
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
359
+ # keep track of first called to avoid computation of future calls
360
+ self._precomputed_train_ref_log_probs = False
361
+ self._precomputed_eval_ref_log_probs = False
362
+
363
+ if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
364
+ warnings.warn(
365
+ "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
366
+ )
367
+
368
+ self.beta = beta
369
+ self.label_smoothing = label_smoothing
370
+ self.loss_type = loss_type
371
+
372
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
373
+
374
+ self.dataset_num_proc = dataset_num_proc
375
+
376
+ # Compute that only on the main process for faster data processing.
377
+ # see: https://github.com/huggingface/trl/pull/1255
378
+ with PartialState().local_main_process_first():
379
+ # tokenize the dataset
380
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
381
+ if eval_dataset is not None:
382
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
383
+
384
+ super().__init__(
385
+ model=model,
386
+ args=args,
387
+ data_collator=data_collator,
388
+ train_dataset=train_dataset,
389
+ eval_dataset=eval_dataset,
390
+ tokenizer=tokenizer,
391
+ model_init=model_init,
392
+ compute_metrics=compute_metrics,
393
+ callbacks=callbacks,
394
+ optimizers=optimizers,
395
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
396
+ )
397
+
398
+ # Add tags for models that have been loaded with the correct transformers version
399
+ if hasattr(self.model, "add_model_tags"):
400
+ self.model.add_model_tags(self._tag_names)
401
+
402
+ if not hasattr(self, "accelerator"):
403
+ raise AttributeError(
404
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
405
+ )
406
+
407
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
408
+ if self.is_deepspeed_enabled:
409
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
410
+ raise ValueError(
411
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
412
+ )
413
+
414
+ if self.ref_model is None:
415
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
416
+ raise ValueError(
417
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
418
+ )
419
+ else:
420
+ if self.is_deepspeed_enabled:
421
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
422
+ else:
423
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
424
+
425
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
426
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
427
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
428
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
429
+
430
+ if model is not None:
431
+ if hasattr(model, "config"):
432
+ hidden_size = (
433
+ max(model.config.hidden_sizes)
434
+ if getattr(model.config, "hidden_sizes", None)
435
+ else getattr(model.config, "hidden_size", None)
436
+ )
437
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
438
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
439
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
440
+ config_kwargs.update(
441
+ {
442
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
443
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
444
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
445
+ }
446
+ )
447
+
448
+ # If ZeRO-3 is used, we shard both the active and reference model.
449
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
450
+ if config_kwargs["zero_optimization"]["stage"] != 3:
451
+ config_kwargs["zero_optimization"]["stage"] = 0
452
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
453
+ model.eval()
454
+ return model
455
+
456
+ def get_train_dataloader(self) -> DataLoader:
457
+ """
458
+ Returns the training [`~torch.utils.data.DataLoader`].
459
+
460
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
461
+ """
462
+
463
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
464
+ dataloader_params = {
465
+ "batch_size": self.args.per_device_train_batch_size,
466
+ "collate_fn": self.data_collator,
467
+ "num_workers": self.args.dataloader_num_workers,
468
+ "pin_memory": self.args.dataloader_pin_memory,
469
+ "shuffle": False,
470
+ }
471
+
472
+ # prepare dataloader
473
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
474
+
475
+ reference_chosen_logps = []
476
+ reference_rejected_logps = []
477
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
478
+ reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
479
+ reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
480
+ (reference_chosen_logp, reference_rejected_logp)
481
+ )
482
+ reference_chosen_logps.append(reference_chosen_logp.cpu())
483
+ reference_rejected_logps.append(reference_rejected_logp.cpu())
484
+
485
+ all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
486
+ all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
487
+
488
+ self.train_dataset = self.train_dataset.add_column(
489
+ name="reference_chosen_logps", column=all_reference_chosen_logps
490
+ )
491
+ self.train_dataset = self.train_dataset.add_column(
492
+ name="reference_rejected_logps", column=all_reference_rejected_logps
493
+ )
494
+
495
+ self._precomputed_train_ref_log_probs = True
496
+
497
+ return super().get_train_dataloader()
498
+
499
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
500
+ """
501
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
502
+
503
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
504
+
505
+ Args:
506
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
507
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
508
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
509
+ """
510
+ if eval_dataset is None and self.eval_dataset is None:
511
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
512
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
513
+
514
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
515
+ dataloader_params = {
516
+ "batch_size": self.args.per_device_eval_batch_size,
517
+ "collate_fn": self.data_collator,
518
+ "num_workers": self.args.dataloader_num_workers,
519
+ "pin_memory": self.args.dataloader_pin_memory,
520
+ "shuffle": False,
521
+ }
522
+
523
+ # prepare dataloader
524
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
525
+
526
+ reference_chosen_logps = []
527
+ reference_rejected_logps = []
528
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
529
+ reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
530
+ reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics(
531
+ (reference_chosen_logp, reference_rejected_logp)
532
+ )
533
+ reference_chosen_logps.append(reference_chosen_logp.cpu())
534
+ reference_rejected_logps.append(reference_rejected_logp.cpu())
535
+
536
+ all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
537
+ all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
538
+
539
+ eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
540
+ eval_dataset = eval_dataset.add_column(
541
+ name="reference_rejected_logps", column=all_reference_rejected_logps
542
+ )
543
+
544
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
545
+ if self.eval_dataset is not None:
546
+ self.eval_dataset = eval_dataset
547
+ self._precomputed_eval_ref_log_probs = True
548
+
549
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
550
+
551
+ def build_tokenized_answer(self, prompt, answer):
552
+ """
553
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
554
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
555
+ Reference:
556
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
557
+ """
558
+
559
+ full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
560
+ prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
561
+
562
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
563
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
564
+
565
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
566
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
567
+
568
+ # Prepare input tokens for token by token comparison
569
+ full_input_ids = np.array(full_tokenized["input_ids"])
570
+
571
+ if len(full_input_ids) != len(full_concat_input_ids):
572
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
573
+
574
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
575
+ # can be merged together when tokenizing prompt+answer. This could result
576
+ # on the last token from the prompt being different when tokenized on its own
577
+ # vs when done as prompt+answer.
578
+ response_token_ids_start_idx = len(prompt_input_ids)
579
+
580
+ # If tokenized prompt is different than both prompt+answer, then it means the
581
+ # last token has changed due to merging.
582
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
583
+ response_token_ids_start_idx -= 1
584
+
585
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
586
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
587
+
588
+ if len(prompt_input_ids) != len(prompt_attention_mask):
589
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
590
+
591
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
592
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
593
+
594
+ return dict(
595
+ prompt_input_ids=prompt_input_ids,
596
+ prompt_attention_mask=prompt_attention_mask,
597
+ input_ids=answer_input_ids,
598
+ attention_mask=answer_attention_mask,
599
+ )
600
+
601
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
602
+ """Tokenize a single row from a DPO specific dataset.
603
+
604
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
605
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
606
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
607
+
608
+ We also create the labels for the chosen/rejected responses, which are of length equal to
609
+ the sum of the length of the prompt and the chosen/rejected response, with
610
+ label_pad_token_id for the prompt tokens.
611
+ """
612
+ batch = {}
613
+ prompt = feature["prompt"]
614
+ chosen = feature["chosen"]
615
+ rejected = feature["rejected"]
616
+
617
+ if not self.tokenizer.bos_token_id:
618
+ self.tokenizer.bos_token_id = self.tokenizer.eos_token_id
619
+ self.tokenizer.add_special_tokens({"bos_token": self.tokenizer.eos_token})
620
+
621
+ if not self.is_encoder_decoder:
622
+ # Check issues below for more details
623
+ # 1. https://github.com/huggingface/trl/issues/907
624
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
625
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
626
+
627
+ if not isinstance(prompt, str):
628
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
629
+ prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
630
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
631
+
632
+ if not isinstance(chosen, str):
633
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
634
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
635
+
636
+ if not isinstance(rejected, str):
637
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
638
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
639
+
640
+ # Last prompt token might get merged by tokenizer and
641
+ # it should not be included for generation if that happens
642
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
643
+
644
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
645
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
646
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
647
+
648
+ for k, v in prompt_tokens.items():
649
+ prompt_tokens[k] = v[:prompt_len_input_ids]
650
+
651
+ # Make sure prompts only have one different token at most an
652
+ # and length only differs by 1 at most
653
+ num_diff_tokens = sum(
654
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
655
+ )
656
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
657
+ if num_diff_tokens > 1 or num_diff_len > 1:
658
+ raise ValueError(
659
+ "Chosen and rejected prompt_input_ids might only differ on the "
660
+ "last token due to tokenizer merge ops."
661
+ )
662
+
663
+ # add BOS token to head of prompt
664
+ prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
665
+ chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
666
+ rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
667
+
668
+ prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
669
+ chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
670
+ rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
671
+
672
+ # print(chosen_tokens["input_ids"])
673
+ # print(chosen_tokens["attention_mask"])
674
+ # add EOS token to end of answer
675
+ chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
676
+ # print(chosen_tokens["input_ids"])
677
+ chosen_tokens["attention_mask"].append(1)
678
+ # print(chosen_tokens["attention_mask"])
679
+
680
+ rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
681
+ rejected_tokens["attention_mask"].append(1)
682
+
683
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
684
+
685
+ # if combined sequence is too long, truncate the prompt
686
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
687
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
688
+ if self.truncation_mode == "keep_start":
689
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
690
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
691
+ elif self.truncation_mode == "keep_end":
692
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
693
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
694
+ else:
695
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
696
+
697
+ # if that's still too long, truncate the response
698
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
699
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
700
+ for k in ["input_ids", "attention_mask"]:
701
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
702
+
703
+ # Create labels
704
+ chosen_sequence_tokens = {
705
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
706
+ }
707
+ rejected_sequence_tokens = {
708
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
709
+ }
710
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
711
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
712
+ self.label_pad_token_id
713
+ ] * len(chosen_tokens["prompt_input_ids"])
714
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
715
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
716
+ self.label_pad_token_id
717
+ ] * len(rejected_tokens["prompt_input_ids"])
718
+
719
+ for k, toks in {
720
+ "chosen_": chosen_sequence_tokens,
721
+ "rejected_": rejected_sequence_tokens,
722
+ "": prompt_tokens,
723
+ }.items():
724
+ for type_key, tokens in toks.items():
725
+ if type_key == "token_type_ids":
726
+ continue
727
+ batch[f"{k}{type_key}"] = tokens
728
+ # print(f"{k}{type_key}", tokens)
729
+ # import pdb; pdb.set_trace()
730
+ # raise
731
+
732
+ else:
733
+ chosen_tokens = self.tokenizer(
734
+ chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
735
+ )
736
+ rejected_tokens = self.tokenizer(
737
+ rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
738
+ )
739
+ prompt_tokens = self.tokenizer(
740
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
741
+ )
742
+
743
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
744
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
745
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
746
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
747
+
748
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
749
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
750
+ labels=torch.tensor(batch["rejected_labels"])
751
+ )
752
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
753
+ labels=torch.tensor(batch["chosen_labels"])
754
+ )
755
+
756
+ return batch
757
+
758
+ @contextmanager
759
+ def null_ref_context(self):
760
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
761
+ with self.accelerator.unwrap_model(
762
+ self.model
763
+ ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
764
+ if self.ref_adapter_name:
765
+ self.model.set_adapter(self.ref_adapter_name)
766
+ yield
767
+ if self.ref_adapter_name:
768
+ self.model.set_adapter(self.model_adapter_name or "default")
769
+
770
+ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
771
+ """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
772
+ compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
773
+
774
+ # compute reference logps
775
+ with torch.no_grad(), compte_ref_context_manager():
776
+ if self.ref_model is None:
777
+ with self.null_ref_context():
778
+ (
779
+ reference_chosen_logps,
780
+ reference_rejected_logps,
781
+ _,
782
+ _,
783
+ ) = self.concatenated_forward(self.model, padded_batch)
784
+ else:
785
+ (
786
+ reference_chosen_logps,
787
+ reference_rejected_logps,
788
+ _,
789
+ _,
790
+ ) = self.concatenated_forward(self.ref_model, padded_batch)
791
+
792
+ return reference_chosen_logps, reference_rejected_logps
793
+
794
+ @staticmethod
795
+ def concatenated_inputs(
796
+ batch: Dict[str, Union[List, torch.LongTensor]],
797
+ is_encoder_decoder: bool = False,
798
+ label_pad_token_id: int = -100,
799
+ padding_value: int = 0,
800
+ device: Optional[torch.device] = None,
801
+ ) -> Dict[str, torch.LongTensor]:
802
+ """Concatenate the chosen and rejected inputs into a single tensor.
803
+
804
+ Args:
805
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
806
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
807
+ label_pad_token_id: The label pad token id.
808
+ padding_value: The padding value to use for the concatenated inputs_ids.
809
+ device: The device for the concatenated inputs.
810
+
811
+ Returns:
812
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
813
+ """
814
+ concatenated_batch = {}
815
+
816
+ if is_encoder_decoder:
817
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
818
+ else:
819
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
820
+
821
+ for k in batch:
822
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
823
+ if "labels" in k or is_encoder_decoder:
824
+ pad_value = label_pad_token_id
825
+ elif k.endswith("_input_ids"):
826
+ pad_value = padding_value
827
+ elif k.endswith("_attention_mask"):
828
+ pad_value = 0
829
+ concatenated_key = k.replace("chosen", "concatenated")
830
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
831
+ for k in batch:
832
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
833
+ if "labels" in k or is_encoder_decoder:
834
+ pad_value = label_pad_token_id
835
+ elif k.endswith("_input_ids"):
836
+ pad_value = padding_value
837
+ elif k.endswith("_attention_mask"):
838
+ pad_value = 0
839
+ concatenated_key = k.replace("rejected", "concatenated")
840
+ concatenated_batch[concatenated_key] = torch.cat(
841
+ (
842
+ concatenated_batch[concatenated_key],
843
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
844
+ ),
845
+ dim=0,
846
+ ).to(device=device)
847
+
848
+ if is_encoder_decoder:
849
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
850
+ concatenated_batch["concatenated_attention_mask"] = (
851
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
852
+ )
853
+
854
+ return concatenated_batch
855
+
856
+ def dpo_loss(
857
+ self,
858
+ policy_chosen_logps: torch.FloatTensor,
859
+ policy_rejected_logps: torch.FloatTensor,
860
+ reference_chosen_logps: torch.FloatTensor,
861
+ reference_rejected_logps: torch.FloatTensor,
862
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
863
+ """Compute the DPO loss for a batch of policy and reference model log probabilities.
864
+
865
+ Args:
866
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
867
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
868
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
869
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
870
+
871
+ Returns:
872
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
873
+ The losses tensor contains the DPO loss for each example in the batch.
874
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
875
+ """
876
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
877
+ if self.reference_free:
878
+ ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
879
+ else:
880
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
881
+
882
+ pi_logratios = pi_logratios.to(self.accelerator.device)
883
+ ref_logratios = ref_logratios.to(self.accelerator.device)
884
+ logits = pi_logratios - ref_logratios
885
+
886
+ # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
887
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
888
+ # calculates a conservative DPO loss.
889
+ if self.loss_type == "sigmoid":
890
+ losses = (
891
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
892
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
893
+ )
894
+ elif self.loss_type == "hinge":
895
+ losses = torch.relu(1 - self.beta * logits)
896
+ elif self.loss_type == "ipo":
897
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
898
+ losses = (logits - 1 / (2 * self.beta)) ** 2
899
+ elif self.loss_type == "kto_pair":
900
+ # eqn (7) of the HALOs paper
901
+ chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
902
+ rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
903
+
904
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
905
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
906
+ # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
907
+ losses = torch.cat(
908
+ (
909
+ 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
910
+ 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
911
+ ),
912
+ 0,
913
+ )
914
+ else:
915
+ raise ValueError(
916
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
917
+ )
918
+
919
+ chosen_rewards = (
920
+ self.beta
921
+ * (
922
+ policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
923
+ ).detach()
924
+ )
925
+ rejected_rewards = (
926
+ self.beta
927
+ * (
928
+ policy_rejected_logps.to(self.accelerator.device)
929
+ - reference_rejected_logps.to(self.accelerator.device)
930
+ ).detach()
931
+ )
932
+
933
+ return losses, chosen_rewards, rejected_rewards
934
+
935
+ @staticmethod
936
+ def get_batch_logps(
937
+ logits: torch.FloatTensor,
938
+ labels: torch.LongTensor,
939
+ average_log_prob: bool = False,
940
+ label_pad_token_id: int = -100,
941
+ is_encoder_decoder: bool = False,
942
+ ) -> torch.FloatTensor:
943
+ """Compute the log probabilities of the given labels under the given logits.
944
+
945
+ Args:
946
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
947
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
948
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
949
+ label_pad_token_id: The label pad token id.
950
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
951
+
952
+ Returns:
953
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
954
+ """
955
+ if logits.shape[:-1] != labels.shape:
956
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
957
+
958
+ if not is_encoder_decoder:
959
+ labels = labels[:, 1:].clone()
960
+ logits = logits[:, :-1, :]
961
+ loss_mask = labels != label_pad_token_id
962
+
963
+ # dummy token; we'll ignore the losses on these tokens later
964
+ labels[labels == label_pad_token_id] = 0
965
+
966
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
967
+
968
+ if average_log_prob:
969
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
970
+ else:
971
+ return (per_token_logps * loss_mask).sum(-1)
972
+
973
+ def concatenated_forward(
974
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
975
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
976
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
977
+
978
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
979
+ """
980
+ concatenated_batch = self.concatenated_inputs(
981
+ batch,
982
+ is_encoder_decoder=self.is_encoder_decoder,
983
+ label_pad_token_id=self.label_pad_token_id,
984
+ padding_value=self.padding_value,
985
+ device=self.accelerator.device,
986
+ )
987
+ len_chosen = batch["chosen_labels"].shape[0]
988
+
989
+ model_kwargs = (
990
+ {
991
+ "labels": concatenated_batch["concatenated_labels"],
992
+ "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
993
+ }
994
+ if self.is_encoder_decoder
995
+ else {}
996
+ )
997
+ all_logits = model(
998
+ concatenated_batch["concatenated_input_ids"],
999
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1000
+ use_cache=False,
1001
+ **model_kwargs,
1002
+ ).logits
1003
+
1004
+ all_logps = self.get_batch_logps(
1005
+ all_logits,
1006
+ concatenated_batch["concatenated_labels"],
1007
+ average_log_prob=self.loss_type == "ipo",
1008
+ is_encoder_decoder=self.is_encoder_decoder,
1009
+ label_pad_token_id=self.label_pad_token_id,
1010
+ )
1011
+
1012
+ chosen_logps = all_logps[:len_chosen]
1013
+ rejected_logps = all_logps[len_chosen:]
1014
+
1015
+ chosen_logits = all_logits[:len_chosen]
1016
+ rejected_logits = all_logits[len_chosen:]
1017
+
1018
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1019
+
1020
+ def get_batch_loss_metrics(
1021
+ self,
1022
+ model,
1023
+ batch: Dict[str, Union[List, torch.LongTensor]],
1024
+ train_eval: Literal["train", "eval"] = "train",
1025
+ ):
1026
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
1027
+ metrics = {}
1028
+
1029
+ (
1030
+ policy_chosen_logps,
1031
+ policy_rejected_logps,
1032
+ policy_chosen_logits,
1033
+ policy_rejected_logits,
1034
+ ) = self.concatenated_forward(model, batch)
1035
+
1036
+ # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
1037
+ if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
1038
+ reference_chosen_logps = batch["reference_chosen_logps"]
1039
+ reference_rejected_logps = batch["reference_rejected_logps"]
1040
+ else:
1041
+ with torch.no_grad():
1042
+ if self.ref_model is None:
1043
+ with self.null_ref_context():
1044
+ (
1045
+ reference_chosen_logps,
1046
+ reference_rejected_logps,
1047
+ _,
1048
+ _,
1049
+ ) = self.concatenated_forward(self.model, batch)
1050
+ else:
1051
+ (
1052
+ reference_chosen_logps,
1053
+ reference_rejected_logps,
1054
+ _,
1055
+ _,
1056
+ ) = self.concatenated_forward(self.ref_model, batch)
1057
+
1058
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
1059
+ policy_chosen_logps,
1060
+ policy_rejected_logps,
1061
+ reference_chosen_logps,
1062
+ reference_rejected_logps,
1063
+ )
1064
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1065
+
1066
+ prefix = "eval_" if train_eval == "eval" else ""
1067
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
1068
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
1069
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
1070
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
1071
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
1072
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
1073
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
1074
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
1075
+
1076
+ return losses.mean(), metrics
1077
+
1078
+ def compute_loss(
1079
+ self,
1080
+ model: Union[PreTrainedModel, nn.Module],
1081
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1082
+ return_outputs=False,
1083
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
1084
+ if not self.use_dpo_data_collator:
1085
+ warnings.warn(
1086
+ "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1087
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1088
+ )
1089
+
1090
+ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1091
+
1092
+ with compute_loss_context_manager():
1093
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1094
+
1095
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1096
+ loss = loss.to(self.args.device)
1097
+ # force log the metrics
1098
+ self.store_metrics(metrics, train_eval="train")
1099
+
1100
+ if return_outputs:
1101
+ return (loss, metrics)
1102
+ return loss
1103
+
1104
+ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
1105
+ """Generate samples from the model and reference model for the given batch of inputs."""
1106
+
1107
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1108
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1109
+ generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
1110
+
1111
+ with generate_context_manager():
1112
+ policy_output = model.generate(
1113
+ input_ids=batch["prompt_input_ids"],
1114
+ attention_mask=batch["prompt_attention_mask"],
1115
+ max_length=self.max_length,
1116
+ do_sample=True,
1117
+ pad_token_id=self.tokenizer.pad_token_id,
1118
+ )
1119
+
1120
+ # if reference_output in batch use that otherwise use the reference model
1121
+ if "reference_output" in batch:
1122
+ reference_output = batch["reference_output"]
1123
+ else:
1124
+ if self.ref_model is None:
1125
+ with self.null_ref_context():
1126
+ reference_output = self.model.generate(
1127
+ input_ids=batch["prompt_input_ids"],
1128
+ attention_mask=batch["prompt_attention_mask"],
1129
+ max_length=self.max_length,
1130
+ do_sample=True,
1131
+ pad_token_id=self.tokenizer.pad_token_id,
1132
+ )
1133
+ else:
1134
+ reference_output = self.ref_model.generate(
1135
+ input_ids=batch["prompt_input_ids"],
1136
+ attention_mask=batch["prompt_attention_mask"],
1137
+ max_length=self.max_length,
1138
+ do_sample=True,
1139
+ pad_token_id=self.tokenizer.pad_token_id,
1140
+ )
1141
+
1142
+ policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
1143
+ policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
1144
+
1145
+ reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
1146
+ reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
1147
+
1148
+ return policy_output_decoded, reference_output_decoded
1149
+
1150
+ def prediction_step(
1151
+ self,
1152
+ model: Union[PreTrainedModel, nn.Module],
1153
+ inputs: Dict[str, Union[torch.Tensor, Any]],
1154
+ prediction_loss_only: bool,
1155
+ ignore_keys: Optional[List[str]] = None,
1156
+ ):
1157
+ if not self.use_dpo_data_collator:
1158
+ warnings.warn(
1159
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1160
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1161
+ )
1162
+ if ignore_keys is None:
1163
+ if hasattr(model, "config"):
1164
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1165
+ else:
1166
+ ignore_keys = []
1167
+
1168
+ prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
1169
+
1170
+ with torch.no_grad(), prediction_context_manager():
1171
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1172
+
1173
+ # force log the metrics
1174
+ self.store_metrics(metrics, train_eval="eval")
1175
+
1176
+ if prediction_loss_only:
1177
+ return (loss.detach(), None, None)
1178
+
1179
+ # logits for the chosen and rejected samples from model
1180
+ logits_dict = {
1181
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1182
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1183
+ }
1184
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1185
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1186
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1187
+
1188
+ return (loss.detach(), logits, labels)
1189
+
1190
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1191
+ for key, value in metrics.items():
1192
+ self._stored_metrics[train_eval][key].append(value)
1193
+
1194
+ def evaluation_loop(
1195
+ self,
1196
+ dataloader: DataLoader,
1197
+ description: str,
1198
+ prediction_loss_only: Optional[bool] = None,
1199
+ ignore_keys: Optional[List[str]] = None,
1200
+ metric_key_prefix: str = "eval",
1201
+ ) -> EvalLoopOutput:
1202
+ """
1203
+ Overriding built-in evaluation loop to store metrics for each batch.
1204
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1205
+
1206
+ Works both with or without labels.
1207
+ """
1208
+
1209
+ # Sample and save to game log if requested (for one batch to save time)
1210
+ if self.generate_during_eval:
1211
+ # Generate random indices within the range of the total number of samples
1212
+ num_samples = len(dataloader.dataset)
1213
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1214
+
1215
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1216
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1217
+ random_batch = self.data_collator(random_batch_dataset)
1218
+ random_batch = self._prepare_inputs(random_batch)
1219
+
1220
+ policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
1221
+
1222
+ self.log(
1223
+ {
1224
+ "game_log": wandb.Table(
1225
+ columns=["Prompt", "Policy", "Ref Model"],
1226
+ rows=[
1227
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1228
+ for prompt, pol, ref in zip(
1229
+ random_batch["prompt"], policy_output_decoded, ref_output_decoded
1230
+ )
1231
+ ],
1232
+ )
1233
+ }
1234
+ )
1235
+ self.state.log_history.pop()
1236
+
1237
+ # Base evaluation
1238
+ initial_output = super().evaluation_loop(
1239
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1240
+ )
1241
+
1242
+ return initial_output
1243
+
1244
+ def log(self, logs: Dict[str, float]) -> None:
1245
+ """
1246
+ Log `logs` on the various objects watching training, including stored metrics.
1247
+
1248
+ Args:
1249
+ logs (`Dict[str, float]`):
1250
+ The values to log.
1251
+ """
1252
+ # logs either has 'loss' or 'eval_loss'
1253
+ train_eval = "train" if "loss" in logs else "eval"
1254
+ # Add averaged stored metrics to logs
1255
+ for key, metrics in self._stored_metrics[train_eval].items():
1256
+ logs[key] = torch.tensor(metrics).mean().item()
1257
+ del self._stored_metrics[train_eval]
1258
+ return super().log(logs)
1259
+
1260
+ @wraps(Trainer.push_to_hub)
1261
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
1262
+ """
1263
+ Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the
1264
+ model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
1265
+ """
1266
+ kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
1267
+
1268
+ return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
prepare_sft_data.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (Modifications Copyright(C) [2024] Advanced Micro Devices, Inc. All rights reserved)
2
+ """
3
+ Script for preparing the SFT data for fine-tuning AMD-OLMo model.
4
+ Modifed from https://github.com/allenai/OLMo/blob/main/scripts/prepare_tulu_data.py
5
+ """
6
+
7
+ import logging
8
+ from argparse import ArgumentParser
9
+ from functools import partial
10
+ from pathlib import Path
11
+
12
+ import datasets as ds
13
+ import numpy as np
14
+ from rich.progress import track
15
+
16
+ from olmo.tokenizer import Tokenizer
17
+ from olmo.util import prepare_cli_environment
18
+ import random
19
+ from tqdm import tqdm
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ def convert_code_feedback_to_tulu_format(dataset, mix=False):
25
+ log.info("Converting code_feedback ...")
26
+ y_all = []
27
+ for i, sample in enumerate(dataset):
28
+ y = {
29
+ "dataset": "code_feedback",
30
+ "id": "code_feedback_{}".format(i),
31
+ "messages": sample['messages']
32
+ }
33
+ y_all.append(y)
34
+
35
+ log.info(f"In total {len(y_all)} samples")
36
+ if mix:
37
+ return y_all
38
+ else:
39
+ new_dataset = ds.Dataset.from_list(y_all)
40
+ return new_dataset
41
+
42
+
43
+ def convert_OpenHermes_to_tulu_format(dataset, mix=False):
44
+ log.info("Converting OpenHermes ...")
45
+ role_map = {"human": "user", "gpt": "assistant", "system": "system"}
46
+ y_all = []
47
+ for i, sample in enumerate(dataset):
48
+ y = {
49
+ "dataset": "OpenHermes",
50
+ "id": "OpenHermes_{}".format(i),
51
+ "messages": [{"role": role_map[mssg["from"]], "content": mssg["value"]} for mssg in sample['conversations']]
52
+ }
53
+ y_all.append(y)
54
+
55
+ log.info(f"In total {len(y_all)} samples")
56
+ if mix:
57
+ return y_all
58
+ else:
59
+ new_dataset = ds.Dataset.from_list(y_all)
60
+ return new_dataset
61
+
62
+
63
+ def convert_WebInstructSub_to_tulu_format(dataset, mix=False):
64
+ log.info("Converting WebInstructSub ...")
65
+ y_all = []
66
+ for i, sample in tqdm(enumerate(dataset)):
67
+ y = {
68
+ "dataset": "WebInstructSub",
69
+ "id": "WebInstructSub_{}".format(i),
70
+ "messages": [{"role": "user", "content": sample["question"]}, {"role": "assistant", "content": sample["answer"]}]
71
+ }
72
+ y_all.append(y)
73
+
74
+ log.info(f"In total {len(y_all)} samples")
75
+ if mix:
76
+ return y_all
77
+ else:
78
+ new_dataset = ds.Dataset.from_list(y_all)
79
+ return new_dataset
80
+
81
+
82
+ def main(opts) -> None:
83
+ tokenizer: Tokenizer
84
+ if Path(opts.tokenizer).is_file():
85
+ tokenizer = Tokenizer.from_file(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad)
86
+ else:
87
+ tokenizer = Tokenizer.from_pretrained(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad)
88
+
89
+ if opts.dataset == "tulu":
90
+ dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train")
91
+ elif opts.dataset == "2nd-phase":
92
+ datasets = ["code-feedback", "OpenHermes", "WebInstructSub"]
93
+ combined_datasets = []
94
+ for dataset_name in datasets:
95
+ if dataset_name == "code-feedback":
96
+ dataset = ds.load_dataset("m-a-p/Code-Feedback", split="train")
97
+ dataset = convert_code_feedback_to_tulu_format(dataset, mix=True)
98
+ elif dataset_name == "OpenHermes":
99
+ dataset = ds.load_dataset("teknium/OpenHermes-2.5", split="train")
100
+ dataset = convert_OpenHermes_to_tulu_format(dataset, mix=True)
101
+ elif dataset_name == "WebInstructSub":
102
+ dataset = ds.load_dataset("TIGER-Lab/WebInstructSub", split="train")
103
+ dataset = convert_WebInstructSub_to_tulu_format(dataset, mix=True)
104
+
105
+ combined_datasets += dataset
106
+
107
+ random.seed(42)
108
+ random.shuffle(combined_datasets)
109
+ log.info(f"In total {len(combined_datasets)} samples")
110
+ dataset = ds.Dataset.from_list(combined_datasets)
111
+
112
+ log.info("Tokenizing dataset...")
113
+ dataset = dataset.map(
114
+ partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len),
115
+ batched=False,
116
+ remove_columns=["dataset", "id", "messages"],
117
+ num_proc=opts.num_proc, # type: ignore
118
+ )
119
+
120
+ log.info("Filtering dataset...")
121
+ n = len(dataset) # type: ignore
122
+ dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore
123
+ log.info(f"Filtered out {n - len(dataset):,d} examples")
124
+
125
+ log.info("Counting tokens...")
126
+ total_tokens = 0
127
+ for ex in track(dataset):
128
+ assert len(ex["input_ids"]) == opts.seq_len # type: ignore
129
+ total_tokens += len(ex["input_ids"]) # type: ignore
130
+ log.info(f"Total tokens: {total_tokens:,d}")
131
+
132
+ log.info(f"Saving results to '{opts.output_dir}'...")
133
+ output_dir = Path(opts.output_dir)
134
+ output_dir.mkdir(exist_ok=True, parents=True)
135
+
136
+ input_ids_file = np.memmap(
137
+ str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,)
138
+ )
139
+ label_mask_file = np.memmap(
140
+ str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,)
141
+ )
142
+ offset = 0
143
+ for ex in track(dataset):
144
+ ex_len = len(ex["input_ids"]) # type: ignore
145
+ input_ids_file[offset : offset + ex_len] = ex["input_ids"] # type: ignore
146
+ label_mask_file[offset : offset + ex_len] = ex["label_mask"] # type: ignore
147
+ offset += ex_len
148
+ input_ids_file.flush()
149
+ label_mask_file.flush()
150
+
151
+ log.info("Done!")
152
+
153
+
154
+ def filter(example):
155
+ return example["n_labels"] > 0
156
+
157
+
158
+ def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
159
+ input_ids = [tokenizer.eos_token_id]
160
+ label_mask = [False]
161
+
162
+ for msg in example["messages"]:
163
+ role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False)
164
+ label_mask += [False] * len(role_tokens)
165
+ input_ids += role_tokens
166
+
167
+ if msg["role"] == "assistant":
168
+ content_tokens = tokenizer.encode(
169
+ msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False
170
+ )
171
+ label_mask += [True] * len(content_tokens)
172
+ # mask out the last '\n'
173
+ assert content_tokens[-2] == tokenizer.eos_token_id
174
+ label_mask[-1] = False
175
+ else:
176
+ content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False)
177
+ label_mask += [False] * len(content_tokens)
178
+ input_ids += content_tokens
179
+
180
+ input_ids = input_ids[:max_seq_len]
181
+ label_mask = label_mask[:max_seq_len]
182
+
183
+ if len(input_ids) < max_seq_len:
184
+ pad_len = max_seq_len - len(input_ids)
185
+ input_ids += [tokenizer.pad_token_id] * pad_len
186
+ label_mask += [False] * pad_len
187
+
188
+ assert len(input_ids) == len(label_mask)
189
+ n_labels = sum(label_mask)
190
+
191
+ return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels}
192
+
193
+
194
+ def get_parser() -> ArgumentParser:
195
+ parser = ArgumentParser(description="Prepare Math dataset")
196
+ parser.add_argument("--output_dir", type=str, help="""Directory to save the results to.""")
197
+ parser.add_argument(
198
+ "-t",
199
+ "--tokenizer",
200
+ type=str,
201
+ help="""Tokenizer path or identifier.""",
202
+ default=Path(__file__).parent / "tokenizers" / "allenai_eleuther-ai-gpt-neox-20b-pii-special.json",
203
+ )
204
+ parser.add_argument("-ds", "--dataset", type=str, help="""Dataset that we are processing. tulu or 2nd-phase""", default="tulu")
205
+ parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048)
206
+ parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=50279)
207
+ parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1)
208
+ parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8)
209
+ return parser
210
+
211
+
212
+ if __name__ == "__main__":
213
+ prepare_cli_environment()
214
+ opts = get_parser().parse_args()
215
+ main(opts)