codewithdark commited on
Commit
d860eca
·
verified ·
1 Parent(s): 668d558

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .config/.last_opt_in_prompt.yaml +1 -0
  2. .config/.last_survey_prompt.yaml +1 -0
  3. .config/.last_update_check.json +1 -0
  4. .config/active_config +1 -0
  5. .config/config_sentinel +0 -0
  6. .config/configurations/config_default +6 -0
  7. .config/default_configs.db +0 -0
  8. .config/gce +1 -0
  9. .config/hidden_gcloud_config_universe_descriptor_data_cache_configs.db +0 -0
  10. .config/logs/2025.03.14/13.31.36.734686.log +765 -0
  11. .config/logs/2025.03.14/13.32.03.025824.log +5 -0
  12. .config/logs/2025.03.14/13.32.11.932574.log +153 -0
  13. .config/logs/2025.03.14/13.32.16.153180.log +5 -0
  14. .config/logs/2025.03.14/13.32.25.046318.log +8 -0
  15. .config/logs/2025.03.14/13.32.25.746375.log +8 -0
  16. .gitattributes +3 -0
  17. .gradio/certificate.pem +31 -0
  18. Gemma-Finetune/.gitignore +192 -0
  19. Gemma-Finetune/Gemma3_(4B).ipynb +0 -0
  20. Gemma-Finetune/LICENSE +21 -0
  21. Gemma-Finetune/README.md +41 -0
  22. Gemma-Finetune/main.py +295 -0
  23. Gemma-Finetune/requirements.txt +9 -0
  24. Gemma-Finetune/utils/__pycache__/check_dataset.cpython-311.pyc +0 -0
  25. Gemma-Finetune/utils/__pycache__/model.cpython-311.pyc +0 -0
  26. Gemma-Finetune/utils/__pycache__/sample_dataset.cpython-311.pyc +0 -0
  27. Gemma-Finetune/utils/check_dataset.py +272 -0
  28. Gemma-Finetune/utils/model.py +552 -0
  29. Gemma-Finetune/utils/sample_dataset.py +105 -0
  30. README.md +2 -8
  31. requirements.txt +6 -0
  32. sample_data/README.md +19 -0
  33. sample_data/anscombe.json +49 -0
  34. sample_data/california_housing_test.csv +0 -0
  35. sample_data/california_housing_train.csv +0 -0
  36. sample_data/mnist_test.csv +3 -0
  37. sample_data/mnist_train_small.csv +3 -0
  38. unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
  39. unsloth_compiled_cache/UnslothBCOTrainer.py +1822 -0
  40. unsloth_compiled_cache/UnslothCPOTrainer.py +1555 -0
  41. unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
  42. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  43. unsloth_compiled_cache/UnslothGKDTrainer.py +861 -0
  44. unsloth_compiled_cache/UnslothGRPOTrainer.py +1436 -0
  45. unsloth_compiled_cache/UnslothKTOTrainer.py +1838 -0
  46. unsloth_compiled_cache/UnslothNashMDTrainer.py +953 -0
  47. unsloth_compiled_cache/UnslothORPOTrainer.py +1541 -0
  48. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1267 -0
  49. unsloth_compiled_cache/UnslothPPOTrainer.py +1257 -0
  50. unsloth_compiled_cache/UnslothPRMTrainer.py +798 -0
.config/.last_opt_in_prompt.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
.config/.last_survey_prompt.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ last_prompt_time: 1741959131.3035824
.config/.last_update_check.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"last_update_check_time": 1741959135.630771, "last_update_check_revision": 20250307152352, "notifications": [], "last_nag_times": {}}
.config/active_config ADDED
@@ -0,0 +1 @@
 
 
1
+ default
.config/config_sentinel ADDED
File without changes
.config/configurations/config_default ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [component_manager]
2
+ disable_update_check = true
3
+
4
+ [compute]
5
+ gce_metadata_read_timeout_sec = 0
6
+
.config/default_configs.db ADDED
Binary file (12.3 kB). View file
 
.config/gce ADDED
@@ -0,0 +1 @@
 
 
1
+ False
.config/hidden_gcloud_config_universe_descriptor_data_cache_configs.db ADDED
Binary file (12.3 kB). View file
 
.config/logs/2025.03.14/13.31.36.734686.log ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-14 13:31:48,760 DEBUG root Loaded Command Group: ['gcloud', 'components']
2
+ 2025-03-14 13:31:48,763 DEBUG root Loaded Command Group: ['gcloud', 'components', 'update']
3
+ 2025-03-14 13:31:48,765 DEBUG root Running [gcloud.components.update] with arguments: [--compile-python: "True", --quiet: "True", COMPONENT-IDS:6: "['core', 'gcloud-deps', 'bq', 'gcloud', 'gcloud-crc32c', 'gsutil']"]
4
+ 2025-03-14 13:31:48,766 INFO ___FILE_ONLY___ Beginning update. This process may take several minutes.
5
+
6
+ 2025-03-14 13:31:48,801 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
7
+ 2025-03-14 13:31:49,745 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components-2.json HTTP/11" 200 226132
8
+ 2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
9
+
10
+ 2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
11
+ Your current Google Cloud CLI version is: 514.0.0
12
+
13
+ 2025-03-14 13:31:49,757 INFO ___FILE_ONLY___ Installing components from version: 514.0.0
14
+
15
+ 2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
16
+
17
+ 2025-03-14 13:31:49,757 DEBUG root Chosen display Format:table[box,title="These components will be removed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
18
+ 2025-03-14 13:31:49,758 DEBUG root Chosen display Format:table[box,title="These components will be updated."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
19
+ 2025-03-14 13:31:49,758 DEBUG root Chosen display Format:table[box,title="These components will be installed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
20
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ┌─────────────────────────────────────────────────────────────────────────────┐
21
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
22
+
23
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │ These components will be installed. │
24
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
25
+
26
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ├─────────────────────────────────────────────────────┬────────────┬──────────┤
27
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
28
+
29
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │ Name │ Version │ Size │
30
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
31
+
32
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ├─────────────────────────────────────────────────────┼────────────┼──────────┤
33
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
34
+
35
+ 2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │
36
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ BigQuery Command Line Tool
37
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
38
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
39
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 2.1.14
40
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
41
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
42
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 1.8 MiB
43
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
44
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
45
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
46
+
47
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
48
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ BigQuery Command Line Tool (Platform Specific)
49
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
50
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
51
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 2.1.8
52
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
53
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
54
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ < 1 MiB
55
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
56
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
57
+ 2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
58
+
59
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
60
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ Bundled Python 3.12 (Platform Specific)
61
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
62
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
63
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 3.12.8
64
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
65
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
66
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 89.2 MiB
67
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
68
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
69
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
70
+
71
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
72
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ Cloud Storage Command Line Tool
73
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
74
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
75
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 5.33
76
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
77
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
78
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 11.8 MiB
79
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
80
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
81
+ 2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
82
+
83
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
84
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ Cloud Storage Command Line Tool (Platform Specific)
85
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
86
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
87
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ 5.30
88
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
89
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
90
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ < 1 MiB
91
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
92
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
93
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
94
+
95
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
96
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ Google Cloud CLI Core Libraries (Platform Specific)
97
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
98
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
99
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ 2024.08.30
100
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
101
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
102
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ < 1 MiB
103
+ 2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
104
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
105
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
106
+
107
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
108
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ Google Cloud CRC32C Hash Tool (Platform Specific)
109
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
110
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
111
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 1.0.0
112
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
113
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
114
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 1.4 MiB
115
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
116
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
117
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
118
+
119
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
120
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ gcloud cli dependencies (Platform Specific)
121
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
122
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
123
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 2021.04.16
124
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
125
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
126
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ < 1 MiB
127
+ 2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
128
+ 2025-03-14 13:31:49,800 INFO ___FILE_ONLY___ │
129
+ 2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
130
+
131
+ 2025-03-14 13:31:49,800 INFO ___FILE_ONLY___ └─────────────────────────────────────────────────────┴────────────┴──────────┘
132
+ 2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
133
+
134
+ 2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
135
+
136
+ 2025-03-14 13:31:49,803 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
137
+ 2025-03-14 13:31:49,983 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/RELEASE_NOTES HTTP/11" 200 1377050
138
+ 2025-03-14 13:31:50,443 INFO ___FILE_ONLY___ For the latest full release notes, please visit:
139
+ https://cloud.google.com/sdk/release_notes
140
+
141
+
142
+ 2025-03-14 13:31:50,443 INFO ___FILE_ONLY___ Performing in place update...
143
+
144
+
145
+ 2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
146
+
147
+ 2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╠═ Downloading: BigQuery Command Line Tool ═╣
148
+
149
+ 2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╚
150
+ 2025-03-14 13:31:50,449 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
151
+ 2025-03-14 13:31:51,415 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bq-20250228155416.tar.gz HTTP/11" 200 1845321
152
+ 2025-03-14 13:31:51,427 INFO ___FILE_ONLY___ ═
153
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
154
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
155
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
156
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
157
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
158
+ 2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
159
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
160
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
161
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
162
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
163
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
164
+ 2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
165
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
166
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
167
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
168
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
169
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
170
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
171
+ 2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
172
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
173
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
174
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
175
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
176
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
177
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
178
+ 2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
179
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
180
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
181
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
182
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
183
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
184
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
185
+ 2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
186
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
187
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
188
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
189
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
190
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
191
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
192
+ 2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
193
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
194
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
195
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
196
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
197
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
198
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
199
+ 2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
200
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
201
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
202
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
203
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
204
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
205
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
206
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
207
+ 2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
208
+ 2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
209
+ 2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
210
+ 2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
211
+ 2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
212
+ 2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ╝
213
+
214
+ 2025-03-14 13:31:51,438 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
215
+
216
+ 2025-03-14 13:31:51,438 INFO ___FILE_ONLY___ ╠═ Downloading: BigQuery Command Line Tool (Platform Spe... ═╣
217
+
218
+ 2025-03-14 13:31:51,439 INFO ___FILE_ONLY___ ╚
219
+ 2025-03-14 13:31:51,443 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
220
+ 2025-03-14 13:31:52,383 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bq-nix-20240830134514.tar.gz HTTP/11" 200 1914
221
+ 2025-03-14 13:31:52,384 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
222
+ 2025-03-14 13:31:52,384 INFO ___FILE_ONLY___ ╝
223
+
224
+ 2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
225
+
226
+ 2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╠═ Downloading: Bundled Python 3.12 ═╣
227
+
228
+ 2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╚
229
+ 2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
230
+ 2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╝
231
+
232
+ 2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════���═══════════════╗
233
+
234
+ 2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╠═ Downloading: Bundled Python 3.12 (Platform Specific) ═╣
235
+
236
+ 2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╚
237
+ 2025-03-14 13:31:52,392 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
238
+ 2025-03-14 13:31:53,292 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bundled-python3-unix-linux-x86_64-20250131143518.tar.gz HTTP/11" 200 93520256
239
+ 2025-03-14 13:31:53,624 INFO ___FILE_ONLY___ ═
240
+ 2025-03-14 13:31:53,626 INFO ___FILE_ONLY___ ═
241
+ 2025-03-14 13:31:53,628 INFO ___FILE_ONLY___ ═
242
+ 2025-03-14 13:31:53,630 INFO ___FILE_ONLY___ ═
243
+ 2025-03-14 13:31:53,632 INFO ___FILE_ONLY___ ═
244
+ 2025-03-14 13:31:53,634 INFO ___FILE_ONLY___ ═
245
+ 2025-03-14 13:31:53,636 INFO ___FILE_ONLY___ ═
246
+ 2025-03-14 13:31:53,638 INFO ___FILE_ONLY___ ═
247
+ 2025-03-14 13:31:53,640 INFO ___FILE_ONLY___ ═
248
+ 2025-03-14 13:31:53,642 INFO ___FILE_ONLY___ ═
249
+ 2025-03-14 13:31:53,644 INFO ___FILE_ONLY___ ═
250
+ 2025-03-14 13:31:53,646 INFO ___FILE_ONLY___ ═
251
+ 2025-03-14 13:31:53,648 INFO ___FILE_ONLY___ ═
252
+ 2025-03-14 13:31:53,650 INFO ___FILE_ONLY___ ═
253
+ 2025-03-14 13:31:53,652 INFO ___FILE_ONLY___ ═
254
+ 2025-03-14 13:31:53,654 INFO ___FILE_ONLY___ ═
255
+ 2025-03-14 13:31:53,656 INFO ___FILE_ONLY___ ═
256
+ 2025-03-14 13:31:53,658 INFO ___FILE_ONLY___ ═
257
+ 2025-03-14 13:31:53,660 INFO ___FILE_ONLY___ ═
258
+ 2025-03-14 13:31:53,662 INFO ___FILE_ONLY___ ═
259
+ 2025-03-14 13:31:53,664 INFO ___FILE_ONLY___ ═
260
+ 2025-03-14 13:31:53,666 INFO ___FILE_ONLY___ ═
261
+ 2025-03-14 13:31:53,668 INFO ___FILE_ONLY___ ═
262
+ 2025-03-14 13:31:53,670 INFO ___FILE_ONLY___ ═
263
+ 2025-03-14 13:31:53,672 INFO ___FILE_ONLY___ ═
264
+ 2025-03-14 13:31:53,674 INFO ___FILE_ONLY___ ═
265
+ 2025-03-14 13:31:53,676 INFO ___FILE_ONLY___ ═
266
+ 2025-03-14 13:31:53,678 INFO ___FILE_ONLY___ ═
267
+ 2025-03-14 13:31:53,680 INFO ___FILE_ONLY___ ═
268
+ 2025-03-14 13:31:53,682 INFO ___FILE_ONLY___ ═
269
+ 2025-03-14 13:31:53,684 INFO ___FILE_ONLY___ ═
270
+ 2025-03-14 13:31:53,686 INFO ___FILE_ONLY___ ═
271
+ 2025-03-14 13:31:53,688 INFO ___FILE_ONLY___ ═
272
+ 2025-03-14 13:31:53,689 INFO ___FILE_ONLY___ ═
273
+ 2025-03-14 13:31:53,691 INFO ___FILE_ONLY___ ═
274
+ 2025-03-14 13:31:53,693 INFO ___FILE_ONLY___ ═
275
+ 2025-03-14 13:31:53,695 INFO ___FILE_ONLY___ ═
276
+ 2025-03-14 13:31:53,697 INFO ___FILE_ONLY___ ═
277
+ 2025-03-14 13:31:53,699 INFO ___FILE_ONLY___ ═
278
+ 2025-03-14 13:31:53,701 INFO ___FILE_ONLY___ ═
279
+ 2025-03-14 13:31:53,703 INFO ___FILE_ONLY___ ═
280
+ 2025-03-14 13:31:53,705 INFO ___FILE_ONLY___ ═
281
+ 2025-03-14 13:31:53,707 INFO ___FILE_ONLY___ ═
282
+ 2025-03-14 13:31:53,709 INFO ___FILE_ONLY___ ═
283
+ 2025-03-14 13:31:53,711 INFO ___FILE_ONLY___ ═
284
+ 2025-03-14 13:31:53,713 INFO ___FILE_ONLY___ ═
285
+ 2025-03-14 13:31:53,715 INFO ___FILE_ONLY___ ═
286
+ 2025-03-14 13:31:53,717 INFO ___FILE_ONLY___ ═
287
+ 2025-03-14 13:31:53,719 INFO ___FILE_ONLY___ ═
288
+ 2025-03-14 13:31:53,721 INFO ___FILE_ONLY___ ═
289
+ 2025-03-14 13:31:53,723 INFO ___FILE_ONLY___ ═
290
+ 2025-03-14 13:31:53,725 INFO ___FILE_ONLY___ ═
291
+ 2025-03-14 13:31:53,727 INFO ___FILE_ONLY___ ═
292
+ 2025-03-14 13:31:53,729 INFO ___FILE_ONLY___ ═
293
+ 2025-03-14 13:31:53,731 INFO ___FILE_ONLY___ ═
294
+ 2025-03-14 13:31:53,732 INFO ___FILE_ONLY___ ═
295
+ 2025-03-14 13:31:53,734 INFO ___FILE_ONLY___ ═
296
+ 2025-03-14 13:31:53,736 INFO ___FILE_ONLY___ ═
297
+ 2025-03-14 13:31:53,738 INFO ___FILE_ONLY___ ═
298
+ 2025-03-14 13:31:53,740 INFO ___FILE_ONLY___ ═
299
+ 2025-03-14 13:31:53,741 INFO ___FILE_ONLY___ ╝
300
+
301
+ 2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
302
+
303
+ 2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╠═ Downloading: Cloud Storage Command Line Tool ═╣
304
+
305
+ 2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╚
306
+ 2025-03-14 13:31:53,747 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
307
+ 2025-03-14 13:31:53,929 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gsutil-20241213184646.tar.gz HTTP/11" 200 12368759
308
+ 2025-03-14 13:31:53,995 INFO ___FILE_ONLY___ ═
309
+ 2025-03-14 13:31:53,995 INFO ___FILE_ONLY___ ═
310
+ 2025-03-14 13:31:53,996 INFO ___FILE_ONLY___ ═
311
+ 2025-03-14 13:31:53,996 INFO ___FILE_ONLY___ ═
312
+ 2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
313
+ 2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
314
+ 2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
315
+ 2025-03-14 13:31:53,998 INFO ___FILE_ONLY___ ═
316
+ 2025-03-14 13:31:53,998 INFO ___FILE_ONLY___ ═
317
+ 2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
318
+ 2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
319
+ 2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
320
+ 2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
321
+ 2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
322
+ 2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
323
+ 2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
324
+ 2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
325
+ 2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
326
+ 2025-03-14 13:31:54,002 INFO ___FILE_ONLY___ ═
327
+ 2025-03-14 13:31:54,002 INFO ___FILE_ONLY___ ═
328
+ 2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
329
+ 2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
330
+ 2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
331
+ 2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
332
+ 2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
333
+ 2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
334
+ 2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
335
+ 2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
336
+ 2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
337
+ 2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
338
+ 2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
339
+ 2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
340
+ 2025-03-14 13:31:54,007 INFO ___FILE_ONLY___ ═
341
+ 2025-03-14 13:31:54,007 INFO ___FILE_ONLY___ ═
342
+ 2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
343
+ 2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
344
+ 2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
345
+ 2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
346
+ 2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
347
+ 2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
348
+ 2025-03-14 13:31:54,010 INFO ___FILE_ONLY___ ═
349
+ 2025-03-14 13:31:54,010 INFO ___FILE_ONLY___ ═
350
+ 2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
351
+ 2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
352
+ 2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
353
+ 2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
354
+ 2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
355
+ 2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
356
+ 2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
357
+ 2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
358
+ 2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
359
+ 2025-03-14 13:31:54,014 INFO ___FILE_ONLY___ ═
360
+ 2025-03-14 13:31:54,014 INFO ___FILE_ONLY___ ═
361
+ 2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
362
+ 2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
363
+ 2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
364
+ 2025-03-14 13:31:54,016 INFO ___FILE_ONLY___ ═
365
+ 2025-03-14 13:31:54,016 INFO ___FILE_ONLY___ ═
366
+ 2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ═
367
+ 2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ═
368
+ 2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ╝
369
+
370
+ 2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
371
+
372
+ 2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╠═ Downloading: Cloud Storage Command Line Tool (Platfor... ═╣
373
+
374
+ 2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╚
375
+ 2025-03-14 13:31:54,023 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
376
+ 2025-03-14 13:31:54,217 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gsutil-nix-20240830134514.tar.gz HTTP/11" 200 1928
377
+ 2025-03-14 13:31:54,217 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
378
+ 2025-03-14 13:31:54,217 INFO ___FILE_ONLY___ ╝
379
+
380
+ 2025-03-14 13:31:54,219 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
381
+
382
+ 2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╠═ Downloading: Default set of gcloud commands ═╣
383
+
384
+ 2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╚
385
+ 2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
386
+ 2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╝
387
+
388
+ 2025-03-14 13:31:54,221 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
389
+
390
+ 2025-03-14 13:31:54,222 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CLI Core Libraries (Platfor... ═╣
391
+
392
+ 2025-03-14 13:31:54,222 INFO ___FILE_ONLY___ ╚
393
+ 2025-03-14 13:31:54,225 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
394
+ 2025-03-14 13:31:55,133 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-core-nix-20240830134514.tar.gz HTTP/11" 200 2306
395
+ 2025-03-14 13:31:55,134 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
396
+ 2025-03-14 13:31:55,134 INFO ___FILE_ONLY___ ╝
397
+
398
+ 2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
399
+
400
+ 2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CRC32C Hash Tool ═╣
401
+
402
+ 2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╚
403
+ 2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
404
+ 2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╝
405
+
406
+ 2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
407
+
408
+ 2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CRC32C Hash Tool (Platform ... ═╣
409
+
410
+ 2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╚
411
+ 2025-03-14 13:31:55,141 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
412
+ 2025-03-14 13:31:55,315 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gcloud-crc32c-linux-x86_64-20250110133808.tar.gz HTTP/11" 200 1478989
413
+ 2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
414
+ 2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
415
+ 2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
416
+ 2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
417
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
418
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
419
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
420
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
421
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
422
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
423
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
424
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
425
+ 2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
426
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
427
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
428
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
429
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
430
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
431
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
432
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
433
+ 2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
434
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
435
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
436
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
437
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
438
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
439
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
440
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
441
+ 2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
442
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
443
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
444
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
445
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
446
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
447
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
448
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
449
+ 2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
450
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
451
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
452
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
453
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
454
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
455
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
456
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
457
+ 2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
458
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
459
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
460
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
461
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
462
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
463
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
464
+ 2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
465
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
466
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
467
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
468
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
469
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
470
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
471
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
472
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
473
+ 2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ╝
474
+
475
+ 2025-03-14 13:31:55,335 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
476
+
477
+ 2025-03-14 13:31:55,336 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud cli dependencies (Platform Specific) ═╣
478
+
479
+ 2025-03-14 13:31:55,336 INFO ___FILE_ONLY___ ╚
480
+ 2025-03-14 13:31:55,339 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
481
+ 2025-03-14 13:31:55,513 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gcloud-deps-linux-x86_64-20210416153011.tar.gz HTTP/11" 200 104
482
+ 2025-03-14 13:31:55,515 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
483
+ 2025-03-14 13:31:55,515 INFO ___FILE_ONLY___ ╝
484
+
485
+ 2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
486
+
487
+ 2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╠═ Installing: BigQuery Command Line Tool ═╣
488
+
489
+ 2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╚
490
+ 2025-03-14 13:31:55,609 INFO ___FILE_ONLY___ ═
491
+ 2025-03-14 13:31:55,611 INFO ___FILE_ONLY___ ═
492
+ 2025-03-14 13:31:55,614 INFO ___FILE_ONLY___ ═
493
+ 2025-03-14 13:31:55,616 INFO ___FILE_ONLY___ ═
494
+ 2025-03-14 13:31:55,619 INFO ___FILE_ONLY___ ═
495
+ 2025-03-14 13:31:55,621 INFO ___FILE_ONLY___ ═
496
+ 2025-03-14 13:31:55,624 INFO ___FILE_ONLY___ ═
497
+ 2025-03-14 13:31:55,626 INFO ___FILE_ONLY___ ═
498
+ 2025-03-14 13:31:55,629 INFO ___FILE_ONLY___ ═
499
+ 2025-03-14 13:31:55,631 INFO ___FILE_ONLY___ ═
500
+ 2025-03-14 13:31:55,633 INFO ___FILE_ONLY___ ═
501
+ 2025-03-14 13:31:55,636 INFO ___FILE_ONLY___ ═
502
+ 2025-03-14 13:31:55,638 INFO ___FILE_ONLY___ ═
503
+ 2025-03-14 13:31:55,640 INFO ___FILE_ONLY___ ═
504
+ 2025-03-14 13:31:55,643 INFO ___FILE_ONLY___ ═
505
+ 2025-03-14 13:31:55,645 INFO ___FILE_ONLY___ ═
506
+ 2025-03-14 13:31:55,647 INFO ___FILE_ONLY___ ═
507
+ 2025-03-14 13:31:55,649 INFO ___FILE_ONLY___ ═
508
+ 2025-03-14 13:31:55,653 INFO ___FILE_ONLY___ ═
509
+ 2025-03-14 13:31:55,656 INFO ___FILE_ONLY___ ═
510
+ 2025-03-14 13:31:55,658 INFO ___FILE_ONLY___ ═
511
+ 2025-03-14 13:31:55,661 INFO ___FILE_ONLY___ ═
512
+ 2025-03-14 13:31:55,663 INFO ___FILE_ONLY___ ═
513
+ 2025-03-14 13:31:55,665 INFO ___FILE_ONLY___ ═
514
+ 2025-03-14 13:31:55,667 INFO ___FILE_ONLY___ ═
515
+ 2025-03-14 13:31:55,669 INFO ___FILE_ONLY___ ═
516
+ 2025-03-14 13:31:55,671 INFO ___FILE_ONLY___ ═
517
+ 2025-03-14 13:31:55,674 INFO ___FILE_ONLY___ ═
518
+ 2025-03-14 13:31:55,678 INFO ___FILE_ONLY___ ═
519
+ 2025-03-14 13:31:55,680 INFO ___FILE_ONLY___ ═
520
+ 2025-03-14 13:31:55,682 INFO ___FILE_ONLY___ ═
521
+ 2025-03-14 13:31:55,684 INFO ___FILE_ONLY___ ═
522
+ 2025-03-14 13:31:55,688 INFO ___FILE_ONLY___ ═
523
+ 2025-03-14 13:31:55,690 INFO ___FILE_ONLY___ ═
524
+ 2025-03-14 13:31:55,698 INFO ___FILE_ONLY___ ═
525
+ 2025-03-14 13:31:55,703 INFO ___FILE_ONLY___ ═
526
+ 2025-03-14 13:31:55,709 INFO ___FILE_ONLY___ ═
527
+ 2025-03-14 13:31:55,711 INFO ___FILE_ONLY___ ═
528
+ 2025-03-14 13:31:55,713 INFO ___FILE_ONLY___ ═
529
+ 2025-03-14 13:31:55,716 INFO ___FILE_ONLY___ ═
530
+ 2025-03-14 13:31:55,718 INFO ___FILE_ONLY___ ═
531
+ 2025-03-14 13:31:55,721 INFO ___FILE_ONLY___ ═
532
+ 2025-03-14 13:31:55,726 INFO ___FILE_ONLY___ ═
533
+ 2025-03-14 13:31:55,728 INFO ___FILE_ONLY___ ═
534
+ 2025-03-14 13:31:55,731 INFO ___FILE_ONLY___ ═
535
+ 2025-03-14 13:31:55,732 INFO ___FILE_ONLY___ ═
536
+ 2025-03-14 13:31:55,735 INFO ___FILE_ONLY___ ═
537
+ 2025-03-14 13:31:55,737 INFO ___FILE_ONLY___ ═
538
+ 2025-03-14 13:31:55,739 INFO ___FILE_ONLY___ ═
539
+ 2025-03-14 13:31:55,741 INFO ___FILE_ONLY___ ═
540
+ 2025-03-14 13:31:55,744 INFO ___FILE_ONLY___ ═
541
+ 2025-03-14 13:31:55,746 INFO ___FILE_ONLY___ ═
542
+ 2025-03-14 13:31:55,748 INFO ___FILE_ONLY___ ═
543
+ 2025-03-14 13:31:55,750 INFO ___FILE_ONLY___ ═
544
+ 2025-03-14 13:31:55,752 INFO ___FILE_ONLY___ ═
545
+ 2025-03-14 13:31:55,754 INFO ___FILE_ONLY___ ═
546
+ 2025-03-14 13:31:55,756 INFO ___FILE_ONLY___ ═
547
+ 2025-03-14 13:31:55,759 INFO ___FILE_ONLY___ ═
548
+ 2025-03-14 13:31:55,760 INFO ___FILE_ONLY___ ═
549
+ 2025-03-14 13:31:55,763 INFO ___FILE_ONLY___ ═
550
+ 2025-03-14 13:31:55,763 INFO ___FILE_ONLY___ ╝
551
+
552
+ 2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
553
+
554
+ 2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╠═ Installing: BigQuery Command Line Tool (Platform Spec... ═╣
555
+
556
+ 2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╚
557
+ 2025-03-14 13:31:55,772 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
558
+ 2025-03-14 13:31:55,772 INFO ___FILE_ONLY___ ╝
559
+
560
+ 2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
561
+
562
+ 2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╠═ Installing: Bundled Python 3.12 ═╣
563
+
564
+ 2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╚
565
+ 2025-03-14 13:31:55,780 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
566
+ 2025-03-14 13:31:55,780 INFO ___FILE_ONLY___ ╝
567
+
568
+ 2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
569
+
570
+ 2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╠═ Installing: Bundled Python 3.12 (Platform Specific) ═╣
571
+
572
+ 2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╚
573
+ 2025-03-14 13:31:58,899 INFO ___FILE_ONLY___ ═
574
+ 2025-03-14 13:31:59,398 INFO ___FILE_ONLY___ ═
575
+ 2025-03-14 13:31:59,419 INFO ___FILE_ONLY___ ═
576
+ 2025-03-14 13:31:59,456 INFO ___FILE_ONLY___ ═
577
+ 2025-03-14 13:31:59,479 INFO ___FILE_ONLY___ ═
578
+ 2025-03-14 13:31:59,496 INFO ___FILE_ONLY___ ═
579
+ 2025-03-14 13:31:59,528 INFO ___FILE_ONLY___ ═
580
+ 2025-03-14 13:31:59,550 INFO ___FILE_ONLY___ ═
581
+ 2025-03-14 13:31:59,569 INFO ___FILE_ONLY___ ═
582
+ 2025-03-14 13:31:59,588 INFO ___FILE_ONLY___ ═
583
+ 2025-03-14 13:31:59,604 INFO ___FILE_ONLY___ ═
584
+ 2025-03-14 13:31:59,632 INFO ___FILE_ONLY___ ═
585
+ 2025-03-14 13:31:59,775 INFO ___FILE_ONLY___ ═
586
+ 2025-03-14 13:31:59,793 INFO ___FILE_ONLY___ ═
587
+ 2025-03-14 13:31:59,810 INFO ___FILE_ONLY___ ═
588
+ 2025-03-14 13:31:59,827 INFO ___FILE_ONLY___ ═
589
+ 2025-03-14 13:31:59,842 INFO ___FILE_ONLY___ ═
590
+ 2025-03-14 13:31:59,861 INFO ___FILE_ONLY___ ═
591
+ 2025-03-14 13:31:59,879 INFO ___FILE_ONLY___ ═
592
+ 2025-03-14 13:31:59,900 INFO ___FILE_ONLY___ ═
593
+ 2025-03-14 13:31:59,917 INFO ___FILE_ONLY___ ═
594
+ 2025-03-14 13:32:00,014 INFO ___FILE_ONLY___ ═
595
+ 2025-03-14 13:32:00,038 INFO ___FILE_ONLY___ ═
596
+ 2025-03-14 13:32:00,539 INFO ___FILE_ONLY___ ═
597
+ 2025-03-14 13:32:00,556 INFO ___FILE_ONLY___ ═
598
+ 2025-03-14 13:32:00,571 INFO ___FILE_ONLY___ ═
599
+ 2025-03-14 13:32:00,584 INFO ___FILE_ONLY___ ═
600
+ 2025-03-14 13:32:00,597 INFO ___FILE_ONLY___ ═
601
+ 2025-03-14 13:32:00,610 INFO ___FILE_ONLY___ ═
602
+ 2025-03-14 13:32:00,623 INFO ___FILE_ONLY___ ═
603
+ 2025-03-14 13:32:00,635 INFO ___FILE_ONLY___ ═
604
+ 2025-03-14 13:32:00,647 INFO ___FILE_ONLY___ ═
605
+ 2025-03-14 13:32:00,659 INFO ___FILE_ONLY___ ═
606
+ 2025-03-14 13:32:00,672 INFO ___FILE_ONLY___ ═
607
+ 2025-03-14 13:32:00,684 INFO ___FILE_ONLY___ ═
608
+ 2025-03-14 13:32:00,697 INFO ___FILE_ONLY___ ═
609
+ 2025-03-14 13:32:00,711 INFO ___FILE_ONLY___ ═
610
+ 2025-03-14 13:32:00,724 INFO ___FILE_ONLY___ ═
611
+ 2025-03-14 13:32:00,737 INFO ___FILE_ONLY___ ═
612
+ 2025-03-14 13:32:00,750 INFO ___FILE_ONLY___ ═
613
+ 2025-03-14 13:32:00,763 INFO ___FILE_ONLY___ ═
614
+ 2025-03-14 13:32:00,775 INFO ___FILE_ONLY___ ═
615
+ 2025-03-14 13:32:00,788 INFO ___FILE_ONLY___ ═
616
+ 2025-03-14 13:32:00,802 INFO ___FILE_ONLY___ ═
617
+ 2025-03-14 13:32:00,815 INFO ___FILE_ONLY___ ═
618
+ 2025-03-14 13:32:00,828 INFO ___FILE_ONLY___ ═
619
+ 2025-03-14 13:32:00,841 INFO ___FILE_ONLY___ ═
620
+ 2025-03-14 13:32:00,854 INFO ___FILE_ONLY___ ═
621
+ 2025-03-14 13:32:00,867 INFO ___FILE_ONLY___ ═
622
+ 2025-03-14 13:32:00,881 INFO ___FILE_ONLY___ ═
623
+ 2025-03-14 13:32:00,893 INFO ___FILE_ONLY___ ═
624
+ 2025-03-14 13:32:00,905 INFO ___FILE_ONLY___ ═
625
+ 2025-03-14 13:32:00,919 INFO ___FILE_ONLY___ ═
626
+ 2025-03-14 13:32:00,932 INFO ___FILE_ONLY___ ═
627
+ 2025-03-14 13:32:00,946 INFO ___FILE_ONLY___ ═
628
+ 2025-03-14 13:32:00,959 INFO ___FILE_ONLY___ ═
629
+ 2025-03-14 13:32:00,972 INFO ___FILE_ONLY___ ═
630
+ 2025-03-14 13:32:00,985 INFO ___FILE_ONLY___ ═
631
+ 2025-03-14 13:32:00,999 INFO ___FILE_ONLY___ ═
632
+ 2025-03-14 13:32:01,012 INFO ___FILE_ONLY___ ═
633
+ 2025-03-14 13:32:01,012 INFO ___FILE_ONLY___ ╝
634
+
635
+ 2025-03-14 13:32:01,068 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
636
+
637
+ 2025-03-14 13:32:01,069 INFO ___FILE_ONLY___ ╠═ Installing: Cloud Storage Command Line Tool ═╣
638
+
639
+ 2025-03-14 13:32:01,069 INFO ___FILE_ONLY___ ╚
640
+ 2025-03-14 13:32:01,600 INFO ___FILE_ONLY___ ═
641
+ 2025-03-14 13:32:01,617 INFO ___FILE_ONLY___ ═
642
+ 2025-03-14 13:32:01,638 INFO ___FILE_ONLY___ ═
643
+ 2025-03-14 13:32:01,655 INFO ___FILE_ONLY___ ═
644
+ 2025-03-14 13:32:01,671 INFO ___FILE_ONLY___ ═
645
+ 2025-03-14 13:32:01,684 INFO ___FILE_ONLY___ ═
646
+ 2025-03-14 13:32:01,699 INFO ___FILE_ONLY___ ═
647
+ 2025-03-14 13:32:01,711 INFO ___FILE_ONLY___ ═
648
+ 2025-03-14 13:32:01,726 INFO ___FILE_ONLY___ ═
649
+ 2025-03-14 13:32:01,740 INFO ___FILE_ONLY___ ═
650
+ 2025-03-14 13:32:01,754 INFO ___FILE_ONLY___ ═
651
+ 2025-03-14 13:32:01,766 INFO ___FILE_ONLY___ ═
652
+ 2025-03-14 13:32:01,776 INFO ___FILE_ONLY___ ═
653
+ 2025-03-14 13:32:01,786 INFO ___FILE_ONLY___ ═
654
+ 2025-03-14 13:32:01,797 INFO ___FILE_ONLY___ ═
655
+ 2025-03-14 13:32:01,807 INFO ___FILE_ONLY___ ═
656
+ 2025-03-14 13:32:01,818 INFO ___FILE_ONLY___ ═
657
+ 2025-03-14 13:32:01,829 INFO ___FILE_ONLY___ ═
658
+ 2025-03-14 13:32:01,840 INFO ___FILE_ONLY___ ═
659
+ 2025-03-14 13:32:01,851 INFO ___FILE_ONLY___ ═
660
+ 2025-03-14 13:32:01,864 INFO ___FILE_ONLY___ ═
661
+ 2025-03-14 13:32:01,878 INFO ___FILE_ONLY___ ═
662
+ 2025-03-14 13:32:01,893 INFO ___FILE_ONLY___ ═
663
+ 2025-03-14 13:32:01,903 INFO ___FILE_ONLY___ ═
664
+ 2025-03-14 13:32:01,916 INFO ___FILE_ONLY___ ═
665
+ 2025-03-14 13:32:01,931 INFO ___FILE_ONLY___ ═
666
+ 2025-03-14 13:32:01,948 INFO ___FILE_ONLY___ ═
667
+ 2025-03-14 13:32:01,967 INFO ___FILE_ONLY___ ═
668
+ 2025-03-14 13:32:01,983 INFO ___FILE_ONLY___ ═
669
+ 2025-03-14 13:32:02,004 INFO ___FILE_ONLY___ ═
670
+ 2025-03-14 13:32:02,015 INFO ___FILE_ONLY___ ═
671
+ 2025-03-14 13:32:02,027 INFO ___FILE_ONLY___ ═
672
+ 2025-03-14 13:32:02,045 INFO ___FILE_ONLY___ ═
673
+ 2025-03-14 13:32:02,058 INFO ___FILE_ONLY___ ═
674
+ 2025-03-14 13:32:02,069 INFO ___FILE_ONLY___ ═
675
+ 2025-03-14 13:32:02,079 INFO ___FILE_ONLY___ ═
676
+ 2025-03-14 13:32:02,089 INFO ___FILE_ONLY___ ═
677
+ 2025-03-14 13:32:02,101 INFO ___FILE_ONLY___ ═
678
+ 2025-03-14 13:32:02,112 INFO ___FILE_ONLY___ ═
679
+ 2025-03-14 13:32:02,124 INFO ___FILE_ONLY___ ═
680
+ 2025-03-14 13:32:02,134 INFO ___FILE_ONLY___ ═
681
+ 2025-03-14 13:32:02,148 INFO ___FILE_ONLY___ ═
682
+ 2025-03-14 13:32:02,157 INFO ___FILE_ONLY___ ═
683
+ 2025-03-14 13:32:02,172 INFO ___FILE_ONLY___ ═
684
+ 2025-03-14 13:32:02,187 INFO ___FILE_ONLY___ ═
685
+ 2025-03-14 13:32:02,199 INFO ___FILE_ONLY___ ═
686
+ 2025-03-14 13:32:02,209 INFO ___FILE_ONLY___ ═
687
+ 2025-03-14 13:32:02,220 INFO ___FILE_ONLY___ ═
688
+ 2025-03-14 13:32:02,231 INFO ___FILE_ONLY___ ═
689
+ 2025-03-14 13:32:02,246 INFO ___FILE_ONLY___ ═
690
+ 2025-03-14 13:32:02,266 INFO ___FILE_ONLY___ ═
691
+ 2025-03-14 13:32:02,281 INFO ___FILE_ONLY___ ═
692
+ 2025-03-14 13:32:02,300 INFO ___FILE_ONLY___ ═
693
+ 2025-03-14 13:32:02,317 INFO ___FILE_ONLY___ ═
694
+ 2025-03-14 13:32:02,357 INFO ___FILE_ONLY___ ═
695
+ 2025-03-14 13:32:02,368 INFO ___FILE_ONLY___ ═
696
+ 2025-03-14 13:32:02,379 INFO ___FILE_ONLY___ ═
697
+ 2025-03-14 13:32:02,390 INFO ___FILE_ONLY___ ═
698
+ 2025-03-14 13:32:02,402 INFO ___FILE_ONLY___ ═
699
+ 2025-03-14 13:32:02,417 INFO ___FILE_ONLY___ ═
700
+ 2025-03-14 13:32:02,417 INFO ___FILE_ONLY___ ╝
701
+
702
+ 2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
703
+
704
+ 2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╠═ Installing: Cloud Storage Command Line Tool (Platform... ═╣
705
+
706
+ 2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╚
707
+ 2025-03-14 13:32:02,450 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
708
+ 2025-03-14 13:32:02,450 INFO ___FILE_ONLY___ ╝
709
+
710
+ 2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
711
+
712
+ 2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╠═ Installing: Default set of gcloud commands ═╣
713
+
714
+ 2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╚
715
+ 2025-03-14 13:32:02,457 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
716
+ 2025-03-14 13:32:02,457 INFO ___FILE_ONLY___ ╝
717
+
718
+ 2025-03-14 13:32:02,458 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
719
+
720
+ 2025-03-14 13:32:02,458 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CLI Core Libraries (Platform... ═╣
721
+
722
+ 2025-03-14 13:32:02,459 INFO ___FILE_ONLY___ ╚
723
+ 2025-03-14 13:32:02,459 INFO ___FILE_ONLY___ ══════════════════════════════
724
+ 2025-03-14 13:32:02,460 INFO ___FILE_ONLY___ ══════════════════════════════
725
+ 2025-03-14 13:32:02,460 INFO ___FILE_ONLY___ ╝
726
+
727
+ 2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
728
+
729
+ 2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CRC32C Hash Tool ═╣
730
+
731
+ 2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╚
732
+ 2025-03-14 13:32:02,466 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
733
+ 2025-03-14 13:32:02,466 INFO ___FILE_ONLY___ ╝
734
+
735
+ 2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
736
+
737
+ 2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CRC32C Hash Tool (Platform S... ═╣
738
+
739
+ 2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╚
740
+ 2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ══════════════════════════════
741
+ 2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ══════════════════════════════
742
+ 2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ╝
743
+
744
+ 2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
745
+
746
+ 2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╠═ Installing: gcloud cli dependencies (Platform Specific) ═╣
747
+
748
+ 2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╚
749
+ 2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
750
+ 2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╝
751
+
752
+ 2025-03-14 13:32:02,518 DEBUG root Updating notification cache...
753
+ 2025-03-14 13:32:02,518 INFO ___FILE_ONLY___
754
+
755
+ 2025-03-14 13:32:02,520 INFO ___FILE_ONLY___ Performing post processing steps...
756
+ 2025-03-14 13:32:02,521 DEBUG root Executing command: ['/tools/google-cloud-sdk/bin/gcloud', 'components', 'post-process']
757
+ 2025-03-14 13:32:11,259 DEBUG ___FILE_ONLY___
758
+ 2025-03-14 13:32:11,259 DEBUG ___FILE_ONLY___
759
+ 2025-03-14 13:32:11,299 INFO root descriptor_list: [{'universeDomain': 'googleapis.com', 'universeShortName': '', 'authenticationDomain': 'auth.cloud.google.com', 'projectPrefix': '', 'cloudWebDomain': 'cloud.google.com', 'documentationDomain': 'cloud.google.com', 'version': '1.0.0', 'state': 'primary', 'artifactRegistryDomain': 'pkg.dev'}]
760
+ 2025-03-14 13:32:11,299 INFO ___FILE_ONLY___
761
+ Update done!
762
+
763
+
764
+ 2025-03-14 13:32:11,302 DEBUG root Chosen display Format:none
765
+ 2025-03-14 13:32:11,303 INFO root Display format: "none"
.config/logs/2025.03.14/13.32.03.025824.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2025-03-14 13:32:03,026 DEBUG root Loaded Command Group: ['gcloud', 'components']
2
+ 2025-03-14 13:32:03,028 DEBUG root Loaded Command Group: ['gcloud', 'components', 'post_process']
3
+ 2025-03-14 13:32:03,030 DEBUG root Running [gcloud.components.post-process] with arguments: []
4
+ 2025-03-14 13:32:11,135 DEBUG root Chosen display Format:none
5
+ 2025-03-14 13:32:11,136 INFO root Display format: "none"
.config/logs/2025.03.14/13.32.11.932574.log ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-03-14 13:32:11,933 DEBUG root Loaded Command Group: ['gcloud', 'components']
2
+ 2025-03-14 13:32:11,935 DEBUG root Loaded Command Group: ['gcloud', 'components', 'update']
3
+ 2025-03-14 13:32:11,937 DEBUG root Running [gcloud.components.update] with arguments: [--quiet: "True", COMPONENT-IDS:8: "['gcloud', 'core', 'bq', 'gsutil', 'compute', 'preview', 'alpha', 'beta']"]
4
+ 2025-03-14 13:32:11,938 INFO ___FILE_ONLY___ Beginning update. This process may take several minutes.
5
+
6
+ 2025-03-14 13:32:11,947 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
7
+ 2025-03-14 13:32:12,903 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components-2.json HTTP/11" 200 226132
8
+ 2025-03-14 13:32:12,918 WARNING root Component [compute] no longer exists.
9
+ 2025-03-14 13:32:12,919 INFO ___FILE_ONLY___
10
+
11
+ 2025-03-14 13:32:12,919 INFO ___FILE_ONLY___
12
+ Your current Google Cloud CLI version is: 514.0.0
13
+
14
+ 2025-03-14 13:32:12,920 INFO ___FILE_ONLY___ Installing components from version: 514.0.0
15
+
16
+ 2025-03-14 13:32:12,920 INFO ___FILE_ONLY___
17
+
18
+ 2025-03-14 13:32:12,920 DEBUG root Chosen display Format:table[box,title="These components will be removed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
19
+ 2025-03-14 13:32:12,920 DEBUG root Chosen display Format:table[box,title="These components will be updated."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
20
+ 2025-03-14 13:32:12,921 DEBUG root Chosen display Format:table[box,title="These components will be installed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
21
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ┌────────────────────────────────────────────────┐
22
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
23
+
24
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │ These components will be installed. │
25
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
26
+
27
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ├─────────────────────────┬────────────┬─────────┤
28
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
29
+
30
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │ Name │ Version │ Size │
31
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
32
+
33
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ├─────────────────────────┼────────────┼─────────┤
34
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
35
+
36
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │
37
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ gcloud Alpha Commands
38
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
39
+ 2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │
40
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ 2025.03.07
41
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
42
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
43
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ < 1 MiB
44
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
45
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
46
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
47
+
48
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
49
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ gcloud Beta Commands
50
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
51
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
52
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ 2025.03.07
53
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
54
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
55
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ < 1 MiB
56
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
57
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
58
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
59
+
60
+ 2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
61
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ gcloud Preview Commands
62
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
63
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
64
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
65
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
66
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
67
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ < 1 MiB
68
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
69
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
70
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
71
+
72
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ └─────────────────────────┴────────────┴─────────┘
73
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
74
+
75
+ 2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
76
+
77
+ 2025-03-14 13:32:12,940 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
78
+ 2025-03-14 13:32:13,829 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/RELEASE_NOTES HTTP/11" 200 1377050
79
+ 2025-03-14 13:32:14,277 INFO ___FILE_ONLY___ For the latest full release notes, please visit:
80
+ https://cloud.google.com/sdk/release_notes
81
+
82
+
83
+ 2025-03-14 13:32:14,277 INFO ___FILE_ONLY___ Performing in place update...
84
+
85
+
86
+ 2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
87
+
88
+ 2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Alpha Commands ═╣
89
+
90
+ 2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╚
91
+ 2025-03-14 13:32:14,283 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
92
+ 2025-03-14 13:32:15,248 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-alpha-20250307152352.tar.gz HTTP/11" 200 800
93
+ 2025-03-14 13:32:15,249 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
94
+ 2025-03-14 13:32:15,249 INFO ___FILE_ONLY___ ╝
95
+
96
+ 2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
97
+
98
+ 2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Beta Commands ═╣
99
+
100
+ 2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╚
101
+ 2025-03-14 13:32:15,254 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
102
+ 2025-03-14 13:32:15,413 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-beta-20250307152352.tar.gz HTTP/11" 200 797
103
+ 2025-03-14 13:32:15,413 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
104
+ 2025-03-14 13:32:15,413 INFO ___FILE_ONLY___ ╝
105
+
106
+ 2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
107
+
108
+ 2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Preview Commands ═╣
109
+
110
+ 2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╚
111
+ 2025-03-14 13:32:15,419 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
112
+ 2025-03-14 13:32:15,610 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-preview-20241115154308.tar.gz HTTP/11" 200 823
113
+ 2025-03-14 13:32:15,611 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
114
+ 2025-03-14 13:32:15,611 INFO ___FILE_ONLY___ ╝
115
+
116
+ 2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
117
+
118
+ 2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Alpha Commands ═╣
119
+
120
+ 2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╚
121
+ 2025-03-14 13:32:15,614 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
122
+ 2025-03-14 13:32:15,614 INFO ___FILE_ONLY___ ╝
123
+
124
+ 2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
125
+
126
+ 2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Beta Commands ═╣
127
+
128
+ 2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╚
129
+ 2025-03-14 13:32:15,620 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
130
+ 2025-03-14 13:32:15,620 INFO ___FILE_ONLY___ ╝
131
+
132
+ 2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
133
+
134
+ 2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Preview Commands ═╣
135
+
136
+ 2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╚
137
+ 2025-03-14 13:32:15,626 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
138
+ 2025-03-14 13:32:15,626 INFO ___FILE_ONLY___ ╝
139
+
140
+ 2025-03-14 13:32:15,630 DEBUG root Updating notification cache...
141
+ 2025-03-14 13:32:15,631 INFO ___FILE_ONLY___
142
+
143
+ 2025-03-14 13:32:15,632 INFO ___FILE_ONLY___ Performing post processing steps...
144
+ 2025-03-14 13:32:15,633 DEBUG root Executing command: ['/tools/google-cloud-sdk/bin/gcloud', 'components', 'post-process']
145
+ 2025-03-14 13:32:24,180 DEBUG ___FILE_ONLY___
146
+ 2025-03-14 13:32:24,180 DEBUG ___FILE_ONLY___
147
+ 2025-03-14 13:32:24,406 INFO root descriptor_list: [{'universeDomain': 'googleapis.com', 'universeShortName': '', 'authenticationDomain': 'auth.cloud.google.com', 'projectPrefix': '', 'cloudWebDomain': 'cloud.google.com', 'documentationDomain': 'cloud.google.com', 'version': '1.0.0', 'state': 'primary', 'artifactRegistryDomain': 'pkg.dev'}]
148
+ 2025-03-14 13:32:24,407 INFO ___FILE_ONLY___
149
+ Update done!
150
+
151
+
152
+ 2025-03-14 13:32:24,409 DEBUG root Chosen display Format:none
153
+ 2025-03-14 13:32:24,410 INFO root Display format: "none"
.config/logs/2025.03.14/13.32.16.153180.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 2025-03-14 13:32:16,154 DEBUG root Loaded Command Group: ['gcloud', 'components']
2
+ 2025-03-14 13:32:16,155 DEBUG root Loaded Command Group: ['gcloud', 'components', 'post_process']
3
+ 2025-03-14 13:32:16,157 DEBUG root Running [gcloud.components.post-process] with arguments: []
4
+ 2025-03-14 13:32:24,057 DEBUG root Chosen display Format:none
5
+ 2025-03-14 13:32:24,058 INFO root Display format: "none"
.config/logs/2025.03.14/13.32.25.046318.log ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ 2025-03-14 13:32:25,048 DEBUG root Loaded Command Group: ['gcloud', 'config']
2
+ 2025-03-14 13:32:25,095 DEBUG root Loaded Command Group: ['gcloud', 'config', 'set']
3
+ 2025-03-14 13:32:25,098 DEBUG root Running [gcloud.config.set] with arguments: [SECTION/PROPERTY: "component_manager/disable_update_check", VALUE: "true"]
4
+ 2025-03-14 13:32:25,098 INFO ___FILE_ONLY___ Updated property [component_manager/disable_update_check].
5
+
6
+ 2025-03-14 13:32:25,099 DEBUG root Chosen display Format:default
7
+ 2025-03-14 13:32:25,100 INFO root Display format: "default"
8
+ 2025-03-14 13:32:25,100 DEBUG root SDK update checks are disabled.
.config/logs/2025.03.14/13.32.25.746375.log ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ 2025-03-14 13:32:25,748 DEBUG root Loaded Command Group: ['gcloud', 'config']
2
+ 2025-03-14 13:32:25,794 DEBUG root Loaded Command Group: ['gcloud', 'config', 'set']
3
+ 2025-03-14 13:32:25,796 DEBUG root Running [gcloud.config.set] with arguments: [SECTION/PROPERTY: "compute/gce_metadata_read_timeout_sec", VALUE: "0"]
4
+ 2025-03-14 13:32:25,797 INFO ___FILE_ONLY___ Updated property [compute/gce_metadata_read_timeout_sec].
5
+
6
+ 2025-03-14 13:32:25,798 DEBUG root Chosen display Format:default
7
+ 2025-03-14 13:32:25,799 INFO root Display format: "default"
8
+ 2025-03-14 13:32:25,799 DEBUG root SDK update checks are disabled.
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
37
+ sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
38
+ unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
Gemma-Finetune/.gitignore ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # Model files and datasets
29
+ models/
30
+ sample_datasets/
31
+ *.pt
32
+ *.pth
33
+ *.bin
34
+ *.gguf
35
+ *.onnx
36
+
37
+ # IDE
38
+ .idea/
39
+ .vscode/
40
+ *.swp
41
+ *.swo
42
+
43
+ # Logs and databases
44
+ *.log
45
+ *.sqlite
46
+ wandb/
47
+
48
+ # OS generated files
49
+ .DS_Store
50
+ .DS_Store?
51
+ ._*
52
+ .Spotlight-V100
53
+ .Trashes
54
+ ehthumbs.db
55
+ Thumbs.db
56
+
57
+ # Installer logs
58
+ pip-log.txt
59
+ pip-delete-this-directory.txt
60
+
61
+ # Unit test / coverage reports
62
+ htmlcov/
63
+ .tox/
64
+ .nox/
65
+ .coverage
66
+ .coverage.*
67
+ .cache
68
+ nosetests.xml
69
+ coverage.xml
70
+ *.cover
71
+ *.py,cover
72
+ .hypothesis/
73
+ .pytest_cache/
74
+ cover/
75
+
76
+ # Translations
77
+ *.mo
78
+ *.pot
79
+
80
+ # Django stuff:
81
+ local_settings.py
82
+ db.sqlite3
83
+ db.sqlite3-journal
84
+
85
+ # Flask stuff:
86
+ instance/
87
+ .webassets-cache
88
+
89
+ # Scrapy stuff:
90
+ .scrapy
91
+
92
+ # Sphinx documentation
93
+ docs/_build/
94
+
95
+ # PyBuilder
96
+ .pybuilder/
97
+ target/
98
+
99
+ # Jupyter Notebook
100
+ .ipynb_checkpoints
101
+
102
+ # IPython
103
+ profile_default/
104
+ ipython_config.py
105
+
106
+ # pyenv
107
+ # For a library or package, you might want to ignore these files since the code is
108
+ # intended to run in multiple environments; otherwise, check them in:
109
+ # .python-version
110
+
111
+ # pipenv
112
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
113
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
114
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
115
+ # install all needed dependencies.
116
+ #Pipfile.lock
117
+
118
+ # UV
119
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
120
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
121
+ # commonly ignored for libraries.
122
+ #uv.lock
123
+
124
+ # poetry
125
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
126
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
127
+ # commonly ignored for libraries.
128
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
129
+ #poetry.lock
130
+
131
+ # pdm
132
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
133
+ #pdm.lock
134
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
135
+ # in version control.
136
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
137
+ .pdm.toml
138
+ .pdm-python
139
+ .pdm-build/
140
+
141
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
142
+ __pypackages__/
143
+
144
+ # Celery stuff
145
+ celerybeat-schedule
146
+ celerybeat.pid
147
+
148
+ # SageMath parsed files
149
+ *.sage.py
150
+
151
+ # Environments
152
+ .env
153
+ .venv
154
+ env.bak/
155
+ venv.bak/
156
+
157
+ # Spyder project settings
158
+ .spyderproject
159
+ .spyproject
160
+
161
+ # Rope project settings
162
+ .ropeproject
163
+
164
+ # mkdocs documentation
165
+ /site
166
+
167
+ # mypy
168
+ .mypy_cache/
169
+ .dmypy.json
170
+ dmypy.json
171
+
172
+ # Pyre type checker
173
+ .pyre/
174
+
175
+ # pytype static type analyzer
176
+ .pytype/
177
+
178
+ # Cython debug symbols
179
+ cython_debug/
180
+
181
+ # PyCharm
182
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
183
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
184
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
185
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
186
+ #.idea/
187
+
188
+ # Ruff stuff:
189
+ .ruff_cache/
190
+
191
+ # PyPI configuration file
192
+ .pypirc
Gemma-Finetune/Gemma3_(4B).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Gemma-Finetune/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Dark Coder
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Gemma-Finetune/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gemma Fine-tuning UI
2
+
3
+ A user-friendly interface for fine-tuning Google's Gemma models using Unsloth optimizations.
4
+
5
+ ## Features
6
+
7
+ - Easy-to-use web interface for model fine-tuning
8
+ - Support for multiple data formats (CSV, JSONL, TEXT)
9
+ - Parameter-efficient fine-tuning with LoRA
10
+ - Real-time training progress visualization
11
+ - Model export in multiple formats
12
+ - Integrated text generation testing
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ git clone https://github.com/codewithdark-git/Gemma-Finetune.git
18
+ cd Gemma-Finetune
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ 1. Run the application:
25
+ ```bash
26
+ python main.py
27
+ ```
28
+
29
+ 2. Follow the UI steps:
30
+ - Upload your dataset
31
+ - Configure model parameters
32
+ - Start training
33
+ - Test and export your model
34
+
35
+ ## Requirements
36
+
37
+ See requirements.txt for detailed dependencies.
38
+
39
+ ## License
40
+
41
+ MIT License
Gemma-Finetune/main.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.check_dataset import validate_dataset, generate_dataset_report
3
+ from utils.sample_dataset import generate_sample_datasets
4
+ from utils.model import GemmaFineTuning
5
+
6
+ class GemmaUI:
7
+ def __init__(self):
8
+ self.model_instance = GemmaFineTuning()
9
+ self.default_params = self.model_instance.default_params
10
+
11
+ def create_ui(self):
12
+ """Create the Gradio interface"""
13
+ with gr.Blocks(title="Gemma Fine-tuning UI") as app:
14
+ gr.Markdown("# Gemma Model Fine-tuning Interface")
15
+ gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
16
+
17
+ with gr.Tabs():
18
+ with gr.TabItem("1. Data Upload & Preprocessing"):
19
+ with gr.Row():
20
+ with gr.Column():
21
+ file_upload = gr.File(label="Upload Dataset")
22
+ file_format = gr.Radio(
23
+ ["csv", "jsonl", "text"],
24
+ label="File Format",
25
+ value="csv"
26
+ )
27
+ preprocess_button = gr.Button("Preprocess Dataset")
28
+ dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
29
+
30
+ with gr.TabItem("2. Model & Hyperparameters"):
31
+ with gr.Row():
32
+ with gr.Column():
33
+ model_name = gr.Dropdown(
34
+ choices=[
35
+ "google/gemma-2b",
36
+ "google/gemma-7b",
37
+ "google/gemma-2b-it",
38
+ "google/gemma-7b-it"
39
+ ],
40
+ value=self.default_params["model_name"],
41
+ label="Model Name",
42
+ info="Select a Gemma model to fine-tune"
43
+ )
44
+ learning_rate = gr.Slider(
45
+ minimum=1e-6,
46
+ maximum=5e-4,
47
+ value=self.default_params["learning_rate"],
48
+ label="Learning Rate",
49
+ info="Learning rate for the optimizer"
50
+ )
51
+ batch_size = gr.Slider(
52
+ minimum=1,
53
+ maximum=32,
54
+ step=1,
55
+ value=self.default_params["batch_size"],
56
+ label="Batch Size",
57
+ info="Number of samples in each training batch"
58
+ )
59
+ epochs = gr.Slider(
60
+ minimum=1,
61
+ maximum=10,
62
+ step=1,
63
+ value=self.default_params["epochs"],
64
+ label="Epochs",
65
+ info="Number of training epochs"
66
+ )
67
+
68
+ with gr.Column():
69
+ max_length = gr.Slider(
70
+ minimum=128,
71
+ maximum=2048,
72
+ step=16,
73
+ value=self.default_params["max_length"],
74
+ label="Max Sequence Length",
75
+ info="Maximum token length for inputs"
76
+ )
77
+ use_lora = gr.Checkbox(
78
+ value=self.default_params["use_lora"],
79
+ label="Use LoRA for Parameter-Efficient Fine-tuning",
80
+ info="Recommended for faster training and lower memory usage"
81
+ )
82
+ lora_r = gr.Slider(
83
+ minimum=4,
84
+ maximum=64,
85
+ step=4,
86
+ value=self.default_params["lora_r"],
87
+ label="LoRA Rank (r)",
88
+ info="Rank of the LoRA update matrices",
89
+ visible=lambda: use_lora.value
90
+ )
91
+ lora_alpha = gr.Slider(
92
+ minimum=4,
93
+ maximum=64,
94
+ step=4,
95
+ value=self.default_params["lora_alpha"],
96
+ label="LoRA Alpha",
97
+ info="Scaling factor for LoRA updates",
98
+ visible=lambda: use_lora.value
99
+ )
100
+ eval_ratio = gr.Slider(
101
+ minimum=0.05,
102
+ maximum=0.3,
103
+ step=0.05,
104
+ value=self.default_params["eval_ratio"],
105
+ label="Validation Split Ratio",
106
+ info="Portion of data to use for validation"
107
+ )
108
+
109
+ with gr.TabItem("3. Training"):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ start_training_button = gr.Button("Start Fine-tuning")
113
+ stop_training_button = gr.Button("Stop Training", variant="stop")
114
+ training_status = gr.Textbox(label="Training Status", interactive=False)
115
+
116
+ with gr.Column():
117
+ progress_plot = gr.Plot(label="Training Progress")
118
+ refresh_plot_button = gr.Button("Refresh Plot")
119
+
120
+ with gr.TabItem("4. Evaluation & Export"):
121
+ with gr.Row():
122
+ with gr.Column():
123
+ test_prompt = gr.Textbox(
124
+ label="Test Prompt",
125
+ placeholder="Enter a prompt to test the model...",
126
+ lines=3
127
+ )
128
+ max_gen_length = gr.Slider(
129
+ minimum=10,
130
+ maximum=500,
131
+ step=10,
132
+ value=100,
133
+ label="Max Generation Length"
134
+ )
135
+ generate_button = gr.Button("Generate Text")
136
+ generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
137
+
138
+ with gr.Column():
139
+ export_format = gr.Radio(
140
+ ["pytorch", "tensorflow", "gguf"],
141
+ label="Export Format",
142
+ value="pytorch"
143
+ )
144
+ export_button = gr.Button("Export Model")
145
+ export_status = gr.Textbox(label="Export Status", interactive=False)
146
+
147
+ # Functionality
148
+ def preprocess_data(file, format_type):
149
+ try:
150
+ if file is None:
151
+ return "Please upload a file first."
152
+
153
+ # Process the uploaded file
154
+ dataset = self.model_instance.prepare_dataset(file.name, format_type)
155
+ self.model_instance.dataset = dataset
156
+
157
+ # Create a summary of the dataset
158
+ num_samples = len(dataset["train"])
159
+
160
+
161
+ # Sample a few examples
162
+ examples = dataset["train"].select(range(min(3, num_samples)))
163
+ sample_text = []
164
+ for ex in examples:
165
+ text_key = list(ex.keys())[0] if "text" not in ex else "text"
166
+ sample = ex[text_key]
167
+ if isinstance(sample, str):
168
+ sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
169
+
170
+ info = f"Dataset loaded successfully!\n"
171
+ info += f"Number of training examples: {num_samples}\n"
172
+ info += f"Sample data:\n" + "\n---\n".join(sample_text)
173
+
174
+ return info
175
+ except Exception as e:
176
+ return f"Error preprocessing data: {str(e)}"
177
+
178
+ def start_training(
179
+ model_name, learning_rate, batch_size, epochs, max_length,
180
+ use_lora, lora_r, lora_alpha, eval_ratio
181
+ ):
182
+ try:
183
+ if self.model_instance.dataset is None:
184
+ return "Please preprocess a dataset first."
185
+
186
+ # Validate parameters
187
+ if not model_name:
188
+ return "Please select a model."
189
+
190
+ # Prepare training parameters with proper type conversion
191
+ training_params = {
192
+ "model_name": str(model_name),
193
+ "learning_rate": float(learning_rate),
194
+ "batch_size": int(batch_size),
195
+ "epochs": int(epochs),
196
+ "max_length": int(max_length),
197
+ "use_lora": bool(use_lora),
198
+ "lora_r": int(lora_r) if use_lora else None,
199
+ "lora_alpha": int(lora_alpha) if use_lora else None,
200
+ "eval_ratio": float(eval_ratio),
201
+ "weight_decay": float(self.default_params["weight_decay"]),
202
+ "warmup_ratio": float(self.default_params["warmup_ratio"]),
203
+ "lora_dropout": float(self.default_params["lora_dropout"])
204
+ }
205
+
206
+ # Start training in a separate thread
207
+ import threading
208
+ def train_thread():
209
+ status = self.model_instance.train(training_params)
210
+ return status
211
+
212
+ thread = threading.Thread(target=train_thread)
213
+ thread.start()
214
+
215
+ return "Training started! Monitor the progress in the Training tab."
216
+ except Exception as e:
217
+ return f"Error starting training: {str(e)}"
218
+
219
+ def stop_training():
220
+ if self.model_instance.trainer is not None:
221
+ # Attempt to stop the trainer
222
+ self.model_instance.trainer.stop_training = True
223
+ return "Training stop signal sent. It may take a moment to complete the current step."
224
+ return "No active training to stop."
225
+
226
+ def update_progress_plot():
227
+ try:
228
+ return self.model_instance.plot_training_progress()
229
+ except Exception as e:
230
+ return None
231
+
232
+ def run_text_generation(prompt, max_length):
233
+ try:
234
+ if self.model_instance.model is None:
235
+ return "Please fine-tune a model first."
236
+
237
+ return self.model_instance.generate_text(prompt, int(max_length))
238
+ except Exception as e:
239
+ return f"Error generating text: {str(e)}"
240
+
241
+ def export_model_fn(format_type):
242
+ try:
243
+ if self.model_instance.model is None:
244
+ return "Please fine-tune a model first."
245
+
246
+ return self.model_instance.export_model(format_type)
247
+ except Exception as e:
248
+ return f"Error exporting model: {str(e)}"
249
+
250
+ # Connect UI components to functions
251
+ preprocess_button.click(
252
+ preprocess_data,
253
+ inputs=[file_upload, file_format],
254
+ outputs=dataset_info
255
+ )
256
+
257
+ start_training_button.click(
258
+ start_training,
259
+ inputs=[
260
+ model_name, learning_rate, batch_size, epochs, max_length,
261
+ use_lora, lora_r, lora_alpha, eval_ratio
262
+ ],
263
+ outputs=training_status
264
+ )
265
+
266
+ stop_training_button.click(
267
+ stop_training,
268
+ inputs=[],
269
+ outputs=training_status
270
+ )
271
+
272
+ refresh_plot_button.click(
273
+ update_progress_plot,
274
+ inputs=[],
275
+ outputs=progress_plot
276
+ )
277
+
278
+ generate_button.click(
279
+ run_text_generation,
280
+ inputs=[test_prompt, max_gen_length],
281
+ outputs=generated_output
282
+ )
283
+
284
+ export_button.click(
285
+ export_model_fn,
286
+ inputs=[export_format],
287
+ outputs=export_status
288
+ )
289
+
290
+ return app
291
+
292
+ if __name__ == '__main__':
293
+ ui = GemmaUI()
294
+ app = ui.create_ui()
295
+ app.launch(share=True)
Gemma-Finetune/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.36.0
3
+ unsloth>=0.1.0
4
+ gradio>=4.0.0
5
+ pandas>=1.5.0
6
+ numpy>=1.24.0
7
+ matplotlib>=3.7.0
8
+ peft>=0.7.0
9
+ datasets>=2.14.0
Gemma-Finetune/utils/__pycache__/check_dataset.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
Gemma-Finetune/utils/__pycache__/model.cpython-311.pyc ADDED
Binary file (28.4 kB). View file
 
Gemma-Finetune/utils/__pycache__/sample_dataset.cpython-311.pyc ADDED
Binary file (7.67 kB). View file
 
Gemma-Finetune/utils/check_dataset.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ def validate_dataset(self, file_path, format_type):
5
+ """
6
+ Validate and analyze the dataset format, providing detailed feedback
7
+
8
+ Parameters:
9
+ file_path (str): Path to the dataset file
10
+ format_type (str): File format (csv, jsonl, text)
11
+
12
+ Returns:
13
+ dict: Validation results including format, structure, and statistics
14
+ """
15
+ import pandas as pd
16
+ import json
17
+ import os
18
+ import re
19
+
20
+ validation_results = {
21
+ "is_valid": False,
22
+ "format": format_type,
23
+ "detected_structure": None,
24
+ "statistics": {},
25
+ "issues": [],
26
+ "recommendations": []
27
+ }
28
+
29
+ try:
30
+ # Check if file exists
31
+ if not os.path.exists(file_path):
32
+ validation_results["issues"].append(f"File not found: {file_path}")
33
+ return validation_results
34
+
35
+ # Check file size
36
+ file_size = os.path.getsize(file_path)
37
+ validation_results["statistics"]["file_size_bytes"] = file_size
38
+ validation_results["statistics"]["file_size_mb"] = round(file_size / (1024 * 1024), 2)
39
+
40
+ if file_size == 0:
41
+ validation_results["issues"].append("File is empty")
42
+ return validation_results
43
+
44
+ if format_type == "csv":
45
+ # Load CSV file
46
+ try:
47
+ df = pd.read_csv(file_path)
48
+ validation_results["statistics"]["total_rows"] = len(df)
49
+ validation_results["statistics"]["total_columns"] = len(df.columns)
50
+ validation_results["statistics"]["column_names"] = list(df.columns)
51
+
52
+ # Check for null values
53
+ null_counts = df.isnull().sum().to_dict()
54
+ validation_results["statistics"]["null_counts"] = null_counts
55
+
56
+ if validation_results["statistics"]["total_rows"] == 0:
57
+ validation_results["issues"].append("CSV file has no rows")
58
+ return validation_results
59
+
60
+ # Detect structure
61
+ if "instruction" in df.columns and "response" in df.columns:
62
+ validation_results["detected_structure"] = "instruction-response"
63
+ validation_results["is_valid"] = True
64
+ elif "input" in df.columns and "output" in df.columns:
65
+ validation_results["detected_structure"] = "input-output"
66
+ validation_results["is_valid"] = True
67
+ elif "prompt" in df.columns and "completion" in df.columns:
68
+ validation_results["detected_structure"] = "prompt-completion"
69
+ validation_results["is_valid"] = True
70
+ elif "text" in df.columns:
71
+ validation_results["detected_structure"] = "text-only"
72
+ validation_results["is_valid"] = True
73
+ else:
74
+ # Look for text columns
75
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
76
+ if text_columns:
77
+ validation_results["detected_structure"] = "custom"
78
+ validation_results["statistics"]["potential_text_columns"] = text_columns
79
+ validation_results["is_valid"] = True
80
+ validation_results["recommendations"].append(
81
+ f"Consider renaming columns to match standard formats: instruction/response, input/output, prompt/completion, or text"
82
+ )
83
+ else:
84
+ validation_results["issues"].append("No suitable text columns found in CSV")
85
+
86
+ # Check for short text
87
+ if validation_results["detected_structure"] == "instruction-response":
88
+ short_instructions = (df["instruction"].str.len() < 10).sum()
89
+ short_responses = (df["response"].str.len() < 10).sum()
90
+ validation_results["statistics"]["short_instructions"] = short_instructions
91
+ validation_results["statistics"]["short_responses"] = short_responses
92
+
93
+ if short_instructions > 0:
94
+ validation_results["issues"].append(f"Found {short_instructions} instructions shorter than 10 characters")
95
+ if short_responses > 0:
96
+ validation_results["issues"].append(f"Found {short_responses} responses shorter than 10 characters")
97
+
98
+ except Exception as e:
99
+ validation_results["issues"].append(f"Error parsing CSV: {str(e)}")
100
+ return validation_results
101
+
102
+ elif format_type == "jsonl":
103
+ try:
104
+ # Load JSONL file
105
+ data = []
106
+ with open(file_path, 'r', encoding='utf-8') as f:
107
+ for line_num, line in enumerate(f, 1):
108
+ line = line.strip()
109
+ if not line:
110
+ continue
111
+ try:
112
+ json_obj = json.loads(line)
113
+ data.append(json_obj)
114
+ except json.JSONDecodeError:
115
+ validation_results["issues"].append(f"Invalid JSON at line {line_num}")
116
+
117
+ validation_results["statistics"]["total_examples"] = len(data)
118
+
119
+ if len(data) == 0:
120
+ validation_results["issues"].append("No valid JSON objects found in file")
121
+ return validation_results
122
+
123
+ # Get sample of keys from first object
124
+ if data:
125
+ validation_results["statistics"]["sample_keys"] = list(data[0].keys())
126
+
127
+ # Detect structure
128
+ structures = []
129
+ for item in data:
130
+ if "instruction" in item and "response" in item:
131
+ structures.append("instruction-response")
132
+ elif "input" in item and "output" in item:
133
+ structures.append("input-output")
134
+ elif "prompt" in item and "completion" in item:
135
+ structures.append("prompt-completion")
136
+ elif "text" in item:
137
+ structures.append("text-only")
138
+ else:
139
+ structures.append("custom")
140
+
141
+ # Count structure types
142
+ from collections import Counter
143
+ structure_counts = Counter(structures)
144
+ validation_results["statistics"]["structure_counts"] = structure_counts
145
+
146
+ # Set detected structure to most common
147
+ if structures:
148
+ most_common = structure_counts.most_common(1)[0][0]
149
+ validation_results["detected_structure"] = most_common
150
+ validation_results["is_valid"] = True
151
+
152
+ # Check if mixed
153
+ if len(structure_counts) > 1:
154
+ validation_results["issues"].append(f"Mixed structures detected: {dict(structure_counts)}")
155
+ validation_results["recommendations"].append("Consider standardizing all records to the same structure")
156
+
157
+ except Exception as e:
158
+ validation_results["issues"].append(f"Error parsing JSONL: {str(e)}")
159
+ return validation_results
160
+
161
+ elif format_type == "text":
162
+ try:
163
+ # Read text file
164
+ with open(file_path, 'r', encoding='utf-8') as f:
165
+ content = f.read()
166
+
167
+ # Get basic stats
168
+ total_chars = len(content)
169
+ total_words = len(content.split())
170
+ total_lines = content.count('\n') + 1
171
+
172
+ validation_results["statistics"]["total_characters"] = total_chars
173
+ validation_results["statistics"]["total_words"] = total_words
174
+ validation_results["statistics"]["total_lines"] = total_lines
175
+
176
+ # Check if it's a single large document or multiple examples
177
+ paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
178
+ validation_results["statistics"]["total_paragraphs"] = len(paragraphs)
179
+
180
+ # Try to detect structure
181
+ # Look for common patterns like "Q: ... A: ...", "Input: ... Output: ..."
182
+ has_qa_pattern = re.search(r"Q:.*?A:", content, re.DOTALL) is not None
183
+ has_input_output = re.search(r"Input:.*?Output:", content, re.DOTALL) is not None
184
+ has_prompt_completion = re.search(r"Prompt:.*?Completion:", content, re.DOTALL) is not None
185
+
186
+ if has_qa_pattern:
187
+ validation_results["detected_structure"] = "Q&A-format"
188
+ elif has_input_output:
189
+ validation_results["detected_structure"] = "input-output-format"
190
+ elif has_prompt_completion:
191
+ validation_results["detected_structure"] = "prompt-completion-format"
192
+ elif len(paragraphs) > 1:
193
+ validation_results["detected_structure"] = "paragraphs"
194
+ else:
195
+ validation_results["detected_structure"] = "continuous-text"
196
+
197
+ validation_results["is_valid"] = True
198
+
199
+ if validation_results["detected_structure"] == "continuous-text" and total_chars < 1000:
200
+ validation_results["issues"].append("Text file is very short for fine-tuning")
201
+ validation_results["recommendations"].append("Consider adding more content or examples")
202
+
203
+ except Exception as e:
204
+ validation_results["issues"].append(f"Error parsing text file: {str(e)}")
205
+ return validation_results
206
+ else:
207
+ validation_results["issues"].append(f"Unsupported file format: {format_type}")
208
+ return validation_results
209
+
210
+ # General recommendations
211
+ if validation_results["is_valid"]:
212
+ if not validation_results["issues"]:
213
+ validation_results["recommendations"].append("Dataset looks good and ready for fine-tuning!")
214
+ else:
215
+ validation_results["recommendations"].append("Address the issues above before proceeding with fine-tuning")
216
+
217
+ return validation_results
218
+
219
+ except Exception as e:
220
+ validation_results["issues"].append(f"Unexpected error: {str(e)}")
221
+ return validation_results
222
+
223
+ def generate_dataset_report(validation_results):
224
+ """
225
+ Generate a user-friendly report from validation results
226
+
227
+ Parameters:
228
+ validation_results (dict): Results from validate_dataset
229
+
230
+ Returns:
231
+ str: Formatted report
232
+ """
233
+ report = []
234
+
235
+ # Add header
236
+ report.append("# Dataset Validation Report")
237
+ report.append("")
238
+
239
+ # Add validation status
240
+ if validation_results["is_valid"]:
241
+ report.append("✅ Dataset is valid and can be used for fine-tuning")
242
+ else:
243
+ report.append("❌ Dataset has issues that need to be addressed")
244
+ report.append("")
245
+
246
+ # Add format info
247
+ report.append(f"**File Format:** {validation_results['format']}")
248
+ report.append(f"**Detected Structure:** {validation_results['detected_structure']}")
249
+ report.append("")
250
+
251
+ # Add statistics
252
+ report.append("## Statistics")
253
+ for key, value in validation_results["statistics"].items():
254
+ # Format the key for better readability
255
+ formatted_key = key.replace("_", " ").title()
256
+ report.append(f"- {formatted_key}: {value}")
257
+ report.append("")
258
+
259
+ # Add issues
260
+ if validation_results["issues"]:
261
+ report.append("## Issues")
262
+ for issue in validation_results["issues"]:
263
+ report.append(f"- ⚠️ {issue}")
264
+ report.append("")
265
+
266
+ # Add recommendations
267
+ if validation_results["recommendations"]:
268
+ report.append("## Recommendations")
269
+ for recommendation in validation_results["recommendations"]:
270
+ report.append(f"- 💡 {recommendation}")
271
+
272
+ return "\n".join(report)
Gemma-Finetune/utils/model.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+ from datetime import datetime
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TrainingArguments,
16
+ Trainer,
17
+ DataCollatorForLanguageModeling,
18
+ TrainerCallback
19
+ )
20
+ from peft import (
21
+ LoraConfig,
22
+ get_peft_model,
23
+ prepare_model_for_kbit_training
24
+ )
25
+ from datasets import load_dataset
26
+ from unsloth import FastModel
27
+
28
+
29
+ class GemmaFineTuning:
30
+ def __init__(self):
31
+ self.model = None
32
+ self.tokenizer = None
33
+ self.dataset = None
34
+ self.trainer = None
35
+ self.training_history = {"loss": [], "eval_loss": [], "step": []}
36
+ self.model_save_path = None
37
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ self.fourbit_models = [
40
+ "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
41
+ "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
42
+ "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
43
+ "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
44
+ ]
45
+ # Default hyperparameters
46
+ self.default_params = {
47
+ "model_name": "google/gemma-2b",
48
+ "learning_rate": 2e-5,
49
+ "batch_size": 8,
50
+ "epochs": 3,
51
+ "max_length": 512,
52
+ "weight_decay": 0.01,
53
+ "warmup_ratio": 0.1,
54
+ "use_lora": True,
55
+ "lora_r": 16,
56
+ "lora_alpha": 32,
57
+ "lora_dropout": 0.05,
58
+ "eval_ratio": 0.1,
59
+ }
60
+
61
+ def load_model_and_tokenizer(self, model_name: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
62
+ """Load the model and tokenizer"""
63
+ try:
64
+ # Map UI model names to actual model IDs
65
+ model_mapping = {
66
+ "google/gemma-2b": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
67
+ "google/gemma-7b": "unsloth/gemma-7b-it-unsloth-bnb-4bit",
68
+ "google/gemma-2b-it": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
69
+ "google/gemma-7b-it": "unsloth/gemma-7b-it-unsloth-bnb-4bit"
70
+ }
71
+
72
+ actual_model_name = model_mapping.get(model_name, model_name)
73
+
74
+ model, tokenizer = FastModel.from_pretrained(
75
+ model_name=actual_model_name,
76
+ max_seq_length=2048,
77
+ load_in_4bit=True,
78
+ load_in_8bit=False,
79
+ full_finetuning=False,
80
+ )
81
+
82
+ # Move model to device
83
+ model = model.to(self.device)
84
+ return model, tokenizer
85
+
86
+ except Exception as e:
87
+ raise ValueError(f"Error loading model {model_name}: {str(e)}")
88
+
89
+ def prepare_dataset(self, file_path, format_type):
90
+ """
91
+ Prepare and normalize dataset from various formats
92
+
93
+ Parameters:
94
+ file_path (str): Path to the dataset file
95
+ format_type (str): File format (csv, jsonl, text)
96
+
97
+ Returns:
98
+ dict: Dataset dictionary with train split
99
+ """
100
+ import pandas as pd
101
+ import json
102
+ import os
103
+ from datasets import Dataset, DatasetDict
104
+
105
+ try:
106
+ if format_type == "csv":
107
+ # Load CSV file
108
+ df = pd.read_csv(file_path)
109
+
110
+ # Check if the CSV has the expected columns (looking for either instruction-response pairs or text)
111
+ if "instruction" in df.columns and "response" in df.columns:
112
+ # Instruction-following dataset format
113
+ dataset_format = "instruction-response"
114
+ # Ensure no nulls
115
+ df = df.dropna(subset=["instruction", "response"])
116
+ # Create formatted text by combining instruction and response
117
+ df["text"] = df.apply(lambda row: f"<instruction>{row['instruction']}</instruction>\n<response>{row['response']}</response>", axis=1)
118
+ elif "input" in df.columns and "output" in df.columns:
119
+ # Another common format
120
+ dataset_format = "input-output"
121
+ df = df.dropna(subset=["input", "output"])
122
+ df["text"] = df.apply(lambda row: f"<input>{row['input']}</input>\n<output>{row['output']}</output>", axis=1)
123
+ elif "prompt" in df.columns and "completion" in df.columns:
124
+ # OpenAI-style format
125
+ dataset_format = "prompt-completion"
126
+ df = df.dropna(subset=["prompt", "completion"])
127
+ df["text"] = df.apply(lambda row: f"<prompt>{row['prompt']}</prompt>\n<completion>{row['completion']}</completion>", axis=1)
128
+ elif "text" in df.columns:
129
+ # Simple text format
130
+ dataset_format = "text-only"
131
+ df = df.dropna(subset=["text"])
132
+ else:
133
+ # Try to infer format from the first text column
134
+ text_columns = [col for col in df.columns if df[col].dtype == 'object']
135
+ if len(text_columns) > 0:
136
+ dataset_format = "inferred"
137
+ df["text"] = df[text_columns[0]]
138
+ df = df.dropna(subset=["text"])
139
+ else:
140
+ raise ValueError("CSV file must contain either 'instruction'/'response', 'input'/'output', 'prompt'/'completion', or 'text' columns")
141
+
142
+ # Create dataset
143
+ dataset = Dataset.from_pandas(df)
144
+
145
+ elif format_type == "jsonl":
146
+ # Load JSONL file
147
+ with open(file_path, 'r', encoding='utf-8') as f:
148
+ data = [json.loads(line) for line in f if line.strip()]
149
+
150
+ # Check and normalize the format
151
+ normalized_data = []
152
+ for item in data:
153
+ normalized_item = {}
154
+
155
+ # Try to find either instruction-response pairs or text
156
+ if "instruction" in item and "response" in item:
157
+ normalized_item["text"] = f"<instruction>{item['instruction']}</instruction>\n<response>{item['response']}</response>"
158
+ normalized_item["instruction"] = item["instruction"]
159
+ normalized_item["response"] = item["response"]
160
+ elif "input" in item and "output" in item:
161
+ normalized_item["text"] = f"<input>{item['input']}</input>\n<output>{item['output']}</output>"
162
+ normalized_item["input"] = item["input"]
163
+ normalized_item["output"] = item["output"]
164
+ elif "prompt" in item and "completion" in item:
165
+ normalized_item["text"] = f"<prompt>{item['prompt']}</prompt>\n<completion>{item['completion']}</completion>"
166
+ normalized_item["prompt"] = item["prompt"]
167
+ normalized_item["completion"] = item["completion"]
168
+ elif "text" in item:
169
+ normalized_item["text"] = item["text"]
170
+ else:
171
+ # Try to infer from the first string value
172
+ text_keys = [k for k, v in item.items() if isinstance(v, str) and len(v.strip()) > 0]
173
+ if text_keys:
174
+ normalized_item["text"] = item[text_keys[0]]
175
+ else:
176
+ continue # Skip this item if no usable text found
177
+
178
+ normalized_data.append(normalized_item)
179
+
180
+ if not normalized_data:
181
+ raise ValueError("No valid data items found in the JSONL file")
182
+
183
+ # Create dataset
184
+ dataset = Dataset.from_list(normalized_data)
185
+
186
+ elif format_type == "text":
187
+ # For text files, split by newlines and create entries
188
+ with open(file_path, 'r', encoding='utf-8') as f:
189
+ content = f.read()
190
+
191
+ # Check if it's a single large document or multiple examples
192
+ # If file size > 10KB, try to split into paragraphs
193
+ if os.path.getsize(file_path) > 10240:
194
+ # Split by double newlines (paragraphs)
195
+ paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
196
+ # Filter out very short paragraphs (less than 20 chars)
197
+ paragraphs = [p for p in paragraphs if len(p) >= 20]
198
+ data = [{"text": p} for p in paragraphs]
199
+ else:
200
+ # Treat as a single example
201
+ data = [{"text": content}]
202
+
203
+ # Create dataset
204
+ dataset = Dataset.from_list(data)
205
+ else:
206
+ raise ValueError(f"Unsupported file format: {format_type}")
207
+
208
+ # Return as a DatasetDict with a train split
209
+ return DatasetDict({"train": dataset})
210
+
211
+ except Exception as e:
212
+ import traceback
213
+ error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
214
+ print(error_msg)
215
+ raise ValueError(error_msg)
216
+
217
+ def chunk_text(self, text: str, chunk_size: int) -> List[str]:
218
+ """Split text into chunks of approximately chunk_size characters"""
219
+ words = text.split()
220
+ chunks = []
221
+ current_chunk = []
222
+ current_length = 0
223
+
224
+ for word in words:
225
+ if current_length + len(word) + 1 > chunk_size and current_chunk:
226
+ chunks.append(" ".join(current_chunk))
227
+ current_chunk = [word]
228
+ current_length = len(word)
229
+ else:
230
+ current_chunk.append(word)
231
+ current_length += len(word) + 1 # +1 for the space
232
+
233
+ if current_chunk:
234
+ chunks.append(" ".join(current_chunk))
235
+
236
+ return chunks
237
+
238
+ def preprocess_dataset(self, dataset, tokenizer, max_length):
239
+ """
240
+ Tokenize and format the dataset for training
241
+
242
+ Parameters:
243
+ dataset (DatasetDict): Dataset dictionary with train and validation splits
244
+ tokenizer: HuggingFace tokenizer
245
+ max_length (int): Maximum sequence length
246
+
247
+ Returns:
248
+ DatasetDict: Tokenized dataset ready for training
249
+ """
250
+ def tokenize_function(examples):
251
+ # Check if the dataset has both input and target text columns
252
+ if "text" in examples:
253
+ texts = examples["text"]
254
+ inputs = tokenizer(
255
+ texts,
256
+ padding="max_length",
257
+ truncation=True,
258
+ max_length=max_length,
259
+ return_tensors="pt"
260
+ )
261
+ inputs["labels"] = inputs["input_ids"].clone()
262
+ return inputs
263
+ else:
264
+ # Try to find text columns based on common naming patterns
265
+ potential_text_cols = [col for col in examples.keys() if isinstance(examples[col], list) and
266
+ all(isinstance(item, str) for item in examples[col])]
267
+
268
+ if not potential_text_cols:
269
+ raise ValueError("No suitable text columns found in the dataset")
270
+
271
+ # Use the first text column found
272
+ text_col = potential_text_cols[0]
273
+ texts = examples[text_col]
274
+
275
+ inputs = tokenizer(
276
+ texts,
277
+ padding="max_length",
278
+ truncation=True,
279
+ max_length=max_length,
280
+ return_tensors="pt"
281
+ )
282
+ inputs["labels"] = inputs["input_ids"].clone()
283
+ return inputs
284
+
285
+ # Apply tokenization to each split
286
+ tokenized_dataset = {}
287
+ for split, ds in dataset.items():
288
+ tokenized_dataset[split] = ds.map(
289
+ tokenize_function,
290
+ batched=True,
291
+ remove_columns=ds.column_names
292
+ )
293
+
294
+ return tokenized_dataset
295
+
296
+ def prepare_training_args(self, params: Dict) -> TrainingArguments:
297
+ """Set up training arguments based on hyperparameters"""
298
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
299
+ self.model_save_path = f"gemma-finetuned-{timestamp}"
300
+
301
+ args = TrainingArguments(
302
+ output_dir=self.model_save_path,
303
+ per_device_train_batch_size=params.get("batch_size", self.default_params["batch_size"]),
304
+ gradient_accumulation_steps=4,
305
+ per_device_eval_batch_size=params.get("batch_size", self.default_params["batch_size"]),
306
+ learning_rate=params.get("learning_rate", self.default_params["learning_rate"]),
307
+ num_train_epochs=params.get("epochs", self.default_params["epochs"]),
308
+ warmup_ratio=params.get("warmup_ratio", self.default_params["warmup_ratio"]),
309
+ weight_decay=params.get("weight_decay", self.default_params["weight_decay"]),
310
+ logging_steps=1,
311
+ evaluation_strategy="steps" if params.get("eval_ratio", 0) > 0 else "no",
312
+ eval_steps=100 if params.get("eval_ratio", 0) > 0 else None,
313
+ save_strategy="steps",
314
+ save_steps=100,
315
+ save_total_limit=2,
316
+ load_best_model_at_end=True if params.get("eval_ratio", 0) > 0 else False,
317
+ report_to="none"
318
+ )
319
+ return args
320
+
321
+ def train(self, training_params: Dict) -> str:
322
+ """Main training method that handles the complete training pipeline"""
323
+ try:
324
+ if self.dataset is None:
325
+ return "Error: No dataset loaded. Please preprocess a dataset first."
326
+
327
+ # Reset training history
328
+ self.training_history = {"loss": [], "eval_loss": [], "step": []}
329
+
330
+ # Load model and tokenizer if not already loaded or if model name changed
331
+ current_model_name = training_params.get("model_name", self.default_params["model_name"])
332
+ if (self.model is None or self.tokenizer is None or
333
+ getattr(self, '_current_model_name', None) != current_model_name):
334
+
335
+ self.model, self.tokenizer = self.load_model_and_tokenizer(current_model_name)
336
+ self._current_model_name = current_model_name
337
+
338
+ # Create validation split if needed
339
+ eval_ratio = float(training_params.get("eval_ratio", self.default_params["eval_ratio"]))
340
+ if eval_ratio > 0 and "validation" not in self.dataset:
341
+ split_dataset = self.dataset["train"].train_test_split(test_size=eval_ratio)
342
+ self.dataset = {
343
+ "train": split_dataset["train"],
344
+ "validation": split_dataset["test"]
345
+ }
346
+
347
+ # Apply LoRA if selected
348
+ if training_params.get("use_lora", self.default_params["use_lora"]):
349
+ self.model = self.setup_lora(self.model, {
350
+ "lora_r": int(training_params.get("lora_r", self.default_params["lora_r"])),
351
+ "lora_alpha": int(training_params.get("lora_alpha", self.default_params["lora_alpha"])),
352
+ "lora_dropout": float(training_params.get("lora_dropout", self.default_params["lora_dropout"]))
353
+ })
354
+
355
+ # Preprocess dataset
356
+ max_length = int(training_params.get("max_length", self.default_params["max_length"]))
357
+ tokenized_dataset = self.preprocess_dataset(self.dataset, self.tokenizer, max_length)
358
+
359
+ # Update training arguments with proper type conversion
360
+ training_args = self.prepare_training_args({
361
+ "batch_size": int(training_params.get("batch_size", self.default_params["batch_size"])),
362
+ "learning_rate": float(training_params.get("learning_rate", self.default_params["learning_rate"])),
363
+ "epochs": int(training_params.get("epochs", self.default_params["epochs"])),
364
+ "weight_decay": float(training_params.get("weight_decay", self.default_params["weight_decay"])),
365
+ "warmup_ratio": float(training_params.get("warmup_ratio", self.default_params["warmup_ratio"])),
366
+ "eval_ratio": eval_ratio
367
+ })
368
+
369
+ # Create trainer with proper callback
370
+ self.trainer = self.create_trainer(
371
+ self.model,
372
+ self.tokenizer,
373
+ tokenized_dataset,
374
+ training_args
375
+ )
376
+
377
+ # Start training
378
+ self.trainer.train()
379
+
380
+ # Save the model
381
+ save_path = f"models/gemma-finetuned-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
382
+ os.makedirs(save_path, exist_ok=True)
383
+ self.trainer.save_model(save_path)
384
+ self.tokenizer.save_pretrained(save_path)
385
+ self.model_save_path = save_path
386
+
387
+ return f"Training completed successfully! Model saved to {save_path}"
388
+
389
+ except Exception as e:
390
+ import traceback
391
+ return f"Error during training: {str(e)}\n{traceback.format_exc()}"
392
+
393
+ def setup_lora(self, model, params: Dict) -> torch.nn.Module:
394
+ """Configure LoRA for parameter-efficient fine-tuning"""
395
+ # Prepare the model for training if using 8-bit or 4-bit quantization
396
+ if hasattr(model, "is_quantized") and model.is_quantized:
397
+ model = prepare_model_for_kbit_training(model)
398
+
399
+ lora_config = LoraConfig(
400
+ r=params["lora_r"],
401
+ lora_alpha=params["lora_alpha"],
402
+ target_modules=["q_proj", "k_proj", "v_proj"],
403
+ lora_dropout=params["lora_dropout"],
404
+ bias="none",
405
+ task_type="CAUSAL_LM",
406
+ )
407
+
408
+ model = FastModel.get_peft_model(
409
+ model,
410
+ finetune_vision_layers = False, # Turn off for just text!
411
+ finetune_language_layers = True, # Should leave on!
412
+ finetune_attention_modules = True, # Attention good for GRPO
413
+ finetune_mlp_modules = True, # SHould leave on always!
414
+
415
+ r = 8, # Larger = higher accuracy, but might overfit
416
+ lora_alpha = 8, # Recommended alpha == r at least
417
+ lora_dropout = 0,
418
+ bias = "none",
419
+ random_state = 3407,
420
+ )
421
+ model.print_trainable_parameters()
422
+ model = model.to(self.device)
423
+ return model
424
+
425
+ def create_trainer(self, model, tokenizer, dataset, training_args):
426
+ """Set up the Trainer for model fine-tuning"""
427
+ # Create data collator
428
+ data_collator = DataCollatorForLanguageModeling(
429
+ tokenizer=tokenizer,
430
+ mlm=False
431
+ )
432
+
433
+ # Custom callback to store training history
434
+ class TrainingCallback(TrainerCallback):
435
+ def __init__(self, app):
436
+ self.app = app
437
+
438
+ def on_log(self, args, state, control, logs=None, **kwargs):
439
+ if logs:
440
+ for key in ['loss', 'eval_loss']:
441
+ if key in logs:
442
+ self.app.training_history[key].append(logs[key])
443
+ if 'step' in logs:
444
+ self.app.training_history['step'].append(logs['step'])
445
+
446
+ # Create trainer
447
+ trainer = Trainer(
448
+ model=model,
449
+ args=training_args,
450
+ train_dataset=dataset["train"],
451
+ eval_dataset=dataset["validation"] if "validation" in dataset else None,
452
+ data_collator=data_collator,
453
+ callbacks=[TrainingCallback]
454
+ )
455
+
456
+ return trainer
457
+
458
+ def plot_training_progress(self):
459
+ """Generate a plot of the training progress"""
460
+ if not self.training_history["loss"]:
461
+ return None
462
+
463
+ plt.figure(figsize=(10, 6))
464
+ plt.plot(self.training_history["step"], self.training_history["loss"], label="Training Loss")
465
+
466
+ if self.training_history["eval_loss"]:
467
+ # Get the steps where eval happened
468
+ eval_steps = self.training_history["step"][:len(self.training_history["eval_loss"])]
469
+ plt.plot(eval_steps, self.training_history["eval_loss"], label="Validation Loss", linestyle="--")
470
+
471
+ plt.xlabel("Training Steps")
472
+ plt.ylabel("Loss")
473
+ plt.title("Training Progress")
474
+ plt.legend()
475
+ plt.grid(True)
476
+
477
+ return plt
478
+
479
+ def export_model(self, output_format: str) -> str:
480
+ """Export the fine-tuned model in various formats"""
481
+ if self.model is None or self.model_save_path is None:
482
+ return "No model has been trained yet."
483
+
484
+ export_path = f"{self.model_save_path}/exported_{output_format}"
485
+ os.makedirs(export_path, exist_ok=True)
486
+
487
+ if output_format == "pytorch":
488
+ # Save as PyTorch format
489
+ self.model.save_pretrained(export_path)
490
+ self.tokenizer.save_pretrained(export_path)
491
+ return f"Model exported in PyTorch format to {export_path}"
492
+
493
+ elif output_format == "tensorflow":
494
+ # Convert to TensorFlow format
495
+ try:
496
+ from transformers.modeling_tf_utils import convert_pt_to_tf
497
+
498
+ # First save the PyTorch model
499
+ self.model.save_pretrained(export_path)
500
+ self.tokenizer.save_pretrained(export_path)
501
+
502
+ # Then convert to TF SavedModel format
503
+ tf_model = convert_pt_to_tf(self.model)
504
+ tf_model.save_pretrained(f"{export_path}/tf_saved_model")
505
+
506
+ return f"Model exported in TensorFlow format to {export_path}/tf_saved_model"
507
+ except Exception as e:
508
+ return f"Failed to export as TensorFlow model: {str(e)}"
509
+
510
+ elif output_format == "gguf":
511
+ # Export as GGUF format for local inference
512
+ try:
513
+ import subprocess
514
+
515
+ # First save the model in PyTorch format
516
+ self.model.save_pretrained(export_path)
517
+ self.tokenizer.save_pretrained(export_path)
518
+
519
+ # Use llama.cpp's conversion script (must be installed)
520
+ subprocess.run([
521
+ "python", "-m", "llama_cpp.convert",
522
+ "--outtype", "gguf",
523
+ "--outfile", f"{export_path}/model.gguf",
524
+ export_path
525
+ ])
526
+
527
+ return f"Model exported in GGUF format to {export_path}/model.gguf"
528
+ except Exception as e:
529
+ return f"Failed to export as GGUF model: {str(e)}"
530
+
531
+ else:
532
+ return f"Unsupported export format: {output_format}"
533
+
534
+ def generate_text(self, prompt: str, max_length: int = 100) -> str:
535
+ """Generate text using the fine-tuned model"""
536
+ if self.model is None or self.tokenizer is None:
537
+ return "No model has been loaded or fine-tuned yet."
538
+
539
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
540
+
541
+ with torch.no_grad():
542
+ outputs = self.model.generate(
543
+ **inputs,
544
+ max_length=max_length + inputs.input_ids.shape[1],
545
+ temperature=0.7,
546
+ top_p=0.9,
547
+ do_sample=True,
548
+ pad_token_id=self.tokenizer.pad_token_id
549
+ )
550
+
551
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
552
+ return generated_text
Gemma-Finetune/utils/sample_dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import os
4
+ import random
5
+
6
+ def generate_sample_datasets(output_dir="./sample_datasets"):
7
+ """
8
+ Generate sample datasets in CSV, JSONL, and text formats for testing the Gemma fine-tuning UI
9
+
10
+ Parameters:
11
+ output_dir (str): Directory to save the sample datasets
12
+ """
13
+ # Create output directory if it doesn't exist
14
+ os.makedirs(output_dir, exist_ok=True)
15
+
16
+ # Sample data for instruction-response pairs
17
+ instruction_data = [
18
+ {"instruction": "Explain the concept of machine learning in simple terms.",
19
+ "response": "Machine learning is when computers learn from examples rather than being explicitly programmed. It's like teaching a child by showing examples instead of giving exact rules."},
20
+ {"instruction": "What are the health benefits of regular exercise?",
21
+ "response": "Regular exercise improves cardiovascular health, strengthens muscles and bones, helps control weight, reduces stress and anxiety, improves sleep quality, and increases overall energy levels."},
22
+ {"instruction": "How do I make a basic pasta sauce?",
23
+ "response": "For a basic tomato pasta sauce, sauté minced garlic and onions in olive oil until soft. Add canned crushed tomatoes, salt, pepper, and Italian herbs. Simmer for 20-30 minutes, stirring occasionally. Add basil at the end for fresh flavor."},
24
+ {"instruction": "Summarize the water cycle.",
25
+ "response": "The water cycle is the continuous movement of water on, above, and below Earth's surface. It involves evaporation (water turning to vapor), condensation (vapor forming clouds), precipitation (rain or snow), and collection (water returning to bodies of water)."},
26
+ {"instruction": "What is the difference between a simile and a metaphor?",
27
+ "response": "A simile compares two things using 'like' or 'as' (e.g., 'as brave as a lion'). A metaphor directly states that one thing is another (e.g., 'he is a lion in battle'). Both are figurative language techniques used to create vivid descriptions."}
28
+ ]
29
+
30
+ # Generate more instruction-response pairs
31
+ topics = ["history", "science", "literature", "cooking", "technology", "health", "travel", "sports", "music", "art"]
32
+ question_starters = ["Explain", "Describe", "How to", "What is", "Why does", "Compare", "Summarize", "List ways to", "Define", "Analyze"]
33
+
34
+ for _ in range(20):
35
+ topic = random.choice(topics)
36
+ starter = random.choice(question_starters)
37
+ instruction = f"{starter} {topic.lower()} {random.choice(['concepts', 'principles', 'ideas', 'techniques', 'methods'])}"
38
+ response = f"This is a sample response about {topic} that would be more detailed in a real dataset. It would contain multiple sentences explaining {topic} concepts in depth."
39
+ instruction_data.append({"instruction": instruction, "response": response})
40
+
41
+ # Create a dictionary to store sample datasets
42
+ datasets = {}
43
+
44
+ # 1. Create CSV in instruction-response format
45
+ df_instruction = pd.DataFrame(instruction_data)
46
+ datasets["instruction_response.csv"] = df_instruction
47
+
48
+ # 2. Create CSV in input-output format
49
+ input_output_data = [{"input": item["instruction"], "output": item["response"]} for item in instruction_data]
50
+ df_input_output = pd.DataFrame(input_output_data)
51
+ datasets["input_output.csv"] = df_input_output
52
+
53
+ # 3. Create CSV in text-only format
54
+ text_data = [{"text": f"Q: {item['instruction']}\nA: {item['response']}"} for item in instruction_data]
55
+ df_text = pd.DataFrame(text_data)
56
+ datasets["text_only.csv"] = df_text
57
+
58
+ # 4. Create CSV with non-standard format
59
+ custom_data = [{"question": item["instruction"], "answer": item["response"]} for item in instruction_data]
60
+ df_custom = pd.DataFrame(custom_data)
61
+ datasets["custom_format.csv"] = df_custom
62
+
63
+ # 5. Create JSONL in instruction-response format
64
+ jsonl_instruction = instruction_data
65
+ datasets["instruction_response.jsonl"] = jsonl_instruction
66
+
67
+ # 6. Create JSONL in prompt-completion format
68
+ prompt_completion_data = [{"prompt": item["instruction"], "completion": item["response"]} for item in instruction_data]
69
+ datasets["prompt_completion.jsonl"] = prompt_completion_data
70
+
71
+ # 7. Create JSONL with non-standard format
72
+ jsonl_custom = [{"query": item["instruction"], "result": item["response"]} for item in instruction_data]
73
+ datasets["custom_format.jsonl"] = jsonl_custom
74
+
75
+ # 8. Create text format (paragraphs)
76
+ text_paragraphs = "\n\n".join([f"Q: {item['instruction']}\nA: {item['response']}" for item in instruction_data])
77
+ datasets["paragraphs.txt"] = text_paragraphs
78
+
79
+ # 9. Create text format (single examples per line)
80
+ text_lines = "\n".join([f"{item['instruction']} => {item['response']}" for item in instruction_data])
81
+ datasets["single_lines.txt"] = text_lines
82
+
83
+ # Save all datasets
84
+ for filename, data in datasets.items():
85
+ file_path = os.path.join(output_dir, filename)
86
+
87
+ if filename.endswith('.csv'):
88
+ data.to_csv(file_path, index=False)
89
+ elif filename.endswith('.jsonl'):
90
+ with open(file_path, 'w', encoding='utf-8') as f:
91
+ for item in data:
92
+ f.write(json.dumps(item) + '\n')
93
+ elif filename.endswith('.txt'):
94
+ with open(file_path, 'w', encoding='utf-8') as f:
95
+ f.write(data)
96
+
97
+ print(f"Sample datasets generated in {output_dir}")
98
+ return list(datasets.keys())
99
+
100
+ # if __name__ == "__main__":
101
+ # # Generate sample datasets
102
+ # generated_files = generate_sample_datasets()
103
+ # print(f"Generated {len(generated_files)} sample dataset files:")
104
+ # for file in generated_files:
105
+ # print(f" - {file}")
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Gemma Finetuner
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.21.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gemma_Finetuner
3
+ app_file: /content/Gemma-Finetune/main.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.21.0
 
 
6
  ---
 
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ unsloth
4
+ peft
5
+ datasets
6
+ torch
sample_data/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This directory includes a few sample datasets to get you started.
2
+
3
+ * `california_housing_data*.csv` is California housing data from the 1990 US
4
+ Census; more information is available at:
5
+ https://docs.google.com/document/d/e/2PACX-1vRhYtsvc5eOR2FWNCwaBiKL6suIOrxJig8LcSBbmCbyYsayia_DvPOOBlXZ4CAlQ5nlDD8kTaIDRwrN/pub
6
+
7
+ * `mnist_*.csv` is a small sample of the
8
+ [MNIST database](https://en.wikipedia.org/wiki/MNIST_database), which is
9
+ described at: http://yann.lecun.com/exdb/mnist/
10
+
11
+ * `anscombe.json` contains a copy of
12
+ [Anscombe's quartet](https://en.wikipedia.org/wiki/Anscombe%27s_quartet); it
13
+ was originally described in
14
+
15
+ Anscombe, F. J. (1973). 'Graphs in Statistical Analysis'. American
16
+ Statistician. 27 (1): 17-21. JSTOR 2682899.
17
+
18
+ and our copy was prepared by the
19
+ [vega_datasets library](https://github.com/altair-viz/vega_datasets/blob/4f67bdaad10f45e3549984e17e1b3088c731503d/vega_datasets/_data/anscombe.json).
sample_data/anscombe.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {"Series":"I", "X":10.0, "Y":8.04},
3
+ {"Series":"I", "X":8.0, "Y":6.95},
4
+ {"Series":"I", "X":13.0, "Y":7.58},
5
+ {"Series":"I", "X":9.0, "Y":8.81},
6
+ {"Series":"I", "X":11.0, "Y":8.33},
7
+ {"Series":"I", "X":14.0, "Y":9.96},
8
+ {"Series":"I", "X":6.0, "Y":7.24},
9
+ {"Series":"I", "X":4.0, "Y":4.26},
10
+ {"Series":"I", "X":12.0, "Y":10.84},
11
+ {"Series":"I", "X":7.0, "Y":4.81},
12
+ {"Series":"I", "X":5.0, "Y":5.68},
13
+
14
+ {"Series":"II", "X":10.0, "Y":9.14},
15
+ {"Series":"II", "X":8.0, "Y":8.14},
16
+ {"Series":"II", "X":13.0, "Y":8.74},
17
+ {"Series":"II", "X":9.0, "Y":8.77},
18
+ {"Series":"II", "X":11.0, "Y":9.26},
19
+ {"Series":"II", "X":14.0, "Y":8.10},
20
+ {"Series":"II", "X":6.0, "Y":6.13},
21
+ {"Series":"II", "X":4.0, "Y":3.10},
22
+ {"Series":"II", "X":12.0, "Y":9.13},
23
+ {"Series":"II", "X":7.0, "Y":7.26},
24
+ {"Series":"II", "X":5.0, "Y":4.74},
25
+
26
+ {"Series":"III", "X":10.0, "Y":7.46},
27
+ {"Series":"III", "X":8.0, "Y":6.77},
28
+ {"Series":"III", "X":13.0, "Y":12.74},
29
+ {"Series":"III", "X":9.0, "Y":7.11},
30
+ {"Series":"III", "X":11.0, "Y":7.81},
31
+ {"Series":"III", "X":14.0, "Y":8.84},
32
+ {"Series":"III", "X":6.0, "Y":6.08},
33
+ {"Series":"III", "X":4.0, "Y":5.39},
34
+ {"Series":"III", "X":12.0, "Y":8.15},
35
+ {"Series":"III", "X":7.0, "Y":6.42},
36
+ {"Series":"III", "X":5.0, "Y":5.73},
37
+
38
+ {"Series":"IV", "X":8.0, "Y":6.58},
39
+ {"Series":"IV", "X":8.0, "Y":5.76},
40
+ {"Series":"IV", "X":8.0, "Y":7.71},
41
+ {"Series":"IV", "X":8.0, "Y":8.84},
42
+ {"Series":"IV", "X":8.0, "Y":8.47},
43
+ {"Series":"IV", "X":8.0, "Y":7.04},
44
+ {"Series":"IV", "X":8.0, "Y":5.25},
45
+ {"Series":"IV", "X":19.0, "Y":12.50},
46
+ {"Series":"IV", "X":8.0, "Y":5.56},
47
+ {"Series":"IV", "X":8.0, "Y":7.91},
48
+ {"Series":"IV", "X":8.0, "Y":6.89}
49
+ ]
sample_data/california_housing_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/california_housing_train.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/mnist_test.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51c292478d94ec3a01461bdfa82eb0885d262eb09e615679b2d69dedb6ad09e7
3
+ size 18289443
sample_data/mnist_train_small.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ef64781aa03180f4f5ce504314f058f5d0227277df86060473d973cf43b033e
3
+ size 36523880
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothAlignPropConfig(AlignPropConfig):
44
+ """
45
+
46
+ Configuration class for the [`AlignPropTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (defaults to the file name without the extension).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed for reproducibility.
59
+ log_with (`str` or `None`, *optional*, defaults to `None`):
60
+ Log with either `"wandb"` or `"tensorboard"`. Check
61
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
+ log_image_freq (`int`, *optional*, defaults to `1`):
63
+ Frequency for logging images.
64
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the tracker (e.g., `wandb_project`).
66
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator.
68
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
+ Name of project to use for tracking.
72
+ logdir (`str`, *optional*, defaults to `"logs"`):
73
+ Top-level logging directory for checkpoint saving.
74
+ num_epochs (`int`, *optional*, defaults to `100`):
75
+ Number of epochs to train.
76
+ save_freq (`int`, *optional*, defaults to `1`):
77
+ Number of epochs between saving model checkpoints.
78
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
+ Number of checkpoints to keep before overwriting old ones.
80
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
+ Mixed precision training.
82
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
83
+ Allow `tf32` on Ampere GPUs.
84
+ resume_from (`str`, *optional*, defaults to `""`):
85
+ Path to resume training from a checkpoint.
86
+ sample_num_steps (`int`, *optional*, defaults to `50`):
87
+ Number of sampler inference steps.
88
+ sample_eta (`float`, *optional*, defaults to `1.0`):
89
+ Eta parameter for the DDIM sampler.
90
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
+ Classifier-free guidance weight.
92
+ train_batch_size (`int`, *optional*, defaults to `1`):
93
+ Batch size for training.
94
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
+ Learning rate.
98
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
+ Beta1 for Adam optimizer.
100
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
+ Beta2 for Adam optimizer.
102
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
+ Weight decay for Adam optimizer.
104
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
+ Epsilon value for Adam optimizer.
106
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
+ Number of gradient accumulation steps.
108
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
+ Maximum gradient norm for gradient clipping.
110
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
+ Comma-separated list of prompts to use as negative examples.
112
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
+ If `True`, randomized truncation to different diffusion timesteps is used.
114
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
+ Range of diffusion timesteps for randomized truncated backpropagation.
118
+ push_to_hub (`bool`, *optional*, defaults to `False`):
119
+ Whether to push the final model to the Hub.
120
+
121
+ """
122
+ vllm_sampling_params: Optional[Any] = field(
123
+ default = None,
124
+ metadata = {'help': 'vLLM SamplingParams'},
125
+ )
126
+ unsloth_num_chunks : Optional[int] = field(
127
+ default = -1,
128
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
+ )
130
+ def __init__(
131
+ self,
132
+ exp_name = 'main',
133
+ run_name = '',
134
+ seed = 3407,
135
+ log_with = None,
136
+ log_image_freq = 1,
137
+ tracker_project_name = 'trl',
138
+ logdir = 'logs',
139
+ num_epochs = 100,
140
+ save_freq = 1,
141
+ num_checkpoint_limit = 5,
142
+ mixed_precision = 'fp16',
143
+ allow_tf32 = True,
144
+ resume_from = '',
145
+ sample_num_steps = 50,
146
+ sample_eta = 1.0,
147
+ sample_guidance_scale = 5.0,
148
+ train_batch_size = 1,
149
+ train_use_8bit_adam = False,
150
+ train_learning_rate = 5e-05,
151
+ train_adam_beta1 = 0.9,
152
+ train_adam_beta2 = 0.999,
153
+ train_adam_weight_decay = 0.01,
154
+ train_adam_epsilon = 1e-08,
155
+ train_gradient_accumulation_steps = 2,
156
+ train_max_grad_norm = 1.0,
157
+ negative_prompts = None,
158
+ truncated_backprop_rand = True,
159
+ truncated_backprop_timestep = 49,
160
+ push_to_hub = False,
161
+ vllm_sampling_params = None,
162
+ unsloth_num_chunks = -1,
163
+ **kwargs,
164
+ ):
165
+
166
+ super().__init__(
167
+ exp_name = exp_name,
168
+ run_name = run_name,
169
+ seed = seed,
170
+ log_with = log_with,
171
+ log_image_freq = log_image_freq,
172
+ tracker_project_name = tracker_project_name,
173
+ logdir = logdir,
174
+ num_epochs = num_epochs,
175
+ save_freq = save_freq,
176
+ num_checkpoint_limit = num_checkpoint_limit,
177
+ mixed_precision = mixed_precision,
178
+ allow_tf32 = allow_tf32,
179
+ resume_from = resume_from,
180
+ sample_num_steps = sample_num_steps,
181
+ sample_eta = sample_eta,
182
+ sample_guidance_scale = sample_guidance_scale,
183
+ train_batch_size = train_batch_size,
184
+ train_use_8bit_adam = train_use_8bit_adam,
185
+ train_learning_rate = train_learning_rate,
186
+ train_adam_beta1 = train_adam_beta1,
187
+ train_adam_beta2 = train_adam_beta2,
188
+ train_adam_weight_decay = train_adam_weight_decay,
189
+ train_adam_epsilon = train_adam_epsilon,
190
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
+ train_max_grad_norm = train_max_grad_norm,
192
+ negative_prompts = negative_prompts,
193
+ truncated_backprop_rand = truncated_backprop_rand,
194
+ truncated_backprop_timestep = truncated_backprop_timestep,
195
+ push_to_hub = push_to_hub,**kwargs)
196
+ self.vllm_sampling_params = vllm_sampling_params
197
+ self.unsloth_num_chunks = unsloth_num_chunks
198
+ pass
199
+
200
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
+ """"""
202
+
203
+ _tag_names = ["trl", "alignprop"]
204
+
205
+ def __init__(
206
+ self,
207
+ config: AlignPropConfig,
208
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
+ prompt_function: Callable[[], tuple[str, Any]],
210
+ sd_pipeline: DDPOStableDiffusionPipeline,
211
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
+ ):
213
+ if image_samples_hook is None:
214
+ warn("No image_samples_hook provided; no images will be logged")
215
+
216
+ self.prompt_fn = prompt_function
217
+ self.reward_fn = reward_function
218
+ self.config = config
219
+ self.image_samples_callback = image_samples_hook
220
+
221
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
+
223
+ if self.config.resume_from:
224
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
+ # get the most recent checkpoint in this directory
227
+ checkpoints = list(
228
+ filter(
229
+ lambda x: "checkpoint_" in x,
230
+ os.listdir(self.config.resume_from),
231
+ )
232
+ )
233
+ if len(checkpoints) == 0:
234
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
+ self.config.resume_from = os.path.join(
237
+ self.config.resume_from,
238
+ f"checkpoint_{checkpoint_numbers[-1]}",
239
+ )
240
+
241
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
+
243
+ self.accelerator = Accelerator(
244
+ log_with=self.config.log_with,
245
+ mixed_precision=self.config.mixed_precision,
246
+ project_config=accelerator_project_config,
247
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
+ # the total number of optimizer steps to accumulate across.
250
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
+ **self.config.accelerator_kwargs,
252
+ )
253
+
254
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
+
256
+ if self.accelerator.is_main_process:
257
+ self.accelerator.init_trackers(
258
+ self.config.tracker_project_name,
259
+ config=dict(alignprop_trainer_config=config.to_dict())
260
+ if not is_using_tensorboard
261
+ else config.to_dict(),
262
+ init_kwargs=self.config.tracker_kwargs,
263
+ )
264
+
265
+ logger.info(f"\n{config}")
266
+
267
+ set_seed(self.config.seed, device_specific=True)
268
+
269
+ self.sd_pipeline = sd_pipeline
270
+
271
+ self.sd_pipeline.set_progress_bar_config(
272
+ position=1,
273
+ disable=not self.accelerator.is_local_main_process,
274
+ leave=False,
275
+ desc="Timestep",
276
+ dynamic_ncols=True,
277
+ )
278
+
279
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
+ # as these weights are only used for inference, keeping weights in full precision is not required.
281
+ if self.accelerator.mixed_precision == "fp16":
282
+ inference_dtype = torch.float16
283
+ elif self.accelerator.mixed_precision == "bf16":
284
+ inference_dtype = torch.bfloat16
285
+ else:
286
+ inference_dtype = torch.float32
287
+
288
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
+
292
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
293
+
294
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
+
297
+ # Enable TF32 for faster training on Ampere GPUs,
298
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
+ if self.config.allow_tf32:
300
+ torch.backends.cuda.matmul.allow_tf32 = True
301
+
302
+ self.optimizer = self._setup_optimizer(
303
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
+ )
305
+
306
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
+ self.sd_pipeline.tokenizer(
308
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
+ return_tensors="pt",
310
+ padding="max_length",
311
+ truncation=True,
312
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
313
+ ).input_ids.to(self.accelerator.device)
314
+ )[0]
315
+
316
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
+ # more memory
318
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
+
320
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
+ else:
324
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
+
326
+ if config.resume_from:
327
+ logger.info(f"Resuming from {config.resume_from}")
328
+ self.accelerator.load_state(config.resume_from)
329
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
+ else:
331
+ self.first_epoch = 0
332
+
333
+ def compute_rewards(self, prompt_image_pairs):
334
+ reward, reward_metadata = self.reward_fn(
335
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
+ )
337
+ return reward
338
+
339
+ def step(self, epoch: int, global_step: int):
340
+ """
341
+ Perform a single step of training.
342
+
343
+ Args:
344
+ epoch (int): The current epoch.
345
+ global_step (int): The current global step.
346
+
347
+ Side Effects:
348
+ - Model weights are updated
349
+ - Logs the statistics to the accelerator trackers.
350
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
+
352
+ Returns:
353
+ global_step (int): The updated global step.
354
+ """
355
+ info = defaultdict(list)
356
+
357
+ self.sd_pipeline.unet.train()
358
+
359
+ for _ in range(self.config.train_gradient_accumulation_steps):
360
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
+ prompt_image_pairs = self._generate_samples(
362
+ batch_size=self.config.train_batch_size,
363
+ )
364
+
365
+ rewards = self.compute_rewards(prompt_image_pairs)
366
+
367
+ prompt_image_pairs["rewards"] = rewards
368
+
369
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
+
371
+ loss = self.calculate_loss(rewards)
372
+
373
+ self.accelerator.backward(loss)
374
+
375
+ if self.accelerator.sync_gradients:
376
+ self.accelerator.clip_grad_norm_(
377
+ self.trainable_layers.parameters()
378
+ if not isinstance(self.trainable_layers, list)
379
+ else self.trainable_layers,
380
+ self.config.train_max_grad_norm,
381
+ )
382
+
383
+ self.optimizer.step()
384
+ self.optimizer.zero_grad()
385
+
386
+ info["reward_mean"].append(rewards_vis.mean())
387
+ info["reward_std"].append(rewards_vis.std())
388
+ info["loss"].append(loss.item())
389
+
390
+ # Checks if the accelerator has performed an optimization step behind the scenes
391
+ if self.accelerator.sync_gradients:
392
+ # log training-related stuff
393
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
+ info = self.accelerator.reduce(info, reduction="mean")
395
+ info.update({"epoch": epoch})
396
+ self.accelerator.log(info, step=global_step)
397
+ global_step += 1
398
+ info = defaultdict(list)
399
+ else:
400
+ raise ValueError(
401
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
+ )
403
+ # Logs generated images
404
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
+
407
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
+ self.accelerator.save_state()
409
+
410
+ return global_step
411
+
412
+ def calculate_loss(self, rewards):
413
+ """
414
+ Calculate the loss for a batch of an unpacked sample
415
+
416
+ Args:
417
+ rewards (torch.Tensor):
418
+ Differentiable reward scalars for each generated image, shape: [batch_size]
419
+
420
+ Returns:
421
+ loss (torch.Tensor)
422
+ (all of these are of shape (1,))
423
+ """
424
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
+ loss = 10.0 - (rewards).mean()
426
+ return loss
427
+
428
+ def loss(
429
+ self,
430
+ advantages: torch.Tensor,
431
+ clip_range: float,
432
+ ratio: torch.Tensor,
433
+ ):
434
+ unclipped_loss = -advantages * ratio
435
+ clipped_loss = -advantages * torch.clamp(
436
+ ratio,
437
+ 1.0 - clip_range,
438
+ 1.0 + clip_range,
439
+ )
440
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
+
442
+ def _setup_optimizer(self, trainable_layers_parameters):
443
+ if self.config.train_use_8bit_adam:
444
+ import bitsandbytes
445
+
446
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
447
+ else:
448
+ optimizer_cls = torch.optim.AdamW
449
+
450
+ return optimizer_cls(
451
+ trainable_layers_parameters,
452
+ lr=self.config.train_learning_rate,
453
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
+ weight_decay=self.config.train_adam_weight_decay,
455
+ eps=self.config.train_adam_epsilon,
456
+ )
457
+
458
+ def _save_model_hook(self, models, weights, output_dir):
459
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
+
462
+ def _load_model_hook(self, models, input_dir):
463
+ self.sd_pipeline.load_checkpoint(models, input_dir)
464
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
+
466
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
+ """
468
+ Generate samples from the model
469
+
470
+ Args:
471
+ batch_size (int): Batch size to use for sampling
472
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
+
474
+ Returns:
475
+ prompt_image_pairs (dict[Any])
476
+ """
477
+ prompt_image_pairs = {}
478
+
479
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
+
481
+ if prompts is None:
482
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
+ else:
484
+ prompt_metadata = [{} for _ in range(batch_size)]
485
+
486
+ prompt_ids = self.sd_pipeline.tokenizer(
487
+ prompts,
488
+ return_tensors="pt",
489
+ padding="max_length",
490
+ truncation=True,
491
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
492
+ ).input_ids.to(self.accelerator.device)
493
+
494
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
+
496
+ if with_grad:
497
+ sd_output = self.sd_pipeline.rgb_with_grad(
498
+ prompt_embeds=prompt_embeds,
499
+ negative_prompt_embeds=sample_neg_prompt_embeds,
500
+ num_inference_steps=self.config.sample_num_steps,
501
+ guidance_scale=self.config.sample_guidance_scale,
502
+ eta=self.config.sample_eta,
503
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
504
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
+ output_type="pt",
507
+ )
508
+ else:
509
+ sd_output = self.sd_pipeline(
510
+ prompt_embeds=prompt_embeds,
511
+ negative_prompt_embeds=sample_neg_prompt_embeds,
512
+ num_inference_steps=self.config.sample_num_steps,
513
+ guidance_scale=self.config.sample_guidance_scale,
514
+ eta=self.config.sample_eta,
515
+ output_type="pt",
516
+ )
517
+
518
+ images = sd_output.images
519
+
520
+ prompt_image_pairs["images"] = images
521
+ prompt_image_pairs["prompts"] = prompts
522
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
+
524
+ return prompt_image_pairs
525
+
526
+ def train(self, epochs: Optional[int] = None):
527
+ """
528
+ Train the model for a given number of epochs
529
+ """
530
+ global_step = 0
531
+ if epochs is None:
532
+ epochs = self.config.num_epochs
533
+ for epoch in range(self.first_epoch, epochs):
534
+ global_step = self.step(epoch, global_step)
535
+
536
+ def _save_pretrained(self, save_directory):
537
+ self.sd_pipeline.save_pretrained(save_directory)
538
+ self.create_model_card()
539
+
540
+ def create_model_card(
541
+ self,
542
+ model_name: Optional[str] = None,
543
+ dataset_name: Optional[str] = None,
544
+ tags: Union[str, list[str], None] = None,
545
+ ):
546
+ """
547
+ Creates a draft of a model card using the information available to the `Trainer`.
548
+
549
+ Args:
550
+ model_name (`str` or `None`, *optional*, defaults to `None`):
551
+ Name of the model.
552
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
+ Name of the dataset used for training.
554
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
+ Tags to be associated with the model card.
556
+ """
557
+ if not self.is_world_process_zero():
558
+ return
559
+
560
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
+ base_model = self.model.config._name_or_path
562
+ else:
563
+ base_model = None
564
+
565
+ tags = tags or []
566
+ if isinstance(tags, str):
567
+ tags = [tags]
568
+
569
+ if hasattr(self.model.config, "unsloth_version"):
570
+ tags.append("unsloth")
571
+
572
+ citation = textwrap.dedent("""\
573
+ @article{prabhudesai2024aligning,
574
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
+ year = 2024,
577
+ eprint = {arXiv:2310.03739}
578
+ }""")
579
+
580
+ model_card = generate_model_card(
581
+ base_model=base_model,
582
+ model_name=model_name,
583
+ hub_model_id=self.hub_model_id,
584
+ dataset_name=dataset_name,
585
+ tags=tags,
586
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
+ comet_url=get_comet_experiment_url(),
588
+ trainer_name="AlignProp",
589
+ trainer_citation=citation,
590
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
+ paper_id="2310.03739",
592
+ )
593
+
594
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
+ """
597
+
598
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
+ As of now only Stable Diffusion based pipelines are supported
601
+
602
+ Attributes:
603
+ config (`AlignPropConfig`):
604
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
+ Reward function to be used
607
+ prompt_function (`Callable[[], tuple[str, Any]]`):
608
+ Function to generate prompts to guide model
609
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
610
+ Stable Diffusion pipeline to be used for training.
611
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
+ Hook to be called to log images
613
+
614
+ """
615
+ def __init__(
616
+ self,
617
+ config,
618
+ reward_function,
619
+ prompt_function,
620
+ sd_pipeline,
621
+ image_samples_hook = None,
622
+ **kwargs
623
+ ):
624
+ if args is None: args = UnslothAlignPropConfig()
625
+ other_metrics = []
626
+
627
+ from unsloth_zoo.logging_utils import PatchRLStatistics
628
+ PatchRLStatistics('alignprop_trainer', other_metrics)
629
+
630
+ super().__init__(
631
+ config = config,
632
+ reward_function = reward_function,
633
+ prompt_function = prompt_function,
634
+ sd_pipeline = sd_pipeline,
635
+ image_samples_hook = image_samples_hook,**kwargs)
636
+
637
+ pass
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothBCOConfig(BCOConfig):
44
+ """
45
+
46
+ Configuration class for the [`BCOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
+ to use the default data collator.
56
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
+ Maximum length of the completion. This argument is required if you want to use the default data collator
60
+ and your model is an encoder-decoder.
61
+ beta (`float`, *optional*, defaults to `0.1`):
62
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
+ reference model.
64
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
+ Label pad token id. This argument is required if you want to use the default data collator.
66
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
67
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
68
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
+ This argument is required if you want to use the default data collator.
71
+ disable_dropout (`bool`, *optional*, defaults to `True`):
72
+ Whether to disable dropout in the model and reference model.
73
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
74
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
+ evaluation.
76
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
+ you need to specify if the model returned by the callable is an encoder-decoder model.
79
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
+ useful when training without the reference model to reduce the total GPU memory needed.
82
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
+ string.
85
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
+ from a string.
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
+ Number of prompts that are fed to density ratio classifier.
92
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
+
97
+ """
98
+ vllm_sampling_params: Optional[Any] = field(
99
+ default = None,
100
+ metadata = {'help': 'vLLM SamplingParams'},
101
+ )
102
+ unsloth_num_chunks : Optional[int] = field(
103
+ default = -1,
104
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
+ )
106
+ def __init__(
107
+ self,
108
+ output_dir = None,
109
+ overwrite_output_dir = None,
110
+ do_train = False,
111
+ do_eval = False,
112
+ do_predict = False,
113
+ eval_strategy = 'no',
114
+ prediction_loss_only = False,
115
+ per_device_train_batch_size = 4,
116
+ per_device_eval_batch_size = 4,
117
+ per_gpu_train_batch_size = None,
118
+ per_gpu_eval_batch_size = None,
119
+ gradient_accumulation_steps = 2,
120
+ eval_accumulation_steps = 2,
121
+ eval_delay = 0,
122
+ torch_empty_cache_steps = 250,
123
+ learning_rate = 5e-05,
124
+ weight_decay = 0.01,
125
+ adam_beta1 = 0.9,
126
+ adam_beta2 = 0.999,
127
+ adam_epsilon = 1e-08,
128
+ max_grad_norm = 1.0,
129
+ num_train_epochs = 3.0,
130
+ max_steps = -1,
131
+ lr_scheduler_type = 'linear',
132
+ warmup_ratio = 0.1,
133
+ warmup_steps = 0,
134
+ log_level = 'passive',
135
+ log_level_replica = 'warning',
136
+ log_on_each_node = True,
137
+ logging_dir = None,
138
+ logging_strategy = 'steps',
139
+ logging_first_step = False,
140
+ logging_steps = 1,
141
+ logging_nan_inf_filter = False,
142
+ save_strategy = 'steps',
143
+ save_steps = 500,
144
+ save_total_limit = None,
145
+ save_safetensors = True,
146
+ save_on_each_node = False,
147
+ save_only_model = False,
148
+ restore_callback_states_from_checkpoint = False,
149
+ no_cuda = False,
150
+ use_cpu = False,
151
+ use_mps_device = False,
152
+ seed = 3407,
153
+ data_seed = 3407,
154
+ jit_mode_eval = False,
155
+ use_ipex = False,
156
+ bf16 = False,
157
+ fp16 = False,
158
+ fp16_opt_level = 'O1',
159
+ half_precision_backend = 'auto',
160
+ bf16_full_eval = False,
161
+ fp16_full_eval = False,
162
+ tf32 = None,
163
+ local_rank = -1,
164
+ ddp_backend = None,
165
+ tpu_num_cores = None,
166
+ tpu_metrics_debug = False,
167
+ debug = '',
168
+ dataloader_drop_last = False,
169
+ eval_steps = None,
170
+ dataloader_num_workers = 0,
171
+ dataloader_prefetch_factor = None,
172
+ past_index = -1,
173
+ run_name = None,
174
+ disable_tqdm = None,
175
+ remove_unused_columns = True,
176
+ label_names = None,
177
+ load_best_model_at_end = False,
178
+ metric_for_best_model = None,
179
+ greater_is_better = None,
180
+ ignore_data_skip = False,
181
+ fsdp = '',
182
+ fsdp_min_num_params = 0,
183
+ fsdp_config = None,
184
+ fsdp_transformer_layer_cls_to_wrap = None,
185
+ accelerator_config = None,
186
+ deepspeed = None,
187
+ label_smoothing_factor = 0.0,
188
+ optim = 'adamw_8bit',
189
+ optim_args = None,
190
+ adafactor = False,
191
+ group_by_length = False,
192
+ length_column_name = 'length',
193
+ report_to = None,
194
+ ddp_find_unused_parameters = None,
195
+ ddp_bucket_cap_mb = None,
196
+ ddp_broadcast_buffers = None,
197
+ dataloader_pin_memory = True,
198
+ dataloader_persistent_workers = False,
199
+ skip_memory_metrics = True,
200
+ use_legacy_prediction_loop = False,
201
+ push_to_hub = False,
202
+ resume_from_checkpoint = None,
203
+ hub_model_id = None,
204
+ hub_strategy = 'every_save',
205
+ hub_token = None,
206
+ hub_private_repo = None,
207
+ hub_always_push = False,
208
+ gradient_checkpointing = False,
209
+ gradient_checkpointing_kwargs = None,
210
+ include_inputs_for_metrics = False,
211
+ eval_do_concat_batches = True,
212
+ fp16_backend = 'auto',
213
+ evaluation_strategy = None,
214
+ push_to_hub_model_id = None,
215
+ push_to_hub_organization = None,
216
+ push_to_hub_token = None,
217
+ mp_parameters = '',
218
+ auto_find_batch_size = False,
219
+ full_determinism = False,
220
+ torchdynamo = None,
221
+ ray_scope = 'last',
222
+ ddp_timeout = 1800,
223
+ torch_compile = False,
224
+ torch_compile_backend = None,
225
+ torch_compile_mode = None,
226
+ dispatch_batches = None,
227
+ split_batches = None,
228
+ include_tokens_per_second = False,
229
+ include_num_input_tokens_seen = False,
230
+ neftune_noise_alpha = None,
231
+ optim_target_modules = None,
232
+ batch_eval_metrics = False,
233
+ eval_on_start = False,
234
+ use_liger_kernel = False,
235
+ eval_use_gather_object = False,
236
+ average_tokens_across_devices = False,
237
+ max_length = 1024,
238
+ max_prompt_length = 512,
239
+ max_completion_length = None,
240
+ beta = 0.1,
241
+ label_pad_token_id = -100,
242
+ padding_value = None,
243
+ truncation_mode = 'keep_end',
244
+ disable_dropout = True,
245
+ generate_during_eval = False,
246
+ is_encoder_decoder = None,
247
+ precompute_ref_log_probs = False,
248
+ model_init_kwargs = None,
249
+ ref_model_init_kwargs = None,
250
+ dataset_num_proc = None,
251
+ prompt_sample_size = 1024,
252
+ min_density_ratio = 0.5,
253
+ max_density_ratio = 10.0,
254
+ vllm_sampling_params = None,
255
+ unsloth_num_chunks = -1,
256
+ **kwargs,
257
+ ):
258
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
259
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
260
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
261
+ output_dir = 'unsloth_training_checkpoints'
262
+ save_strategy = 'no'
263
+ if dataset_num_proc is None:
264
+ from multiprocessing import cpu_count
265
+ dataset_num_proc = cpu_count()
266
+
267
+ super().__init__(
268
+ output_dir = output_dir,
269
+ overwrite_output_dir = overwrite_output_dir,
270
+ do_train = do_train,
271
+ do_eval = do_eval,
272
+ do_predict = do_predict,
273
+ eval_strategy = eval_strategy,
274
+ prediction_loss_only = prediction_loss_only,
275
+ per_device_train_batch_size = per_device_train_batch_size,
276
+ per_device_eval_batch_size = per_device_eval_batch_size,
277
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
278
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
279
+ gradient_accumulation_steps = gradient_accumulation_steps,
280
+ eval_accumulation_steps = eval_accumulation_steps,
281
+ eval_delay = eval_delay,
282
+ torch_empty_cache_steps = torch_empty_cache_steps,
283
+ learning_rate = learning_rate,
284
+ weight_decay = weight_decay,
285
+ adam_beta1 = adam_beta1,
286
+ adam_beta2 = adam_beta2,
287
+ adam_epsilon = adam_epsilon,
288
+ max_grad_norm = max_grad_norm,
289
+ num_train_epochs = num_train_epochs,
290
+ max_steps = max_steps,
291
+ lr_scheduler_type = lr_scheduler_type,
292
+ warmup_ratio = warmup_ratio,
293
+ warmup_steps = warmup_steps,
294
+ log_level = log_level,
295
+ log_level_replica = log_level_replica,
296
+ log_on_each_node = log_on_each_node,
297
+ logging_dir = logging_dir,
298
+ logging_strategy = logging_strategy,
299
+ logging_first_step = logging_first_step,
300
+ logging_steps = logging_steps,
301
+ logging_nan_inf_filter = logging_nan_inf_filter,
302
+ save_strategy = save_strategy,
303
+ save_steps = save_steps,
304
+ save_total_limit = save_total_limit,
305
+ save_safetensors = save_safetensors,
306
+ save_on_each_node = save_on_each_node,
307
+ save_only_model = save_only_model,
308
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
309
+ no_cuda = no_cuda,
310
+ use_cpu = use_cpu,
311
+ use_mps_device = use_mps_device,
312
+ seed = seed,
313
+ data_seed = data_seed,
314
+ jit_mode_eval = jit_mode_eval,
315
+ use_ipex = use_ipex,
316
+ bf16 = bf16,
317
+ fp16 = fp16,
318
+ fp16_opt_level = fp16_opt_level,
319
+ half_precision_backend = half_precision_backend,
320
+ bf16_full_eval = bf16_full_eval,
321
+ fp16_full_eval = fp16_full_eval,
322
+ tf32 = tf32,
323
+ local_rank = local_rank,
324
+ ddp_backend = ddp_backend,
325
+ tpu_num_cores = tpu_num_cores,
326
+ tpu_metrics_debug = tpu_metrics_debug,
327
+ debug = debug,
328
+ dataloader_drop_last = dataloader_drop_last,
329
+ eval_steps = eval_steps,
330
+ dataloader_num_workers = dataloader_num_workers,
331
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
332
+ past_index = past_index,
333
+ run_name = run_name,
334
+ disable_tqdm = disable_tqdm,
335
+ remove_unused_columns = remove_unused_columns,
336
+ label_names = label_names,
337
+ load_best_model_at_end = load_best_model_at_end,
338
+ metric_for_best_model = metric_for_best_model,
339
+ greater_is_better = greater_is_better,
340
+ ignore_data_skip = ignore_data_skip,
341
+ fsdp = fsdp,
342
+ fsdp_min_num_params = fsdp_min_num_params,
343
+ fsdp_config = fsdp_config,
344
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
345
+ accelerator_config = accelerator_config,
346
+ deepspeed = deepspeed,
347
+ label_smoothing_factor = label_smoothing_factor,
348
+ optim = optim,
349
+ optim_args = optim_args,
350
+ adafactor = adafactor,
351
+ group_by_length = group_by_length,
352
+ length_column_name = length_column_name,
353
+ report_to = report_to,
354
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
355
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
356
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
357
+ dataloader_pin_memory = dataloader_pin_memory,
358
+ dataloader_persistent_workers = dataloader_persistent_workers,
359
+ skip_memory_metrics = skip_memory_metrics,
360
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
361
+ push_to_hub = push_to_hub,
362
+ resume_from_checkpoint = resume_from_checkpoint,
363
+ hub_model_id = hub_model_id,
364
+ hub_strategy = hub_strategy,
365
+ hub_token = hub_token,
366
+ hub_private_repo = hub_private_repo,
367
+ hub_always_push = hub_always_push,
368
+ gradient_checkpointing = gradient_checkpointing,
369
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
370
+ include_inputs_for_metrics = include_inputs_for_metrics,
371
+ eval_do_concat_batches = eval_do_concat_batches,
372
+ fp16_backend = fp16_backend,
373
+ evaluation_strategy = evaluation_strategy,
374
+ push_to_hub_model_id = push_to_hub_model_id,
375
+ push_to_hub_organization = push_to_hub_organization,
376
+ push_to_hub_token = push_to_hub_token,
377
+ mp_parameters = mp_parameters,
378
+ auto_find_batch_size = auto_find_batch_size,
379
+ full_determinism = full_determinism,
380
+ torchdynamo = torchdynamo,
381
+ ray_scope = ray_scope,
382
+ ddp_timeout = ddp_timeout,
383
+ torch_compile = torch_compile,
384
+ torch_compile_backend = torch_compile_backend,
385
+ torch_compile_mode = torch_compile_mode,
386
+ dispatch_batches = dispatch_batches,
387
+ split_batches = split_batches,
388
+ include_tokens_per_second = include_tokens_per_second,
389
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
390
+ neftune_noise_alpha = neftune_noise_alpha,
391
+ optim_target_modules = optim_target_modules,
392
+ batch_eval_metrics = batch_eval_metrics,
393
+ eval_on_start = eval_on_start,
394
+ use_liger_kernel = use_liger_kernel,
395
+ eval_use_gather_object = eval_use_gather_object,
396
+ average_tokens_across_devices = average_tokens_across_devices,
397
+ max_length = max_length,
398
+ max_prompt_length = max_prompt_length,
399
+ max_completion_length = max_completion_length,
400
+ beta = beta,
401
+ label_pad_token_id = label_pad_token_id,
402
+ padding_value = padding_value,
403
+ truncation_mode = truncation_mode,
404
+ disable_dropout = disable_dropout,
405
+ generate_during_eval = generate_during_eval,
406
+ is_encoder_decoder = is_encoder_decoder,
407
+ precompute_ref_log_probs = precompute_ref_log_probs,
408
+ model_init_kwargs = model_init_kwargs,
409
+ ref_model_init_kwargs = ref_model_init_kwargs,
410
+ dataset_num_proc = dataset_num_proc,
411
+ prompt_sample_size = prompt_sample_size,
412
+ min_density_ratio = min_density_ratio,
413
+ max_density_ratio = max_density_ratio,**kwargs)
414
+ self.vllm_sampling_params = vllm_sampling_params
415
+ self.unsloth_num_chunks = unsloth_num_chunks
416
+ pass
417
+
418
+ class _UnslothBCOTrainer(Trainer):
419
+ r""""""
420
+
421
+ _tag_names = ["trl", "bco"]
422
+
423
+ def __init__(
424
+ self,
425
+ model: Union[PreTrainedModel, nn.Module, str] = None,
426
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
427
+ args: BCOConfig = None,
428
+ train_dataset: Optional[Dataset] = None,
429
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
430
+ processing_class: Optional[
431
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
432
+ ] = None,
433
+ data_collator: Optional[DataCollator] = None,
434
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
435
+ callbacks: Optional[list[TrainerCallback]] = None,
436
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
437
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
438
+ peft_config: Optional[dict] = None,
439
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
440
+ model_adapter_name: Optional[str] = None,
441
+ ref_adapter_name: Optional[str] = None,
442
+ embedding_func: Optional[Callable] = None,
443
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
444
+ ):
445
+ if not is_sklearn_available():
446
+ raise ImportError(
447
+ "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
448
+ )
449
+
450
+ if type(args) is TrainingArguments:
451
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
452
+
453
+ if not isinstance(model, str) and ref_model is model:
454
+ raise ValueError(
455
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
456
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
457
+ )
458
+
459
+ if args.model_init_kwargs is None:
460
+ model_init_kwargs = {}
461
+ elif not isinstance(model, str):
462
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
463
+ else:
464
+ model_init_kwargs = args.model_init_kwargs
465
+ torch_dtype = model_init_kwargs.get("torch_dtype")
466
+ if torch_dtype is not None:
467
+ # Convert to `torch.dtype` if an str is passed
468
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
469
+ torch_dtype = getattr(torch, torch_dtype)
470
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
471
+ raise ValueError(
472
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
473
+ )
474
+ model_init_kwargs["torch_dtype"] = torch_dtype
475
+
476
+ if args.ref_model_init_kwargs is None:
477
+ ref_model_init_kwargs = {}
478
+ elif not isinstance(ref_model, str):
479
+ raise ValueError(
480
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
481
+ )
482
+ else:
483
+ ref_model_init_kwargs = args.ref_model_init_kwargs
484
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
485
+ if torch_dtype is not None:
486
+ # Convert to `torch.dtype` if an str is passed
487
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
488
+ torch_dtype = getattr(torch, torch_dtype)
489
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
490
+ raise ValueError(
491
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
492
+ )
493
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
494
+
495
+ if isinstance(model, str):
496
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
497
+
498
+ if isinstance(ref_model, str):
499
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
500
+
501
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
502
+ # has been called in order to properly call autocast if needed.
503
+ self._peft_has_been_casted_to_bf16 = False
504
+
505
+ if not is_peft_available() and peft_config is not None:
506
+ raise ValueError(
507
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
508
+ )
509
+ elif is_peft_available() and peft_config is not None:
510
+ # if model is a peft model and we have a peft_config, we merge and unload it first
511
+ if isinstance(model, PeftModel):
512
+ model = model.merge_and_unload()
513
+
514
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
515
+ _support_gc_kwargs = hasattr(
516
+ args, "gradient_checkpointing_kwargs"
517
+ ) and "gradient_checkpointing_kwargs" in list(
518
+ inspect.signature(prepare_model_for_kbit_training).parameters
519
+ )
520
+
521
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
522
+
523
+ if _support_gc_kwargs:
524
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
525
+
526
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
527
+ elif getattr(args, "gradient_checkpointing", False):
528
+ # For backward compatibility with older versions of transformers
529
+ if hasattr(model, "enable_input_require_grads"):
530
+ model.enable_input_require_grads()
531
+ else:
532
+
533
+ def make_inputs_require_grad(module, input, output):
534
+ output.requires_grad_(True)
535
+
536
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
537
+
538
+ # get peft model with the given config
539
+ model = model
540
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
541
+ peft_module_casting_to_bf16(model)
542
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
543
+ self._peft_has_been_casted_to_bf16 = True
544
+
545
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
546
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
547
+ # fail or completely fail.
548
+ elif getattr(args, "gradient_checkpointing", False):
549
+ # For backward compatibility with older versions of transformers
550
+ if hasattr(model, "enable_input_require_grads"):
551
+ model.enable_input_require_grads()
552
+ else:
553
+
554
+ def make_inputs_require_grad(module, input, output):
555
+ output.requires_grad_(True)
556
+
557
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
558
+
559
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
560
+ raise ValueError(
561
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
562
+ " Please install `wandb` or `comet-ml` to resolve."
563
+ )
564
+
565
+ if model is not None:
566
+ self.is_encoder_decoder = model.config.is_encoder_decoder
567
+ elif args.is_encoder_decoder is None:
568
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
569
+ else:
570
+ self.is_encoder_decoder = args.is_encoder_decoder
571
+
572
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
573
+ self.model_adapter_name = model_adapter_name
574
+ self.ref_adapter_name = ref_adapter_name
575
+
576
+ if ref_model:
577
+ self.ref_model = ref_model
578
+ elif self.is_peft_model or args.precompute_ref_log_probs:
579
+ # The `model` with adapters turned off will be used as the reference model
580
+ self.ref_model = None
581
+ else:
582
+ self.ref_model = create_reference_model(model)
583
+
584
+ if processing_class is None:
585
+ raise ValueError(
586
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
587
+ )
588
+ if args.max_length is None:
589
+ warnings.warn(
590
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
591
+ "It will be set to `512` by default, but you should do it yourself in the future.",
592
+ UserWarning,
593
+ )
594
+ max_length = 512
595
+ if args.max_length is not None:
596
+ max_length = args.max_length
597
+
598
+ if args.max_prompt_length is None:
599
+ warnings.warn(
600
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
601
+ "It will be set to `128` by default, but you should do it yourself in the future.",
602
+ UserWarning,
603
+ )
604
+ max_prompt_length = 128
605
+ if args.max_prompt_length is not None:
606
+ max_prompt_length = args.max_prompt_length
607
+
608
+ max_completion_length = None
609
+ if args.max_completion_length is None and self.is_encoder_decoder:
610
+ warnings.warn(
611
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
612
+ " it will be set to `128` by default, but you should do it yourself in the future.",
613
+ UserWarning,
614
+ )
615
+ max_completion_length = 128
616
+ if args.max_completion_length is not None and self.is_encoder_decoder:
617
+ max_completion_length = args.max_completion_length
618
+
619
+ if data_collator is None:
620
+ data_collator = DPODataCollatorWithPadding(
621
+ pad_token_id=processing_class.pad_token_id,
622
+ label_pad_token_id=args.label_pad_token_id,
623
+ is_encoder_decoder=self.is_encoder_decoder,
624
+ )
625
+
626
+ if args.remove_unused_columns:
627
+ args.remove_unused_columns = False
628
+ # warn users
629
+ warnings.warn(
630
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
631
+ " we have set it for you, but you should do it yourself in the future.",
632
+ UserWarning,
633
+ )
634
+
635
+ self.use_dpo_data_collator = True
636
+ else:
637
+ self.use_dpo_data_collator = False
638
+
639
+ # Disable dropout in the model and reference model
640
+ if args.disable_dropout:
641
+ disable_dropout_in_model(model)
642
+ if self.ref_model is not None:
643
+ disable_dropout_in_model(self.ref_model)
644
+
645
+ self.max_length = max_length
646
+ self.generate_during_eval = args.generate_during_eval
647
+ self.label_pad_token_id = args.label_pad_token_id
648
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
649
+ self.max_prompt_length = max_prompt_length
650
+ self.truncation_mode = args.truncation_mode
651
+ self.max_completion_length = max_completion_length
652
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
653
+
654
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
655
+ # keep track of first called to avoid computation of future calls
656
+ self._precomputed_train_ref_log_probs = False
657
+ self._precomputed_eval_ref_log_probs = False
658
+
659
+ # metric
660
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
661
+
662
+ # BCO parameter
663
+ self.beta = args.beta
664
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
665
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
666
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
667
+ warnings.warn(
668
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
669
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
670
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
671
+ "loss.",
672
+ UserWarning,
673
+ )
674
+
675
+ # Underlying Distribution Matching argument
676
+ self.embedding_func = embedding_func
677
+ self.embedding_tokenizer = embedding_tokenizer
678
+
679
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
680
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
681
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
682
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
683
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
684
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
685
+ # issued.
686
+ model.warnings_issued["estimate_tokens"] = True
687
+
688
+ with PartialState().local_main_process_first():
689
+ # Apply the chat template if needed
690
+ train_dataset = train_dataset.map(
691
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
692
+ )
693
+ if eval_dataset is not None:
694
+ eval_dataset = eval_dataset.map(
695
+ maybe_apply_chat_template,
696
+ fn_kwargs={"tokenizer": processing_class},
697
+ num_proc=args.dataset_num_proc,
698
+ )
699
+ # Shuffle the datasets
700
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
701
+ if eval_dataset is not None:
702
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
703
+ # Tokenize and prepare the training datasets
704
+ train_dataset = train_dataset.map(
705
+ _tokenize,
706
+ batched=True,
707
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
708
+ num_proc=args.dataset_num_proc,
709
+ desc="Tokenizing train dataset",
710
+ )
711
+
712
+ # Prepare the datasets
713
+ fn_kwargs = {
714
+ "prefix": "",
715
+ "is_encoder_decoder": self.is_encoder_decoder,
716
+ "tokenizer": processing_class,
717
+ "max_length": self.max_length,
718
+ "truncation_mode": self.truncation_mode,
719
+ "label_pad_token_id": self.label_pad_token_id,
720
+ "max_prompt_length": self.max_prompt_length,
721
+ "max_completion_length": self.max_completion_length,
722
+ }
723
+ train_dataset = train_dataset.map(
724
+ _process_tokens,
725
+ fn_kwargs=fn_kwargs,
726
+ num_proc=args.dataset_num_proc,
727
+ desc="Processing tokenized train dataset",
728
+ )
729
+
730
+ if eval_dataset is not None:
731
+ # Tokenize
732
+ eval_dataset = eval_dataset.map(
733
+ _tokenize,
734
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
735
+ batched=True,
736
+ num_proc=args.dataset_num_proc,
737
+ desc="Tokenizing eval dataset",
738
+ )
739
+
740
+ # Process
741
+ fn_kwargs = {
742
+ "prefix": "",
743
+ "is_encoder_decoder": self.is_encoder_decoder,
744
+ "tokenizer": processing_class,
745
+ "max_length": self.max_length,
746
+ "truncation_mode": self.truncation_mode,
747
+ "label_pad_token_id": self.label_pad_token_id,
748
+ "max_prompt_length": self.max_prompt_length,
749
+ "max_completion_length": self.max_completion_length,
750
+ }
751
+ eval_dataset = eval_dataset.map(
752
+ _process_tokens,
753
+ fn_kwargs=fn_kwargs,
754
+ num_proc=args.dataset_num_proc,
755
+ desc="Processing tokenized eval dataset",
756
+ )
757
+
758
+ desirable = train_dataset.filter(
759
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
760
+ )
761
+ undesirable = train_dataset.filter(
762
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
763
+ )
764
+
765
+ desirable = desirable.shuffle(seed=args.data_seed)
766
+ undesirable = undesirable.shuffle(seed=args.data_seed)
767
+
768
+ super().__init__(
769
+ model=model,
770
+ args=args,
771
+ data_collator=data_collator,
772
+ train_dataset=train_dataset,
773
+ eval_dataset=eval_dataset,
774
+ processing_class=processing_class,
775
+ model_init=model_init,
776
+ compute_metrics=compute_metrics,
777
+ callbacks=callbacks,
778
+ optimizers=optimizers,
779
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
780
+ )
781
+
782
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
783
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
784
+ # self.model_accepts_loss_kwargs to False to enable scaling.
785
+ self.model_accepts_loss_kwargs = False
786
+
787
+ # Add tags for models that have been loaded with the correct transformers version
788
+ if hasattr(self.model, "add_model_tags"):
789
+ self.model.add_model_tags(self._tag_names)
790
+
791
+ if not hasattr(self, "accelerator"):
792
+ raise AttributeError(
793
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
794
+ )
795
+
796
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
797
+ if self.is_deepspeed_enabled:
798
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
799
+ raise ValueError(
800
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
801
+ )
802
+
803
+ if self.ref_model is None:
804
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
805
+ raise ValueError(
806
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
807
+ )
808
+ else:
809
+ if self.is_deepspeed_enabled:
810
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
811
+ else:
812
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
813
+
814
+ self.running = RunningMoments(accelerator=self.accelerator)
815
+
816
+ if self.embedding_func is None:
817
+ return
818
+
819
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
820
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
821
+
822
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
823
+ labels = torch.cat(
824
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
825
+ )
826
+
827
+ self.clf = LogisticRegression(class_weight="balanced").fit(
828
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
829
+ )
830
+
831
+ @property
832
+ def match_underlying_distribution(self):
833
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
834
+
835
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
836
+ """
837
+ Calculates the probability if the given prompt embedding is from desirable dataset.
838
+ This function calculates the probability in the process and ensemble across processes.
839
+ """
840
+ dtype = prompt_embeddings.dtype
841
+ device = prompt_embeddings.device
842
+ rank = self.accelerator.process_index
843
+
844
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
845
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
846
+ )
847
+ sample_size = padded_prompt_embeddings.shape[0]
848
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
849
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
850
+
851
+ # cannot predict for all empty values
852
+ if prompt_embeddings.shape[0] == 0:
853
+ return torch.tensor([], device=device, dtype=dtype)
854
+
855
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
856
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
857
+ prob = self.accelerator.reduce(prob, reduction="mean")
858
+
859
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
860
+ prob = prob[nonzero]
861
+
862
+ return prob
863
+
864
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
865
+ """
866
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
867
+ and applies self.embedding_func
868
+ """
869
+ input_ids = torch.where(
870
+ input_ids == self.processing_class.pad_token_id,
871
+ self.embedding_tokenizer.pad_token_id,
872
+ input_ids,
873
+ )
874
+
875
+ with torch.no_grad():
876
+ embeddings = self.embedding_func(
877
+ input_ids=input_ids,
878
+ attention_mask=attention_mask,
879
+ )
880
+
881
+ return embeddings
882
+
883
+ def _get_prompt_embeddings(
884
+ self, batch: dict[str, Union[list, torch.LongTensor]]
885
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
886
+ """Extract embeddings from frozen embedding model"""
887
+
888
+ if not self.match_underlying_distribution:
889
+ return None, None
890
+
891
+ embeddings = self._vectorize_prompt(
892
+ input_ids=batch["embedding_input_ids"],
893
+ attention_mask=batch["embedding_attention_mask"],
894
+ )
895
+
896
+ chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
897
+ rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
898
+
899
+ chosen_embeddings = embeddings[chosen_idx, ...]
900
+ rejected_embeddings = embeddings[rejected_idx, ...]
901
+
902
+ return (chosen_embeddings, rejected_embeddings)
903
+
904
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
905
+ """
906
+ Sample instances from dataset and get prompt embeddings.
907
+ Used for density ratio classifier training.
908
+ """
909
+ n_samples = min(len(dataset), sample_size)
910
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
911
+
912
+ embedding_dataset = dataset.select(rand_indices)
913
+
914
+ dataloader_params = {
915
+ "batch_size": self.args.per_device_train_batch_size,
916
+ "collate_fn": self.data_collator,
917
+ "num_workers": self.args.dataloader_num_workers,
918
+ "pin_memory": self.args.dataloader_pin_memory,
919
+ "shuffle": False,
920
+ }
921
+
922
+ # prepare dataloader
923
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
924
+
925
+ with torch.no_grad():
926
+ all_embeddings = torch.empty(0)
927
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
928
+ embeddings = self._vectorize_prompt(
929
+ input_ids=padded_batch["embedding_input_ids"],
930
+ attention_mask=padded_batch["embedding_attention_mask"],
931
+ )
932
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
933
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
934
+
935
+ return all_embeddings
936
+
937
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
938
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
939
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
940
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
941
+
942
+ if model is not None:
943
+ if hasattr(model, "config"):
944
+ hidden_size = (
945
+ max(model.config.hidden_sizes)
946
+ if getattr(model.config, "hidden_sizes", None)
947
+ else getattr(model.config, "hidden_size", None)
948
+ )
949
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
950
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
951
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
952
+ config_kwargs.update(
953
+ {
954
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
955
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
956
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
957
+ }
958
+ )
959
+
960
+ # If ZeRO-3 is used, we shard both the active and reference model.
961
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
962
+ if config_kwargs["zero_optimization"]["stage"] != 3:
963
+ config_kwargs["zero_optimization"]["stage"] = 0
964
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
965
+ model.eval()
966
+ return model
967
+
968
+ def _save_optimizer_and_scheduler(self, output_dir):
969
+ super()._save_optimizer_and_scheduler(output_dir)
970
+
971
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
972
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
973
+
974
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
975
+
976
+ if self.match_underlying_distribution:
977
+ torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
978
+
979
+ def _load_optimizer_and_scheduler(self, checkpoint):
980
+ super()._load_optimizer_and_scheduler(checkpoint)
981
+
982
+ if checkpoint is None:
983
+ return
984
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
985
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
986
+ if os.path.isfile(running_file):
987
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
988
+
989
+ if self.match_underlying_distribution:
990
+ clf_file = os.path.join(checkpoint, CLF_NAME)
991
+ if os.path.isfile(running_file):
992
+ self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
993
+
994
+ @contextmanager
995
+ def null_ref_context(self):
996
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
997
+ with (
998
+ self.accelerator.unwrap_model(self.model).disable_adapter()
999
+ if self.is_peft_model and not self.ref_adapter_name
1000
+ else nullcontext()
1001
+ ):
1002
+ if self.ref_adapter_name:
1003
+ self.model.set_adapter(self.ref_adapter_name)
1004
+ yield
1005
+ if self.ref_adapter_name:
1006
+ self.model.set_adapter(self.model_adapter_name or "default")
1007
+
1008
+ def get_train_dataloader(self) -> DataLoader:
1009
+ """
1010
+ Returns the training [`~torch.utils.data.DataLoader`].
1011
+
1012
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1013
+ """
1014
+
1015
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1016
+ dataloader_params = {
1017
+ "batch_size": self.args.per_device_train_batch_size,
1018
+ "collate_fn": self.data_collator,
1019
+ "num_workers": self.args.dataloader_num_workers,
1020
+ "pin_memory": self.args.dataloader_pin_memory,
1021
+ "shuffle": False,
1022
+ }
1023
+
1024
+ # prepare dataloader
1025
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1026
+ reference_completion_logps = []
1027
+
1028
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1029
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1030
+
1031
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1032
+ reference_completion_logps.append(reference_completion_logp.cpu())
1033
+
1034
+ self.train_dataset = self.train_dataset.add_column(
1035
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1036
+ )
1037
+
1038
+ self._precomputed_train_ref_log_probs = True
1039
+
1040
+ return super().get_train_dataloader()
1041
+
1042
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1043
+ """
1044
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1045
+
1046
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1047
+
1048
+ Args:
1049
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1050
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1051
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1052
+ """
1053
+ if eval_dataset is None and self.eval_dataset is None:
1054
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1055
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1056
+
1057
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1058
+ dataloader_params = {
1059
+ "batch_size": self.args.per_device_eval_batch_size,
1060
+ "collate_fn": self.data_collator,
1061
+ "num_workers": self.args.dataloader_num_workers,
1062
+ "pin_memory": self.args.dataloader_pin_memory,
1063
+ "shuffle": False,
1064
+ }
1065
+
1066
+ # prepare dataloader
1067
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1068
+
1069
+ reference_completion_logps = []
1070
+
1071
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1072
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1073
+
1074
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1075
+ reference_completion_logps.append(reference_completion_logp.cpu())
1076
+
1077
+ eval_dataset = eval_dataset.add_column(
1078
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1079
+ )
1080
+
1081
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1082
+ if self.eval_dataset is not None:
1083
+ self.eval_dataset = eval_dataset
1084
+ self._precomputed_eval_ref_log_probs = True
1085
+
1086
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1087
+
1088
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1089
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1090
+ with torch.no_grad():
1091
+ if self.ref_model is None:
1092
+ with self.null_ref_context():
1093
+ if self.is_encoder_decoder:
1094
+ completion_logits = self.model(
1095
+ padded_batch["prompt_input_ids"],
1096
+ attention_mask=padded_batch["prompt_attention_mask"],
1097
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1098
+ labels=padded_batch["completion_labels"],
1099
+ ).logits
1100
+
1101
+ else:
1102
+ completion_logits = self.model(
1103
+ padded_batch["completion_input_ids"],
1104
+ attention_mask=padded_batch["completion_attention_mask"],
1105
+ ).logits
1106
+
1107
+ else:
1108
+ if self.is_encoder_decoder:
1109
+ completion_logits = self.ref_model(
1110
+ padded_batch["prompt_input_ids"],
1111
+ attention_mask=padded_batch["prompt_attention_mask"],
1112
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1113
+ labels=padded_batch["completion_labels"],
1114
+ ).logits
1115
+
1116
+ else:
1117
+ completion_logits = self.ref_model(
1118
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1119
+ ).logits
1120
+
1121
+ completion_logps = self.get_batch_logps(
1122
+ completion_logits,
1123
+ padded_batch["completion_labels"],
1124
+ average_log_prob=False,
1125
+ is_encoder_decoder=self.is_encoder_decoder,
1126
+ label_pad_token_id=self.label_pad_token_id,
1127
+ )
1128
+
1129
+ return completion_logps
1130
+
1131
+ @staticmethod
1132
+ def get_batch_logps(
1133
+ logits: torch.FloatTensor,
1134
+ labels: torch.LongTensor,
1135
+ average_log_prob: bool = False,
1136
+ label_pad_token_id: int = -100,
1137
+ is_encoder_decoder: bool = False,
1138
+ ) -> torch.FloatTensor:
1139
+ """Compute the log probabilities of the given labels under the given logits.
1140
+
1141
+ Args:
1142
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1143
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1144
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1145
+
1146
+ Returns:
1147
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1148
+ """
1149
+ if logits.shape[:-1] != labels.shape:
1150
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1151
+
1152
+ if not is_encoder_decoder:
1153
+ labels = labels[:, 1:].clone()
1154
+ logits = logits[:, :-1, :]
1155
+ else:
1156
+ # Fixes end-dec RuntimeError
1157
+ labels = labels.clone()
1158
+
1159
+ loss_mask = labels != label_pad_token_id
1160
+
1161
+ # dummy token; we'll ignore the losses on these tokens later
1162
+ labels[labels == label_pad_token_id] = 0
1163
+
1164
+ per_token_logps = selective_log_softmax(logits, labels)
1165
+
1166
+ if average_log_prob:
1167
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1168
+ else:
1169
+ return (per_token_logps * loss_mask).sum(-1)
1170
+
1171
+ def forward(
1172
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1173
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1174
+ model_kwargs = (
1175
+ {
1176
+ "labels": batch["completion_labels"],
1177
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1178
+ }
1179
+ if self.is_encoder_decoder
1180
+ else {}
1181
+ )
1182
+ if self.aux_loss_enabled:
1183
+ model_kwargs["output_router_logits"] = True
1184
+
1185
+ outputs = model(
1186
+ batch["completion_input_ids"],
1187
+ attention_mask=batch["completion_attention_mask"],
1188
+ **model_kwargs,
1189
+ )
1190
+ completion_logits = outputs.logits
1191
+
1192
+ completion_logps = self.get_batch_logps(
1193
+ completion_logits,
1194
+ batch["completion_labels"],
1195
+ average_log_prob=False,
1196
+ is_encoder_decoder=self.is_encoder_decoder,
1197
+ label_pad_token_id=self.label_pad_token_id,
1198
+ )
1199
+
1200
+ if completion_logps.shape[0] != len(batch["label"]):
1201
+ raise ValueError(
1202
+ "There is a mismatch between the number of examples in this batch and the number of "
1203
+ "examples for which an output sequence was predicted."
1204
+ )
1205
+
1206
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1207
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1208
+
1209
+ chosen_logps = completion_logps[chosen_idx, ...]
1210
+ rejected_logps = completion_logps[rejected_idx, ...]
1211
+
1212
+ chosen_logits = completion_logits[chosen_idx, ...]
1213
+ rejected_logits = completion_logits[rejected_idx, ...]
1214
+
1215
+ if self.aux_loss_enabled:
1216
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1217
+ else:
1218
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1219
+
1220
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1221
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1222
+ min_ratio = self.args.min_density_ratio
1223
+ max_ratio = self.args.max_density_ratio
1224
+
1225
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1226
+
1227
+ return weight
1228
+
1229
+ def bco_loss(
1230
+ self,
1231
+ policy_chosen_logps: torch.FloatTensor,
1232
+ policy_rejected_logps: torch.FloatTensor,
1233
+ reference_chosen_logps: torch.FloatTensor,
1234
+ reference_rejected_logps: torch.FloatTensor,
1235
+ chosen_embeddings: Optional[torch.FloatTensor],
1236
+ rejected_embeddings: Optional[torch.FloatTensor],
1237
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1238
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1239
+
1240
+ Args:
1241
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1242
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1243
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
+ chosen_embeddings: embeddings of desirable prompts
1246
+ rejected_embeddings: embeddings of undesirable prompts
1247
+
1248
+ Returns:
1249
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1250
+ The losses tensor contains the BCO loss for each example in the batch.
1251
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1252
+ The delta value contains the moving average of all implicit rewards.
1253
+ """
1254
+
1255
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1256
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1257
+ chosen_rewards = self.beta * chosen_logratios
1258
+ else:
1259
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1260
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1261
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1262
+
1263
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1264
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1265
+ rejected_rewards = self.beta * rejected_logratios
1266
+ else:
1267
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1268
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1269
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1270
+
1271
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1272
+ self.running.update(rewards)
1273
+ delta = self.running.mean
1274
+
1275
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1276
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1277
+
1278
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1279
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1280
+
1281
+ if self.match_underlying_distribution:
1282
+ chosen_weight = torch.ones_like(chosen_losses)
1283
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1284
+
1285
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1286
+ else:
1287
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1288
+
1289
+ return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1290
+
1291
+ def get_batch_loss_metrics(
1292
+ self,
1293
+ model,
1294
+ batch: dict[str, Union[list, torch.LongTensor]],
1295
+ ):
1296
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1297
+ metrics = {}
1298
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1299
+
1300
+ forward_output = self.forward(model, batch)
1301
+ (
1302
+ policy_chosen_logps,
1303
+ policy_rejected_logps,
1304
+ policy_chosen_logits,
1305
+ policy_rejected_logits,
1306
+ ) = forward_output[:4]
1307
+ if self.aux_loss_enabled:
1308
+ aux_loss = forward_output[4]
1309
+
1310
+ # if reference_logps in batch use them, otherwise use the reference model
1311
+ if "reference_logps" in batch:
1312
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1313
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1314
+
1315
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1316
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1317
+ else:
1318
+ with torch.no_grad():
1319
+ if self.ref_model is None:
1320
+ with self.null_ref_context():
1321
+ (
1322
+ reference_chosen_logps,
1323
+ reference_rejected_logps,
1324
+ _,
1325
+ _,
1326
+ ) = self.forward(self.model, batch)[:4]
1327
+ else:
1328
+ (
1329
+ reference_chosen_logps,
1330
+ reference_rejected_logps,
1331
+ _,
1332
+ _,
1333
+ ) = self.forward(self.ref_model, batch)[:4]
1334
+
1335
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1336
+
1337
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1338
+ policy_chosen_logps,
1339
+ policy_rejected_logps,
1340
+ reference_chosen_logps,
1341
+ reference_rejected_logps,
1342
+ chosen_embeddings,
1343
+ rejected_embeddings,
1344
+ )
1345
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1346
+
1347
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1348
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1349
+
1350
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1351
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1352
+
1353
+ if all_num_chosen > 0:
1354
+ metrics["rewards/chosen_sum"] = (
1355
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1356
+ )
1357
+ metrics["logps/chosen_sum"] = (
1358
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1359
+ )
1360
+ metrics["logits/chosen_sum"] = (
1361
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1362
+ )
1363
+ metrics["count/chosen"] = all_num_chosen
1364
+
1365
+ if all_num_rejected > 0:
1366
+ metrics["rewards/rejected_sum"] = (
1367
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1368
+ )
1369
+ metrics["logps/rejected_sum"] = (
1370
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1371
+ )
1372
+ metrics["logits/rejected_sum"] = (
1373
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1374
+ )
1375
+ metrics["count/rejected"] = all_num_rejected
1376
+
1377
+ loss = losses.nanmean()
1378
+ if self.aux_loss_enabled:
1379
+ loss += self.aux_loss_coef * aux_loss
1380
+
1381
+ return loss, metrics
1382
+
1383
+ def compute_loss(
1384
+ self,
1385
+ model: Union[PreTrainedModel, nn.Module],
1386
+ inputs: dict[str, Union[torch.Tensor, Any]],
1387
+ return_outputs=False,
1388
+ num_items_in_batch=None,
1389
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1390
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1391
+
1392
+ with compute_loss_context_manager:
1393
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1394
+
1395
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1396
+ loss = loss.to(self.args.device)
1397
+ # force log the metrics
1398
+ if self.accelerator.is_main_process:
1399
+ self.store_metrics(metrics, train_eval="train")
1400
+
1401
+ if return_outputs:
1402
+ return (loss, metrics)
1403
+ return loss
1404
+
1405
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1406
+ for key, value in metrics.items():
1407
+ self._stored_metrics[train_eval][key].append(value)
1408
+
1409
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1410
+ if self.train_dataset is None or not has_length(self.train_dataset):
1411
+ return None
1412
+ return SequentialSampler(self.train_dataset)
1413
+
1414
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1415
+ """Generate samples from the model and reference model for the given batch of inputs."""
1416
+
1417
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1418
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1419
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1420
+ with generate_context_manager:
1421
+ policy_output = model.generate(
1422
+ input_ids=batch["prompt_input_ids"],
1423
+ attention_mask=batch["prompt_attention_mask"],
1424
+ max_length=self.max_length,
1425
+ do_sample=True,
1426
+ pad_token_id=self.processing_class.pad_token_id,
1427
+ )
1428
+
1429
+ # if reference_output in batch use that otherwise use the reference model
1430
+ if "reference_output" in batch:
1431
+ reference_output = batch["reference_output"]
1432
+ else:
1433
+ if self.ref_model is None:
1434
+ with self.null_ref_context():
1435
+ reference_output = self.model.generate(
1436
+ input_ids=batch["prompt_input_ids"],
1437
+ attention_mask=batch["prompt_attention_mask"],
1438
+ max_length=self.max_length,
1439
+ do_sample=True,
1440
+ pad_token_id=self.processing_class.pad_token_id,
1441
+ )
1442
+ else:
1443
+ reference_output = self.ref_model.generate(
1444
+ input_ids=batch["prompt_input_ids"],
1445
+ attention_mask=batch["prompt_attention_mask"],
1446
+ max_length=self.max_length,
1447
+ do_sample=True,
1448
+ pad_token_id=self.processing_class.pad_token_id,
1449
+ )
1450
+
1451
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1452
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1453
+
1454
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1455
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1456
+
1457
+ return policy_output_decoded, reference_output_decoded
1458
+
1459
+ def prediction_step(
1460
+ self,
1461
+ model: Union[PreTrainedModel, nn.Module],
1462
+ inputs: dict[str, Union[torch.Tensor, Any]],
1463
+ prediction_loss_only: bool,
1464
+ ignore_keys: Optional[list[str]] = None,
1465
+ ):
1466
+ if ignore_keys is None:
1467
+ if hasattr(model, "config"):
1468
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1469
+ else:
1470
+ ignore_keys = []
1471
+
1472
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1473
+ with torch.no_grad(), prediction_context_manager:
1474
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1475
+
1476
+ # force log the metrics
1477
+ if self.accelerator.is_main_process:
1478
+ self.store_metrics(metrics, train_eval="eval")
1479
+
1480
+ if prediction_loss_only:
1481
+ return (loss.detach(), None, None)
1482
+
1483
+ # logits for the chosen and rejected samples from model
1484
+ logits_dict = {
1485
+ "eval_logits/chosen": metrics["logits/chosen"],
1486
+ "eval_logits/rejected": metrics["logits/rejected"],
1487
+ }
1488
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1489
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1490
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1491
+
1492
+ return (loss.detach(), logits, labels)
1493
+
1494
+ def evaluation_loop(
1495
+ self,
1496
+ dataloader: DataLoader,
1497
+ description: str,
1498
+ prediction_loss_only: Optional[bool] = None,
1499
+ ignore_keys: Optional[list[str]] = None,
1500
+ metric_key_prefix: str = "eval",
1501
+ ) -> EvalLoopOutput:
1502
+ """
1503
+ Overriding built-in evaluation loop to store metrics for each batch.
1504
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1505
+
1506
+ Works both with or without labels.
1507
+ """
1508
+
1509
+ # Sample and save to game log if requested (for one batch to save time)
1510
+ if self.generate_during_eval:
1511
+ # Generate random indices within the range of the total number of samples
1512
+ num_samples = len(dataloader.dataset)
1513
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1514
+
1515
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1516
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1517
+ random_batch = self.data_collator(random_batch_dataset)
1518
+ random_batch = self._prepare_inputs(random_batch)
1519
+
1520
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1521
+ target_batch = {
1522
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1523
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1524
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1525
+ }
1526
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1527
+
1528
+ table = pd.DataFrame(
1529
+ columns=["Prompt", "Policy", "Ref Model"],
1530
+ data=[
1531
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1532
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1533
+ ],
1534
+ )
1535
+ if "wandb" in self.args.report_to:
1536
+ wandb.log({"game_log": wandb.Table(data=table)})
1537
+
1538
+ if "comet_ml" in self.args.report_to:
1539
+ log_table_to_comet_experiment(
1540
+ name="game_log.csv",
1541
+ table=table,
1542
+ )
1543
+
1544
+ # Base evaluation
1545
+ initial_output = super().evaluation_loop(
1546
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1547
+ )
1548
+
1549
+ return initial_output
1550
+
1551
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1552
+ """
1553
+ Log `logs` on the various objects watching training, including stored metrics.
1554
+
1555
+ Args:
1556
+ logs (`dict[str, float]`):
1557
+ The values to log.
1558
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1559
+ Start time of the training.
1560
+ """
1561
+ # logs either has 'loss' or 'eval_loss'
1562
+ train_eval = "train" if "loss" in logs else "eval"
1563
+ # train metrics should have no prefix, eval should have 'eval_'
1564
+ prefix = "eval_" if train_eval == "eval" else ""
1565
+ # accumulate average metrics from sums and lengths
1566
+ for split in ["chosen", "rejected"]:
1567
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1568
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1569
+ for metric in ["rewards", "logps", "logits"]:
1570
+ logs[f"{prefix}{metric}/{split}"] = (
1571
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1572
+ / count_sum
1573
+ )
1574
+ # delete obsolete metric
1575
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1576
+ del self._stored_metrics[train_eval][f"count/{split}"]
1577
+ # calculate reward margin
1578
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1579
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1580
+ # Add averaged stored metrics to logs
1581
+ for key, metrics in self._stored_metrics[train_eval].items():
1582
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1583
+ del self._stored_metrics[train_eval]
1584
+
1585
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1586
+ return super().log(logs, start_time)
1587
+ else: # transformers<=4.46
1588
+ return super().log(logs)
1589
+
1590
+ def create_model_card(
1591
+ self,
1592
+ model_name: Optional[str] = None,
1593
+ dataset_name: Optional[str] = None,
1594
+ tags: Union[str, list[str], None] = None,
1595
+ ):
1596
+ """
1597
+ Creates a draft of a model card using the information available to the `Trainer`.
1598
+
1599
+ Args:
1600
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1601
+ Name of the model.
1602
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1603
+ Name of the dataset used for training.
1604
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1605
+ Tags to be associated with the model card.
1606
+ """
1607
+ if not self.is_world_process_zero():
1608
+ return
1609
+
1610
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1611
+ base_model = self.model.config._name_or_path
1612
+ else:
1613
+ base_model = None
1614
+
1615
+ tags = tags or []
1616
+ if isinstance(tags, str):
1617
+ tags = [tags]
1618
+
1619
+ if hasattr(self.model.config, "unsloth_version"):
1620
+ tags.append("unsloth")
1621
+
1622
+ citation = textwrap.dedent("""\
1623
+ @article{jung2024binary,
1624
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1625
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1626
+ year = 2024,
1627
+ eprint = {arXiv:2404.04656}
1628
+ }""")
1629
+
1630
+ model_card = generate_model_card(
1631
+ base_model=base_model,
1632
+ model_name=model_name,
1633
+ hub_model_id=self.hub_model_id,
1634
+ dataset_name=dataset_name,
1635
+ tags=tags,
1636
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1637
+ comet_url=get_comet_experiment_url(),
1638
+ trainer_name="BCO",
1639
+ trainer_citation=citation,
1640
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1641
+ paper_id="2404.04656",
1642
+ )
1643
+
1644
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1645
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1646
+ """
1647
+
1648
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1649
+
1650
+ Args:
1651
+ model (`transformers.PreTrainedModel`):
1652
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1653
+ ref_model (`PreTrainedModelWrapper`):
1654
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1655
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1656
+ args (`BCOConfig`):
1657
+ The arguments to use for training.
1658
+ train_dataset (`datasets.Dataset`):
1659
+ The dataset to use for training.
1660
+ eval_dataset (`datasets.Dataset`):
1661
+ The dataset to use for evaluation.
1662
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1663
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1664
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1665
+ reuse the fine-tuned model.
1666
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1667
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1668
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1669
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1670
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1671
+ callbacks (`list[transformers.TrainerCallback]`):
1672
+ The callbacks to use for training.
1673
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1674
+ The optimizer and scheduler to use for training.
1675
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1676
+ The function to use to preprocess the logits before computing the metrics.
1677
+ peft_config (`dict`, defaults to `None`):
1678
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1679
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1680
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1681
+ a dictionary string to metric values.
1682
+ model_adapter_name (`str`, defaults to `None`):
1683
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1684
+ ref_adapter_name (`str`, defaults to `None`):
1685
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1686
+
1687
+ """
1688
+ def __init__(
1689
+ self,
1690
+ model = None,
1691
+ ref_model = None,
1692
+ args = None,
1693
+ train_dataset = None,
1694
+ eval_dataset = None,
1695
+ processing_class = None,
1696
+ data_collator = None,
1697
+ model_init = None,
1698
+ callbacks = None,
1699
+ preprocess_logits_for_metrics = None,
1700
+ peft_config = None,
1701
+ compute_metrics = None,
1702
+ model_adapter_name = None,
1703
+ ref_adapter_name = None,
1704
+ embedding_func = None,
1705
+ embedding_tokenizer = None,
1706
+ **kwargs
1707
+ ):
1708
+ if args is None: args = UnslothBCOConfig()
1709
+ use_bf16 = getattr(args, 'bf16', False)
1710
+ use_fp16 = getattr(args, 'fp16', False)
1711
+ force_float32 = False
1712
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1713
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1714
+ force_float32 = True
1715
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1716
+ dtype = getattr(model.config, 'torch_dtype', None)
1717
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1718
+ from unsloth_zoo.utils import _get_dtype
1719
+ dtype = _get_dtype(dtype)
1720
+ float16 = dtype == torch.float16
1721
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1722
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1723
+ if force_float32:
1724
+ args.fp16 = False
1725
+ args.bf16 = False
1726
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1727
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1728
+ args.fp16 = float16
1729
+ args.bf16 = not float16
1730
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1731
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1732
+ args.eval_strategy = 'steps'
1733
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1734
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1735
+ if ga_steps is not None and ga_steps > 1:
1736
+ from transformers import __version__ as transformers_version
1737
+ if Version(transformers_version) <= Version('4.45.2'):
1738
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1739
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1740
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1741
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1742
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1743
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1744
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1745
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1746
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1747
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1748
+ if force_float32:
1749
+ args.bf16_full_eval = False
1750
+ args.fp16_full_eval = False
1751
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1752
+ args.bf16_full_eval = True
1753
+ args.fp16_full_eval = False
1754
+ elif not bf16_full_eval and not fp16_full_eval:
1755
+ args.bf16_full_eval = args.bf16
1756
+ args.fp16_full_eval = args.fp16
1757
+ _output_logits = False
1758
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1759
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1760
+ if _output_logits:
1761
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1762
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1763
+ pass
1764
+ else:
1765
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1766
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1767
+ if args_max_seq_length is None and model_max_seq_length is not None:
1768
+ max_seq_length = model.max_seq_length
1769
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1770
+ if model is not None and hasattr(model, 'for_training'):
1771
+ model.for_training()
1772
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1773
+ if 'processing_class' in locals():
1774
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1775
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1776
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1777
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1778
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1779
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1780
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1781
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1782
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1783
+ else:
1784
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1785
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1786
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1787
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1788
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1789
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1790
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1791
+ else:
1792
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1793
+ other_metrics = []
1794
+
1795
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1796
+ PatchRLStatistics('bco_trainer', other_metrics)
1797
+
1798
+ super().__init__(
1799
+ model = model,
1800
+ ref_model = ref_model,
1801
+ args = args,
1802
+ train_dataset = train_dataset,
1803
+ eval_dataset = eval_dataset,
1804
+ processing_class = processing_class,
1805
+ data_collator = data_collator,
1806
+ model_init = model_init,
1807
+ callbacks = callbacks,
1808
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1809
+ peft_config = peft_config,
1810
+ compute_metrics = compute_metrics,
1811
+ model_adapter_name = model_adapter_name,
1812
+ ref_adapter_name = ref_adapter_name,
1813
+ embedding_func = embedding_func,
1814
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1815
+ if hasattr(self, 'neftune_hook_handle'):
1816
+ self.neftune_hook_handle.remove()
1817
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1818
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1819
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1820
+ pass
1821
+
1822
+ pass
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothCPOConfig(CPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`CPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
67
+ the [paper](https://huggingface.co/papers/2310.12036).
68
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
69
+ Label smoothing factor. This argument is required if you want to use the default data collator.
70
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
71
+ Type of loss to use. Possible values are:
72
+
73
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
74
+ - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
75
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
76
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
77
+
78
+ disable_dropout (`bool`, *optional*, defaults to `True`):
79
+ Whether to disable dropout in the model.
80
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
81
+ Weight of the BC regularizer in CPO training.
82
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
83
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
84
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
85
+ Label pad token id. This argument is required if you want to use the default data collator.
86
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
87
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
88
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
89
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
90
+ This argument is required if you want to use the default data collator.
91
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
92
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
93
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
94
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
95
+ you need to specify if the model returned by the callable is an encoder-decoder model.
96
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
98
+ string.
99
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+
102
+ """
103
+ vllm_sampling_params: Optional[Any] = field(
104
+ default = None,
105
+ metadata = {'help': 'vLLM SamplingParams'},
106
+ )
107
+ unsloth_num_chunks : Optional[int] = field(
108
+ default = -1,
109
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
+ )
111
+ def __init__(
112
+ self,
113
+ output_dir = None,
114
+ overwrite_output_dir = None,
115
+ do_train = False,
116
+ do_eval = False,
117
+ do_predict = False,
118
+ eval_strategy = 'no',
119
+ prediction_loss_only = False,
120
+ per_device_train_batch_size = 4,
121
+ per_device_eval_batch_size = 4,
122
+ per_gpu_train_batch_size = None,
123
+ per_gpu_eval_batch_size = None,
124
+ gradient_accumulation_steps = 2,
125
+ eval_accumulation_steps = 2,
126
+ eval_delay = 0,
127
+ torch_empty_cache_steps = 250,
128
+ learning_rate = 5e-05,
129
+ weight_decay = 0.01,
130
+ adam_beta1 = 0.9,
131
+ adam_beta2 = 0.999,
132
+ adam_epsilon = 1e-08,
133
+ max_grad_norm = 1.0,
134
+ num_train_epochs = 3.0,
135
+ max_steps = -1,
136
+ lr_scheduler_type = 'linear',
137
+ warmup_ratio = 0.1,
138
+ warmup_steps = 0,
139
+ log_level = 'passive',
140
+ log_level_replica = 'warning',
141
+ log_on_each_node = True,
142
+ logging_dir = None,
143
+ logging_strategy = 'steps',
144
+ logging_first_step = False,
145
+ logging_steps = 1,
146
+ logging_nan_inf_filter = False,
147
+ save_strategy = 'steps',
148
+ save_steps = 500,
149
+ save_total_limit = None,
150
+ save_safetensors = True,
151
+ save_on_each_node = False,
152
+ save_only_model = False,
153
+ restore_callback_states_from_checkpoint = False,
154
+ no_cuda = False,
155
+ use_cpu = False,
156
+ use_mps_device = False,
157
+ seed = 3407,
158
+ data_seed = 3407,
159
+ jit_mode_eval = False,
160
+ use_ipex = False,
161
+ bf16 = False,
162
+ fp16 = False,
163
+ fp16_opt_level = 'O1',
164
+ half_precision_backend = 'auto',
165
+ bf16_full_eval = False,
166
+ fp16_full_eval = False,
167
+ tf32 = None,
168
+ local_rank = -1,
169
+ ddp_backend = None,
170
+ tpu_num_cores = None,
171
+ tpu_metrics_debug = False,
172
+ debug = '',
173
+ dataloader_drop_last = False,
174
+ eval_steps = None,
175
+ dataloader_num_workers = 0,
176
+ dataloader_prefetch_factor = None,
177
+ past_index = -1,
178
+ run_name = None,
179
+ disable_tqdm = None,
180
+ remove_unused_columns = True,
181
+ label_names = None,
182
+ load_best_model_at_end = False,
183
+ metric_for_best_model = None,
184
+ greater_is_better = None,
185
+ ignore_data_skip = False,
186
+ fsdp = '',
187
+ fsdp_min_num_params = 0,
188
+ fsdp_config = None,
189
+ fsdp_transformer_layer_cls_to_wrap = None,
190
+ accelerator_config = None,
191
+ deepspeed = None,
192
+ label_smoothing_factor = 0.0,
193
+ optim = 'adamw_8bit',
194
+ optim_args = None,
195
+ adafactor = False,
196
+ group_by_length = False,
197
+ length_column_name = 'length',
198
+ report_to = None,
199
+ ddp_find_unused_parameters = None,
200
+ ddp_bucket_cap_mb = None,
201
+ ddp_broadcast_buffers = None,
202
+ dataloader_pin_memory = True,
203
+ dataloader_persistent_workers = False,
204
+ skip_memory_metrics = True,
205
+ use_legacy_prediction_loop = False,
206
+ push_to_hub = False,
207
+ resume_from_checkpoint = None,
208
+ hub_model_id = None,
209
+ hub_strategy = 'every_save',
210
+ hub_token = None,
211
+ hub_private_repo = None,
212
+ hub_always_push = False,
213
+ gradient_checkpointing = False,
214
+ gradient_checkpointing_kwargs = None,
215
+ include_inputs_for_metrics = False,
216
+ eval_do_concat_batches = True,
217
+ fp16_backend = 'auto',
218
+ evaluation_strategy = None,
219
+ push_to_hub_model_id = None,
220
+ push_to_hub_organization = None,
221
+ push_to_hub_token = None,
222
+ mp_parameters = '',
223
+ auto_find_batch_size = False,
224
+ full_determinism = False,
225
+ torchdynamo = None,
226
+ ray_scope = 'last',
227
+ ddp_timeout = 1800,
228
+ torch_compile = False,
229
+ torch_compile_backend = None,
230
+ torch_compile_mode = None,
231
+ dispatch_batches = None,
232
+ split_batches = None,
233
+ include_tokens_per_second = False,
234
+ include_num_input_tokens_seen = False,
235
+ neftune_noise_alpha = None,
236
+ optim_target_modules = None,
237
+ batch_eval_metrics = False,
238
+ eval_on_start = False,
239
+ use_liger_kernel = False,
240
+ eval_use_gather_object = False,
241
+ average_tokens_across_devices = False,
242
+ max_length = 1024,
243
+ max_prompt_length = 512,
244
+ max_completion_length = None,
245
+ beta = 0.1,
246
+ label_smoothing = 0.0,
247
+ loss_type = 'sigmoid',
248
+ disable_dropout = True,
249
+ cpo_alpha = 1.0,
250
+ simpo_gamma = 0.5,
251
+ label_pad_token_id = -100,
252
+ padding_value = None,
253
+ truncation_mode = 'keep_end',
254
+ generate_during_eval = False,
255
+ is_encoder_decoder = None,
256
+ model_init_kwargs = None,
257
+ dataset_num_proc = None,
258
+ vllm_sampling_params = None,
259
+ unsloth_num_chunks = -1,
260
+ **kwargs,
261
+ ):
262
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
263
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
264
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
265
+ output_dir = 'unsloth_training_checkpoints'
266
+ save_strategy = 'no'
267
+ if dataset_num_proc is None:
268
+ from multiprocessing import cpu_count
269
+ dataset_num_proc = cpu_count()
270
+
271
+ super().__init__(
272
+ output_dir = output_dir,
273
+ overwrite_output_dir = overwrite_output_dir,
274
+ do_train = do_train,
275
+ do_eval = do_eval,
276
+ do_predict = do_predict,
277
+ eval_strategy = eval_strategy,
278
+ prediction_loss_only = prediction_loss_only,
279
+ per_device_train_batch_size = per_device_train_batch_size,
280
+ per_device_eval_batch_size = per_device_eval_batch_size,
281
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
282
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
283
+ gradient_accumulation_steps = gradient_accumulation_steps,
284
+ eval_accumulation_steps = eval_accumulation_steps,
285
+ eval_delay = eval_delay,
286
+ torch_empty_cache_steps = torch_empty_cache_steps,
287
+ learning_rate = learning_rate,
288
+ weight_decay = weight_decay,
289
+ adam_beta1 = adam_beta1,
290
+ adam_beta2 = adam_beta2,
291
+ adam_epsilon = adam_epsilon,
292
+ max_grad_norm = max_grad_norm,
293
+ num_train_epochs = num_train_epochs,
294
+ max_steps = max_steps,
295
+ lr_scheduler_type = lr_scheduler_type,
296
+ warmup_ratio = warmup_ratio,
297
+ warmup_steps = warmup_steps,
298
+ log_level = log_level,
299
+ log_level_replica = log_level_replica,
300
+ log_on_each_node = log_on_each_node,
301
+ logging_dir = logging_dir,
302
+ logging_strategy = logging_strategy,
303
+ logging_first_step = logging_first_step,
304
+ logging_steps = logging_steps,
305
+ logging_nan_inf_filter = logging_nan_inf_filter,
306
+ save_strategy = save_strategy,
307
+ save_steps = save_steps,
308
+ save_total_limit = save_total_limit,
309
+ save_safetensors = save_safetensors,
310
+ save_on_each_node = save_on_each_node,
311
+ save_only_model = save_only_model,
312
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
313
+ no_cuda = no_cuda,
314
+ use_cpu = use_cpu,
315
+ use_mps_device = use_mps_device,
316
+ seed = seed,
317
+ data_seed = data_seed,
318
+ jit_mode_eval = jit_mode_eval,
319
+ use_ipex = use_ipex,
320
+ bf16 = bf16,
321
+ fp16 = fp16,
322
+ fp16_opt_level = fp16_opt_level,
323
+ half_precision_backend = half_precision_backend,
324
+ bf16_full_eval = bf16_full_eval,
325
+ fp16_full_eval = fp16_full_eval,
326
+ tf32 = tf32,
327
+ local_rank = local_rank,
328
+ ddp_backend = ddp_backend,
329
+ tpu_num_cores = tpu_num_cores,
330
+ tpu_metrics_debug = tpu_metrics_debug,
331
+ debug = debug,
332
+ dataloader_drop_last = dataloader_drop_last,
333
+ eval_steps = eval_steps,
334
+ dataloader_num_workers = dataloader_num_workers,
335
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
336
+ past_index = past_index,
337
+ run_name = run_name,
338
+ disable_tqdm = disable_tqdm,
339
+ remove_unused_columns = remove_unused_columns,
340
+ label_names = label_names,
341
+ load_best_model_at_end = load_best_model_at_end,
342
+ metric_for_best_model = metric_for_best_model,
343
+ greater_is_better = greater_is_better,
344
+ ignore_data_skip = ignore_data_skip,
345
+ fsdp = fsdp,
346
+ fsdp_min_num_params = fsdp_min_num_params,
347
+ fsdp_config = fsdp_config,
348
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
349
+ accelerator_config = accelerator_config,
350
+ deepspeed = deepspeed,
351
+ label_smoothing_factor = label_smoothing_factor,
352
+ optim = optim,
353
+ optim_args = optim_args,
354
+ adafactor = adafactor,
355
+ group_by_length = group_by_length,
356
+ length_column_name = length_column_name,
357
+ report_to = report_to,
358
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
359
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
360
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
361
+ dataloader_pin_memory = dataloader_pin_memory,
362
+ dataloader_persistent_workers = dataloader_persistent_workers,
363
+ skip_memory_metrics = skip_memory_metrics,
364
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
365
+ push_to_hub = push_to_hub,
366
+ resume_from_checkpoint = resume_from_checkpoint,
367
+ hub_model_id = hub_model_id,
368
+ hub_strategy = hub_strategy,
369
+ hub_token = hub_token,
370
+ hub_private_repo = hub_private_repo,
371
+ hub_always_push = hub_always_push,
372
+ gradient_checkpointing = gradient_checkpointing,
373
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
374
+ include_inputs_for_metrics = include_inputs_for_metrics,
375
+ eval_do_concat_batches = eval_do_concat_batches,
376
+ fp16_backend = fp16_backend,
377
+ evaluation_strategy = evaluation_strategy,
378
+ push_to_hub_model_id = push_to_hub_model_id,
379
+ push_to_hub_organization = push_to_hub_organization,
380
+ push_to_hub_token = push_to_hub_token,
381
+ mp_parameters = mp_parameters,
382
+ auto_find_batch_size = auto_find_batch_size,
383
+ full_determinism = full_determinism,
384
+ torchdynamo = torchdynamo,
385
+ ray_scope = ray_scope,
386
+ ddp_timeout = ddp_timeout,
387
+ torch_compile = torch_compile,
388
+ torch_compile_backend = torch_compile_backend,
389
+ torch_compile_mode = torch_compile_mode,
390
+ dispatch_batches = dispatch_batches,
391
+ split_batches = split_batches,
392
+ include_tokens_per_second = include_tokens_per_second,
393
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
394
+ neftune_noise_alpha = neftune_noise_alpha,
395
+ optim_target_modules = optim_target_modules,
396
+ batch_eval_metrics = batch_eval_metrics,
397
+ eval_on_start = eval_on_start,
398
+ use_liger_kernel = use_liger_kernel,
399
+ eval_use_gather_object = eval_use_gather_object,
400
+ average_tokens_across_devices = average_tokens_across_devices,
401
+ max_length = max_length,
402
+ max_prompt_length = max_prompt_length,
403
+ max_completion_length = max_completion_length,
404
+ beta = beta,
405
+ label_smoothing = label_smoothing,
406
+ loss_type = loss_type,
407
+ disable_dropout = disable_dropout,
408
+ cpo_alpha = cpo_alpha,
409
+ simpo_gamma = simpo_gamma,
410
+ label_pad_token_id = label_pad_token_id,
411
+ padding_value = padding_value,
412
+ truncation_mode = truncation_mode,
413
+ generate_during_eval = generate_during_eval,
414
+ is_encoder_decoder = is_encoder_decoder,
415
+ model_init_kwargs = model_init_kwargs,
416
+ dataset_num_proc = dataset_num_proc,**kwargs)
417
+ self.vllm_sampling_params = vllm_sampling_params
418
+ self.unsloth_num_chunks = unsloth_num_chunks
419
+ pass
420
+
421
+ class _UnslothCPOTrainer(Trainer):
422
+ r""""""
423
+
424
+ _tag_names = ["trl", "cpo"]
425
+
426
+ def __init__(
427
+ self,
428
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
429
+ args: Optional[CPOConfig] = None,
430
+ data_collator: Optional[DataCollator] = None,
431
+ train_dataset: Optional[Dataset] = None,
432
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
433
+ processing_class: Optional[
434
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
435
+ ] = None,
436
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
437
+ callbacks: Optional[list[TrainerCallback]] = None,
438
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
439
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
440
+ peft_config: Optional[dict] = None,
441
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
442
+ ):
443
+ if args.model_init_kwargs is None:
444
+ model_init_kwargs = {}
445
+ elif not isinstance(model, str):
446
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
447
+ else:
448
+ model_init_kwargs = args.model_init_kwargs
449
+ torch_dtype = model_init_kwargs.get("torch_dtype")
450
+ if torch_dtype is not None:
451
+ # Convert to `torch.dtype` if an str is passed
452
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
453
+ torch_dtype = getattr(torch, torch_dtype)
454
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
455
+ raise ValueError(
456
+ f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
457
+ )
458
+ model_init_kwargs["torch_dtype"] = torch_dtype
459
+
460
+ if isinstance(model, str):
461
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
462
+
463
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
464
+ # has been called in order to properly call autocast if needed.
465
+ self._peft_has_been_casted_to_bf16 = False
466
+
467
+ if not is_peft_available() and peft_config is not None:
468
+ raise ValueError(
469
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
470
+ )
471
+ elif is_peft_available() and peft_config is not None:
472
+ # if model is a peft model and we have a peft_config, we merge and unload it first
473
+ if isinstance(model, PeftModel):
474
+ model = model.merge_and_unload()
475
+
476
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
477
+ _support_gc_kwargs = hasattr(
478
+ args, "gradient_checkpointing_kwargs"
479
+ ) and "gradient_checkpointing_kwargs" in list(
480
+ inspect.signature(prepare_model_for_kbit_training).parameters
481
+ )
482
+
483
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
484
+
485
+ if _support_gc_kwargs:
486
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
487
+
488
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
489
+ elif getattr(args, "gradient_checkpointing", False):
490
+ # For backward compatibility with older versions of transformers
491
+ if hasattr(model, "enable_input_require_grads"):
492
+ model.enable_input_require_grads()
493
+ else:
494
+
495
+ def make_inputs_require_grad(module, input, output):
496
+ output.requires_grad_(True)
497
+
498
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
499
+
500
+ # get peft model with the given config
501
+ model = model
502
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
503
+ peft_module_casting_to_bf16(model)
504
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
505
+ self._peft_has_been_casted_to_bf16 = True
506
+
507
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
508
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
509
+ # fail or completely fail.
510
+ elif getattr(args, "gradient_checkpointing", False):
511
+ # For backward compatibility with older versions of transformers
512
+ if hasattr(model, "enable_input_require_grads"):
513
+ model.enable_input_require_grads()
514
+ else:
515
+
516
+ def make_inputs_require_grad(module, input, output):
517
+ output.requires_grad_(True)
518
+
519
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
520
+
521
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
522
+ raise ValueError(
523
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
524
+ " Please install `wandb` or `comet-ml` to resolve."
525
+ )
526
+
527
+ if model is not None:
528
+ self.is_encoder_decoder = model.config.is_encoder_decoder
529
+ elif args.is_encoder_decoder is None:
530
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
531
+ else:
532
+ self.is_encoder_decoder = args.is_encoder_decoder
533
+
534
+ if self.is_encoder_decoder:
535
+ self.decoder_start_token_id = model.config.decoder_start_token_id
536
+ self.pad_token_id = model.config.pad_token_id
537
+
538
+ if processing_class is None:
539
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
540
+ if args.max_length is None:
541
+ warnings.warn(
542
+ "`max_length` is not set in the CPOConfig's init"
543
+ " it will default to `512` by default, but you should do it yourself in the future.",
544
+ UserWarning,
545
+ )
546
+ max_length = 512
547
+ else:
548
+ max_length = args.max_length
549
+ if args.max_prompt_length is None:
550
+ warnings.warn(
551
+ "`max_prompt_length` is not set in the CPOConfig's init"
552
+ " it will default to `128` by default, but you should do it yourself in the future.",
553
+ UserWarning,
554
+ )
555
+ max_prompt_length = 128
556
+ else:
557
+ max_prompt_length = args.max_prompt_length
558
+
559
+ if args.max_completion_length is None and self.is_encoder_decoder:
560
+ warnings.warn(
561
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
562
+ " it will default to `128` by default, but you should do it yourself in the future.",
563
+ UserWarning,
564
+ )
565
+ max_completion_length = 128
566
+ else:
567
+ max_completion_length = args.max_completion_length
568
+
569
+ if data_collator is None:
570
+ data_collator = DPODataCollatorWithPadding(
571
+ pad_token_id=processing_class.pad_token_id,
572
+ label_pad_token_id=args.label_pad_token_id,
573
+ is_encoder_decoder=self.is_encoder_decoder,
574
+ )
575
+
576
+ if args.remove_unused_columns:
577
+ args.remove_unused_columns = False
578
+ # warn users
579
+ warnings.warn(
580
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
581
+ " we have set it for you, but you should do it yourself in the future.",
582
+ UserWarning,
583
+ )
584
+
585
+ self.use_dpo_data_collator = True
586
+ else:
587
+ self.use_dpo_data_collator = False
588
+
589
+ # Disable dropout in the model
590
+ if args.disable_dropout:
591
+ disable_dropout_in_model(model)
592
+
593
+ self.max_length = max_length
594
+ self.generate_during_eval = args.generate_during_eval
595
+ self.label_pad_token_id = args.label_pad_token_id
596
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
597
+ self.max_prompt_length = max_prompt_length
598
+ self.truncation_mode = args.truncation_mode
599
+ self.max_completion_length = max_completion_length
600
+ self.processing_class = processing_class
601
+
602
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
603
+ warnings.warn(
604
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
605
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
606
+ UserWarning,
607
+ )
608
+ if args.loss_type == "kto_pair":
609
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
610
+
611
+ self.beta = args.beta
612
+ self.label_smoothing = args.label_smoothing
613
+ self.loss_type = args.loss_type
614
+ self.cpo_alpha = args.cpo_alpha
615
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
616
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
617
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
618
+ warnings.warn(
619
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
620
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
621
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
622
+ "loss.",
623
+ UserWarning,
624
+ )
625
+
626
+ if args.loss_type == "simpo":
627
+ self.simpo_gamma = args.simpo_gamma
628
+
629
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
630
+
631
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
632
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
633
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
634
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
635
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
636
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
637
+ # that the warning has already been issued.
638
+ model.warnings_issued["estimate_tokens"] = True
639
+
640
+ # Compute that only on the main process for faster data processing.
641
+ # see: https://github.com/huggingface/trl/pull/1255
642
+ with PartialState().local_main_process_first():
643
+ # Extract the prompt if needed, and apply the chat template if needed
644
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
645
+ train_dataset = train_dataset.map(
646
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
647
+ )
648
+ if eval_dataset is not None:
649
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
650
+ eval_dataset = eval_dataset.map(
651
+ maybe_apply_chat_template,
652
+ fn_kwargs={"tokenizer": processing_class},
653
+ num_proc=args.dataset_num_proc,
654
+ )
655
+
656
+ # tokenize the dataset
657
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
658
+ if eval_dataset is not None:
659
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
660
+
661
+ super().__init__(
662
+ model=model,
663
+ args=args,
664
+ data_collator=data_collator,
665
+ train_dataset=train_dataset,
666
+ eval_dataset=eval_dataset,
667
+ processing_class=processing_class,
668
+ model_init=model_init,
669
+ compute_metrics=compute_metrics,
670
+ callbacks=callbacks,
671
+ optimizers=optimizers,
672
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
673
+ )
674
+
675
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
676
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
677
+ # self.model_accepts_loss_kwargs to False to enable scaling.
678
+ self.model_accepts_loss_kwargs = False
679
+
680
+ # Add tags for models that have been loaded with the correct transformers version
681
+ if hasattr(self.model, "add_model_tags"):
682
+ self.model.add_model_tags(self._tag_names)
683
+
684
+ if not hasattr(self, "accelerator"):
685
+ raise AttributeError(
686
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
687
+ )
688
+
689
+ def build_tokenized_answer(self, prompt, answer):
690
+ """
691
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
692
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
693
+ Reference:
694
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
695
+ """
696
+
697
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
698
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
699
+
700
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
701
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
702
+
703
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
704
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
705
+
706
+ # Prepare input tokens for token by token comparison
707
+ full_input_ids = np.array(full_tokenized["input_ids"])
708
+
709
+ if len(full_input_ids) != len(full_concat_input_ids):
710
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
711
+
712
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
713
+ # can be merged together when tokenizing prompt+answer. This could result
714
+ # on the last token from the prompt being different when tokenized on its own
715
+ # vs when done as prompt+answer.
716
+ response_token_ids_start_idx = len(prompt_input_ids)
717
+
718
+ # If tokenized prompt is different than both prompt+answer, then it means the
719
+ # last token has changed due to merging.
720
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
721
+ response_token_ids_start_idx -= 1
722
+
723
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
724
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
725
+
726
+ if len(prompt_input_ids) != len(prompt_attention_mask):
727
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
728
+
729
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
730
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
731
+
732
+ return dict(
733
+ prompt_input_ids=prompt_input_ids,
734
+ prompt_attention_mask=prompt_attention_mask,
735
+ input_ids=answer_input_ids,
736
+ attention_mask=answer_attention_mask,
737
+ )
738
+
739
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
740
+ """Tokenize a single row from a CPO specific dataset.
741
+
742
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
743
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
744
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
745
+
746
+ We also create the labels for the chosen/rejected responses, which are of length equal to
747
+ the sum of the length of the prompt and the chosen/rejected response, with
748
+ label_pad_token_id for the prompt tokens.
749
+ """
750
+ batch = {}
751
+ prompt = feature["prompt"]
752
+ chosen = feature["chosen"]
753
+ rejected = feature["rejected"]
754
+
755
+ if not self.is_encoder_decoder:
756
+ # Check issues below for more details
757
+ # 1. https://github.com/huggingface/trl/issues/907
758
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
759
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
760
+
761
+ if not isinstance(prompt, str):
762
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
763
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
764
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
765
+
766
+ if not isinstance(chosen, str):
767
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
768
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
769
+
770
+ if not isinstance(rejected, str):
771
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
772
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
773
+
774
+ # Last prompt token might get merged by tokenizer and
775
+ # it should not be included for generation if that happens
776
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
777
+
778
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
779
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
780
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
781
+
782
+ for k, v in prompt_tokens.items():
783
+ prompt_tokens[k] = v[:prompt_len_input_ids]
784
+
785
+ # Make sure prompts only have one different token at most an
786
+ # and length only differs by 1 at most
787
+ num_diff_tokens = sum(
788
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
789
+ )
790
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
791
+ if num_diff_tokens > 1 or num_diff_len > 1:
792
+ raise ValueError(
793
+ "Chosen and rejected prompt_input_ids might only differ on the "
794
+ "last token due to tokenizer merge ops."
795
+ )
796
+
797
+ # add BOS token to head of prompt. Avoid adding if it's already there
798
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
799
+ self.processing_class.bos_token_id,
800
+ prompt_len_input_ids,
801
+ prompt_tokens,
802
+ chosen_prompt_len_input_ids,
803
+ chosen_tokens,
804
+ rejected_prompt_len_input_ids,
805
+ rejected_tokens,
806
+ )
807
+
808
+ # add EOS token to end of answer. Avoid adding if it's already there
809
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
810
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
811
+ )
812
+
813
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
814
+
815
+ # if combined sequence is too long, truncate the prompt
816
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
817
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
818
+ if self.truncation_mode == "keep_start":
819
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
820
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
821
+ elif self.truncation_mode == "keep_end":
822
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
823
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
824
+ else:
825
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
826
+
827
+ # if that's still too long, truncate the response
828
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
829
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
830
+ for k in ["input_ids", "attention_mask"]:
831
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
832
+
833
+ # Create labels
834
+ chosen_sequence_tokens = {
835
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
836
+ }
837
+ rejected_sequence_tokens = {
838
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
839
+ }
840
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
841
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
842
+ self.label_pad_token_id
843
+ ] * len(chosen_tokens["prompt_input_ids"])
844
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
845
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
846
+ self.label_pad_token_id
847
+ ] * len(rejected_tokens["prompt_input_ids"])
848
+
849
+ for k, toks in {
850
+ "chosen_": chosen_sequence_tokens,
851
+ "rejected_": rejected_sequence_tokens,
852
+ "": prompt_tokens,
853
+ }.items():
854
+ for type_key, tokens in toks.items():
855
+ if type_key == "token_type_ids":
856
+ continue
857
+ batch[f"{k}{type_key}"] = tokens
858
+
859
+ else:
860
+ chosen_tokens = self.processing_class(
861
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
862
+ )
863
+ rejected_tokens = self.processing_class(
864
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
865
+ )
866
+ prompt_tokens = self.processing_class(
867
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
868
+ )
869
+
870
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
871
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
872
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
873
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
874
+
875
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
876
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
877
+ labels=torch.tensor(batch["rejected_labels"])
878
+ )
879
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
880
+ labels=torch.tensor(batch["chosen_labels"])
881
+ )
882
+
883
+ return batch
884
+
885
+ @staticmethod
886
+ def concatenated_inputs(
887
+ batch: dict[str, Union[list, torch.LongTensor]],
888
+ is_encoder_decoder: bool = False,
889
+ label_pad_token_id: int = -100,
890
+ padding_value: int = 0,
891
+ device: Optional[torch.device] = None,
892
+ ) -> dict[str, torch.LongTensor]:
893
+ """Concatenate the chosen and rejected inputs into a single tensor.
894
+
895
+ Args:
896
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
897
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
898
+ label_pad_token_id: The label pad token id.
899
+ padding_value: The padding value to use for the concatenated inputs_ids.
900
+ device: The device for the concatenated inputs.
901
+
902
+ Returns:
903
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
904
+ """
905
+ concatenated_batch = {}
906
+
907
+ if is_encoder_decoder:
908
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
909
+ else:
910
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
911
+
912
+ for k in batch:
913
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
914
+ if "labels" in k or is_encoder_decoder:
915
+ pad_value = label_pad_token_id
916
+ elif k.endswith("_input_ids"):
917
+ pad_value = padding_value
918
+ elif k.endswith("_attention_mask"):
919
+ pad_value = 0
920
+ concatenated_key = k.replace("chosen", "concatenated")
921
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
922
+ for k in batch:
923
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
924
+ if "labels" in k or is_encoder_decoder:
925
+ pad_value = label_pad_token_id
926
+ elif k.endswith("_input_ids"):
927
+ pad_value = padding_value
928
+ elif k.endswith("_attention_mask"):
929
+ pad_value = 0
930
+ concatenated_key = k.replace("rejected", "concatenated")
931
+ concatenated_batch[concatenated_key] = torch.cat(
932
+ (
933
+ concatenated_batch[concatenated_key],
934
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
935
+ ),
936
+ dim=0,
937
+ ).to(device=device)
938
+
939
+ if is_encoder_decoder:
940
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
941
+ concatenated_batch["concatenated_attention_mask"] = (
942
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
943
+ )
944
+
945
+ return concatenated_batch
946
+
947
+ def cpo_loss(
948
+ self,
949
+ policy_chosen_logps: torch.FloatTensor,
950
+ policy_rejected_logps: torch.FloatTensor,
951
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
952
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
953
+
954
+ Args:
955
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
956
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
957
+
958
+ Returns:
959
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
960
+ The losses tensor contains the CPO loss for each example in the batch.
961
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
962
+ """
963
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
964
+
965
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
966
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
967
+ # calculates a conservative CPO loss.
968
+
969
+ if self.loss_type == "simpo":
970
+ gamma_logratios = self.simpo_gamma / self.beta
971
+ logits = logits - gamma_logratios
972
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
973
+ losses = (
974
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
975
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
976
+ )
977
+ elif self.loss_type == "sigmoid":
978
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
979
+ losses = (
980
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
981
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
982
+ )
983
+ elif self.loss_type == "hinge":
984
+ losses = torch.relu(1 - self.beta * logits)
985
+ elif self.loss_type == "ipo":
986
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
987
+ losses = (logits - 1 / (2 * self.beta)) ** 2
988
+ else:
989
+ raise ValueError(
990
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
991
+ )
992
+
993
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
994
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
995
+
996
+ return losses, chosen_rewards, rejected_rewards
997
+
998
+ @staticmethod
999
+ def get_batch_logps(
1000
+ logits: torch.FloatTensor,
1001
+ labels: torch.LongTensor,
1002
+ average_log_prob: bool = False,
1003
+ label_pad_token_id: int = -100,
1004
+ is_encoder_decoder: bool = False,
1005
+ ) -> torch.FloatTensor:
1006
+ """Compute the log probabilities of the given labels under the given logits.
1007
+
1008
+ Args:
1009
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1010
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1011
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1012
+ label_pad_token_id: The label pad token id.
1013
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1014
+
1015
+ Returns:
1016
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1017
+ """
1018
+ if logits.shape[:-1] != labels.shape:
1019
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1020
+
1021
+ if not is_encoder_decoder:
1022
+ labels = labels[:, 1:].clone()
1023
+ logits = logits[:, :-1, :]
1024
+ loss_mask = labels != label_pad_token_id
1025
+
1026
+ # dummy token; we'll ignore the losses on these tokens later
1027
+ labels[labels == label_pad_token_id] = 0
1028
+
1029
+ per_token_logps = selective_log_softmax(logits, labels)
1030
+
1031
+ if average_log_prob:
1032
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1033
+ else:
1034
+ return (per_token_logps * loss_mask).sum(-1)
1035
+
1036
+ def concatenated_forward(
1037
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1038
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1039
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1040
+
1041
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1042
+ """
1043
+ concatenated_batch = self.concatenated_inputs(
1044
+ batch,
1045
+ is_encoder_decoder=self.is_encoder_decoder,
1046
+ label_pad_token_id=self.label_pad_token_id,
1047
+ padding_value=self.padding_value,
1048
+ device=self.accelerator.device,
1049
+ )
1050
+ len_chosen = batch["chosen_labels"].shape[0]
1051
+
1052
+ model_kwargs = (
1053
+ {
1054
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1055
+ }
1056
+ if self.is_encoder_decoder
1057
+ else {}
1058
+ )
1059
+
1060
+ if self.aux_loss_enabled:
1061
+ model_kwargs["output_router_logits"] = True
1062
+
1063
+ outputs = model(
1064
+ concatenated_batch["concatenated_input_ids"],
1065
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1066
+ use_cache=False,
1067
+ **model_kwargs,
1068
+ )
1069
+ all_logits = outputs.logits
1070
+
1071
+ def cross_entropy_loss(logits, labels):
1072
+ if not self.is_encoder_decoder:
1073
+ # Shift so that tokens < n predict n
1074
+ logits = logits[..., :-1, :].contiguous()
1075
+ labels = labels[..., 1:].contiguous()
1076
+ # Flatten the tokens
1077
+ loss_fct = nn.CrossEntropyLoss()
1078
+ logits = logits.view(-1, logits.shape[-1])
1079
+ labels = labels.view(-1)
1080
+ # Enable model parallelism
1081
+ labels = labels.to(logits.device)
1082
+ loss = loss_fct(logits, labels)
1083
+ return loss
1084
+
1085
+ labels = concatenated_batch["concatenated_labels"].clone()
1086
+
1087
+ if self.cpo_alpha == 0:
1088
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1089
+ else:
1090
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1091
+
1092
+ all_logps = self.get_batch_logps(
1093
+ all_logits,
1094
+ concatenated_batch["concatenated_labels"],
1095
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1096
+ is_encoder_decoder=self.is_encoder_decoder,
1097
+ label_pad_token_id=self.label_pad_token_id,
1098
+ )
1099
+
1100
+ chosen_logps = all_logps[:len_chosen]
1101
+ rejected_logps = all_logps[len_chosen:]
1102
+
1103
+ chosen_logits = all_logits[:len_chosen]
1104
+ rejected_logits = all_logits[len_chosen:]
1105
+
1106
+ if self.aux_loss_enabled:
1107
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1108
+
1109
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1110
+
1111
+ def get_batch_loss_metrics(
1112
+ self,
1113
+ model,
1114
+ batch: dict[str, Union[list, torch.LongTensor]],
1115
+ train_eval: Literal["train", "eval"] = "train",
1116
+ ):
1117
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1118
+ metrics = {}
1119
+
1120
+ forward_output = self.concatenated_forward(model, batch)
1121
+ (
1122
+ policy_chosen_logps,
1123
+ policy_rejected_logps,
1124
+ policy_chosen_logits,
1125
+ policy_rejected_logits,
1126
+ policy_nll_loss,
1127
+ ) = forward_output[:5]
1128
+ if self.aux_loss_enabled:
1129
+ aux_loss = forward_output[5]
1130
+
1131
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1132
+ policy_chosen_logps,
1133
+ policy_rejected_logps,
1134
+ )
1135
+
1136
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1137
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1138
+
1139
+ prefix = "eval_" if train_eval == "eval" else ""
1140
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1141
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1142
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1143
+ metrics[f"{prefix}rewards/margins"] = (
1144
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1145
+ )
1146
+ metrics[f"{prefix}logps/rejected"] = (
1147
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1148
+ )
1149
+ metrics[f"{prefix}logps/chosen"] = (
1150
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1151
+ )
1152
+ metrics[f"{prefix}logits/rejected"] = (
1153
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1154
+ )
1155
+ metrics[f"{prefix}logits/chosen"] = (
1156
+ self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1157
+ )
1158
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1159
+
1160
+ if self.aux_loss_enabled:
1161
+ loss += self.aux_loss_coef * aux_loss
1162
+
1163
+ return loss, metrics
1164
+
1165
+ def compute_loss(
1166
+ self,
1167
+ model: Union[PreTrainedModel, nn.Module],
1168
+ inputs: dict[str, Union[torch.Tensor, Any]],
1169
+ return_outputs=False,
1170
+ num_items_in_batch=None,
1171
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1172
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1173
+
1174
+ with compute_loss_context_manager:
1175
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1176
+
1177
+ # force log the metrics
1178
+ self.store_metrics(metrics, train_eval="train")
1179
+
1180
+ if return_outputs:
1181
+ return (loss, metrics)
1182
+ return loss
1183
+
1184
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1185
+ """Generate samples from the model and reference model for the given batch of inputs."""
1186
+
1187
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1188
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1189
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1190
+
1191
+ with generate_context_manager:
1192
+ policy_output = model.generate(
1193
+ input_ids=batch["prompt_input_ids"],
1194
+ attention_mask=batch["prompt_attention_mask"],
1195
+ max_length=self.max_length,
1196
+ do_sample=True,
1197
+ pad_token_id=self.processing_class.pad_token_id,
1198
+ )
1199
+
1200
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1201
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1202
+
1203
+ return policy_output_decoded
1204
+
1205
+ def prediction_step(
1206
+ self,
1207
+ model: Union[PreTrainedModel, nn.Module],
1208
+ inputs: dict[str, Union[torch.Tensor, Any]],
1209
+ prediction_loss_only: bool,
1210
+ ignore_keys: Optional[list[str]] = None,
1211
+ ):
1212
+ if ignore_keys is None:
1213
+ if hasattr(model, "config"):
1214
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1215
+ else:
1216
+ ignore_keys = []
1217
+
1218
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1219
+
1220
+ with torch.no_grad(), prediction_context_manager:
1221
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1222
+
1223
+ # force log the metrics
1224
+ self.store_metrics(metrics, train_eval="eval")
1225
+
1226
+ if prediction_loss_only:
1227
+ return (loss.detach(), None, None)
1228
+
1229
+ # logits for the chosen and rejected samples from model
1230
+ logits_dict = {
1231
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1232
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1233
+ }
1234
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1235
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1236
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1237
+
1238
+ return (loss.detach(), logits, labels)
1239
+
1240
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1241
+ for key, value in metrics.items():
1242
+ self._stored_metrics[train_eval][key].append(value)
1243
+
1244
+ def evaluation_loop(
1245
+ self,
1246
+ dataloader: DataLoader,
1247
+ description: str,
1248
+ prediction_loss_only: Optional[bool] = None,
1249
+ ignore_keys: Optional[list[str]] = None,
1250
+ metric_key_prefix: str = "eval",
1251
+ ) -> EvalLoopOutput:
1252
+ """
1253
+ Overriding built-in evaluation loop to store metrics for each batch.
1254
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1255
+
1256
+ Works both with or without labels.
1257
+ """
1258
+
1259
+ # Sample and save to game log if requested (for one batch to save time)
1260
+ if self.generate_during_eval:
1261
+ # Generate random indices within the range of the total number of samples
1262
+ num_samples = len(dataloader.dataset)
1263
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1264
+
1265
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1266
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1267
+ random_batch = self.data_collator(random_batch_dataset)
1268
+ random_batch = self._prepare_inputs(random_batch)
1269
+
1270
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1271
+
1272
+ table = pd.DataFrame(
1273
+ columns=["Prompt", "Policy"],
1274
+ data=[
1275
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1276
+ ],
1277
+ )
1278
+ if "wandb" in self.args.report_to:
1279
+ wandb.log({"game_log": wandb.Table(data=table)})
1280
+
1281
+ if "comet_ml" in self.args.report_to:
1282
+ log_table_to_comet_experiment(
1283
+ name="game_log.csv",
1284
+ table=table,
1285
+ )
1286
+
1287
+ # Base evaluation
1288
+ initial_output = super().evaluation_loop(
1289
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1290
+ )
1291
+
1292
+ return initial_output
1293
+
1294
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1295
+ """
1296
+ Log `logs` on the various objects watching training, including stored metrics.
1297
+
1298
+ Args:
1299
+ logs (`dict[str, float]`):
1300
+ The values to log.
1301
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1302
+ Start time of the training.
1303
+ """
1304
+ # logs either has 'loss' or 'eval_loss'
1305
+ train_eval = "train" if "loss" in logs else "eval"
1306
+ # Add averaged stored metrics to logs
1307
+ for key, metrics in self._stored_metrics[train_eval].items():
1308
+ logs[key] = torch.tensor(metrics).mean().item()
1309
+ del self._stored_metrics[train_eval]
1310
+
1311
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1312
+ return super().log(logs, start_time)
1313
+ else: # transformers<=4.46
1314
+ return super().log(logs)
1315
+
1316
+ def _shift_right(self, input_ids):
1317
+ if self.decoder_start_token_id is None:
1318
+ raise ValueError(
1319
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1320
+ )
1321
+
1322
+ # shift inputs to the right
1323
+ if is_torch_fx_proxy(input_ids):
1324
+ # Item assignment is not supported natively for proxies.
1325
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1326
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1327
+ else:
1328
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1329
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1330
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1331
+
1332
+ if self.pad_token_id is None:
1333
+ raise ValueError("model.config.pad_token_id has to be defined.")
1334
+ # replace possible -100 values in labels by `pad_token_id`
1335
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1336
+
1337
+ return shifted_input_ids
1338
+
1339
+ def create_model_card(
1340
+ self,
1341
+ model_name: Optional[str] = None,
1342
+ dataset_name: Optional[str] = None,
1343
+ tags: Union[str, list[str], None] = None,
1344
+ ):
1345
+ """
1346
+ Creates a draft of a model card using the information available to the `Trainer`.
1347
+
1348
+ Args:
1349
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1350
+ Name of the model.
1351
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1352
+ Name of the dataset used for training.
1353
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1354
+ Tags to be associated with the model card.
1355
+ """
1356
+ if not self.is_world_process_zero():
1357
+ return
1358
+
1359
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1360
+ base_model = self.model.config._name_or_path
1361
+ else:
1362
+ base_model = None
1363
+
1364
+ tags = tags or []
1365
+ if isinstance(tags, str):
1366
+ tags = [tags]
1367
+
1368
+ if hasattr(self.model.config, "unsloth_version"):
1369
+ tags.append("unsloth")
1370
+
1371
+ citation = textwrap.dedent("""\
1372
+ @inproceedings{xu2024contrastive,
1373
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1374
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1375
+ year = 2024,
1376
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1377
+ publisher = {OpenReview.net},
1378
+ url = {https://openreview.net/forum?id=51iwkioZpn}
1379
+ }""")
1380
+
1381
+ model_card = generate_model_card(
1382
+ base_model=base_model,
1383
+ model_name=model_name,
1384
+ hub_model_id=self.hub_model_id,
1385
+ dataset_name=dataset_name,
1386
+ tags=tags,
1387
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1388
+ comet_url=get_comet_experiment_url(),
1389
+ trainer_name="CPO",
1390
+ trainer_citation=citation,
1391
+ paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1392
+ paper_id="2401.08417",
1393
+ )
1394
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1395
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1396
+ """
1397
+
1398
+ Initialize CPOTrainer.
1399
+
1400
+ Args:
1401
+ model (`transformers.PreTrainedModel`):
1402
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1403
+ args (`CPOConfig`):
1404
+ The CPO config arguments to use for training.
1405
+ data_collator (`transformers.DataCollator`):
1406
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1407
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1408
+ train_dataset (`datasets.Dataset`):
1409
+ The dataset to use for training.
1410
+ eval_dataset (`datasets.Dataset`):
1411
+ The dataset to use for evaluation.
1412
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1413
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1414
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1415
+ reuse the fine-tuned model.
1416
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1417
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1418
+ callbacks (`list[transformers.TrainerCallback]`):
1419
+ The callbacks to use for training.
1420
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1421
+ The optimizer and scheduler to use for training.
1422
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1423
+ The function to use to preprocess the logits before computing the metrics.
1424
+ peft_config (`dict`, defaults to `None`):
1425
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1426
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1427
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1428
+ a dictionary string to metric values.
1429
+
1430
+ """
1431
+ def __init__(
1432
+ self,
1433
+ model = None,
1434
+ args = None,
1435
+ data_collator = None,
1436
+ train_dataset = None,
1437
+ eval_dataset = None,
1438
+ processing_class = None,
1439
+ model_init = None,
1440
+ callbacks = None,
1441
+ preprocess_logits_for_metrics = None,
1442
+ peft_config = None,
1443
+ compute_metrics = None,
1444
+ **kwargs
1445
+ ):
1446
+ if args is None: args = UnslothCPOConfig()
1447
+ use_bf16 = getattr(args, 'bf16', False)
1448
+ use_fp16 = getattr(args, 'fp16', False)
1449
+ force_float32 = False
1450
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1451
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1452
+ force_float32 = True
1453
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1454
+ dtype = getattr(model.config, 'torch_dtype', None)
1455
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1456
+ from unsloth_zoo.utils import _get_dtype
1457
+ dtype = _get_dtype(dtype)
1458
+ float16 = dtype == torch.float16
1459
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1460
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1461
+ if force_float32:
1462
+ args.fp16 = False
1463
+ args.bf16 = False
1464
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1465
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1466
+ args.fp16 = float16
1467
+ args.bf16 = not float16
1468
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1469
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1470
+ args.eval_strategy = 'steps'
1471
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1472
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1473
+ if ga_steps is not None and ga_steps > 1:
1474
+ from transformers import __version__ as transformers_version
1475
+ if Version(transformers_version) <= Version('4.45.2'):
1476
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1477
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1478
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1479
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1480
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1481
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1482
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1483
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1484
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1485
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1486
+ if force_float32:
1487
+ args.bf16_full_eval = False
1488
+ args.fp16_full_eval = False
1489
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1490
+ args.bf16_full_eval = True
1491
+ args.fp16_full_eval = False
1492
+ elif not bf16_full_eval and not fp16_full_eval:
1493
+ args.bf16_full_eval = args.bf16
1494
+ args.fp16_full_eval = args.fp16
1495
+ _output_logits = False
1496
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1497
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1498
+ if _output_logits:
1499
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1500
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1501
+ pass
1502
+ else:
1503
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1504
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1505
+ if args_max_seq_length is None and model_max_seq_length is not None:
1506
+ max_seq_length = model.max_seq_length
1507
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1508
+ if model is not None and hasattr(model, 'for_training'):
1509
+ model.for_training()
1510
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1511
+ if 'processing_class' in locals():
1512
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1513
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1514
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1515
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1516
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1517
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1518
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1519
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1520
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1521
+ else:
1522
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1523
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1524
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1525
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1526
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1527
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1528
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1529
+ else:
1530
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1531
+ other_metrics = []
1532
+
1533
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1534
+ PatchRLStatistics('cpo_trainer', other_metrics)
1535
+
1536
+ super().__init__(
1537
+ model = model,
1538
+ args = args,
1539
+ data_collator = data_collator,
1540
+ train_dataset = train_dataset,
1541
+ eval_dataset = eval_dataset,
1542
+ processing_class = processing_class,
1543
+ model_init = model_init,
1544
+ callbacks = callbacks,
1545
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1546
+ peft_config = peft_config,
1547
+ compute_metrics = compute_metrics,**kwargs)
1548
+ if hasattr(self, 'neftune_hook_handle'):
1549
+ self.neftune_hook_handle.remove()
1550
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1551
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1552
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1553
+ pass
1554
+
1555
+ pass
unsloth_compiled_cache/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothDDPOConfig(DDPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`DDPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (by default is the file name without the extension name).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed.
59
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
60
+ Log with either 'wandb' or 'tensorboard', check
61
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
62
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
63
+ Keyword arguments for the tracker (e.g. wandb_project).
64
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the accelerator.
66
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
68
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
69
+ Name of project to use for tracking.
70
+ logdir (`str`, *optional*, defaults to `"logs"`):
71
+ Top-level logging directory for checkpoint saving.
72
+ num_epochs (`int`, *optional*, defaults to `100`):
73
+ Number of epochs to train.
74
+ save_freq (`int`, *optional*, defaults to `1`):
75
+ Number of epochs between saving model checkpoints.
76
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
77
+ Number of checkpoints to keep before overwriting old ones.
78
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
79
+ Mixed precision training.
80
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
81
+ Allow `tf32` on Ampere GPUs.
82
+ resume_from (`str`, *optional*, defaults to `""`):
83
+ Resume training from a checkpoint.
84
+ sample_num_steps (`int`, *optional*, defaults to `50`):
85
+ Number of sampler inference steps.
86
+ sample_eta (`float`, *optional*, defaults to `1.0`):
87
+ Eta parameter for the DDIM sampler.
88
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
89
+ Classifier-free guidance weight.
90
+ sample_batch_size (`int`, *optional*, defaults to `1`):
91
+ Batch size (per GPU) to use for sampling.
92
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
93
+ Number of batches to sample per epoch.
94
+ train_batch_size (`int`, *optional*, defaults to `1`):
95
+ Batch size (per GPU) to use for training.
96
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
97
+ Use 8bit Adam optimizer from bitsandbytes.
98
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
99
+ Learning rate.
100
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
101
+ Adam beta1.
102
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
103
+ Adam beta2.
104
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
105
+ Adam weight decay.
106
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
107
+ Adam epsilon.
108
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
109
+ Number of gradient accumulation steps.
110
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
111
+ Maximum gradient norm for gradient clipping.
112
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
113
+ Number of inner epochs per outer epoch.
114
+ train_cfg (`bool`, *optional*, defaults to `True`):
115
+ Whether to use classifier-free guidance during training.
116
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
117
+ Clip advantages to the range.
118
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
119
+ PPO clip range.
120
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
121
+ Fraction of timesteps to train on.
122
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
123
+ Whether to track statistics for each prompt separately.
124
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
125
+ Number of reward values to store in the buffer for each prompt.
126
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
127
+ Minimum number of reward values to store in the buffer.
128
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
129
+ Whether to compute rewards asynchronously.
130
+ max_workers (`int`, *optional*, defaults to `2`):
131
+ Maximum number of workers to use for async reward computation.
132
+ negative_prompts (`str`, *optional*, defaults to `""`):
133
+ Comma-separated list of prompts to use as negative examples.
134
+ push_to_hub (`bool`, *optional*, defaults to `False`):
135
+ Whether to push the final model checkpoint to the Hub.
136
+
137
+ """
138
+ vllm_sampling_params: Optional[Any] = field(
139
+ default = None,
140
+ metadata = {'help': 'vLLM SamplingParams'},
141
+ )
142
+ unsloth_num_chunks : Optional[int] = field(
143
+ default = -1,
144
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
145
+ )
146
+ def __init__(
147
+ self,
148
+ exp_name = 'main',
149
+ run_name = '',
150
+ seed = 3407,
151
+ log_with = None,
152
+ tracker_project_name = 'trl',
153
+ logdir = 'logs',
154
+ num_epochs = 100,
155
+ save_freq = 1,
156
+ num_checkpoint_limit = 5,
157
+ mixed_precision = 'fp16',
158
+ allow_tf32 = True,
159
+ resume_from = '',
160
+ sample_num_steps = 50,
161
+ sample_eta = 1.0,
162
+ sample_guidance_scale = 5.0,
163
+ sample_batch_size = 1,
164
+ sample_num_batches_per_epoch = 2,
165
+ train_batch_size = 1,
166
+ train_use_8bit_adam = False,
167
+ train_learning_rate = 5e-05,
168
+ train_adam_beta1 = 0.9,
169
+ train_adam_beta2 = 0.999,
170
+ train_adam_weight_decay = 0.01,
171
+ train_adam_epsilon = 1e-08,
172
+ train_gradient_accumulation_steps = 2,
173
+ train_max_grad_norm = 1.0,
174
+ train_num_inner_epochs = 1,
175
+ train_cfg = True,
176
+ train_adv_clip_max = 5.0,
177
+ train_clip_range = 0.0001,
178
+ train_timestep_fraction = 1.0,
179
+ per_prompt_stat_tracking = False,
180
+ per_prompt_stat_tracking_buffer_size = 16,
181
+ per_prompt_stat_tracking_min_count = 16,
182
+ async_reward_computation = False,
183
+ max_workers = 2,
184
+ negative_prompts = '',
185
+ push_to_hub = False,
186
+ vllm_sampling_params = None,
187
+ unsloth_num_chunks = -1,
188
+ **kwargs,
189
+ ):
190
+
191
+ super().__init__(
192
+ exp_name = exp_name,
193
+ run_name = run_name,
194
+ seed = seed,
195
+ log_with = log_with,
196
+ tracker_project_name = tracker_project_name,
197
+ logdir = logdir,
198
+ num_epochs = num_epochs,
199
+ save_freq = save_freq,
200
+ num_checkpoint_limit = num_checkpoint_limit,
201
+ mixed_precision = mixed_precision,
202
+ allow_tf32 = allow_tf32,
203
+ resume_from = resume_from,
204
+ sample_num_steps = sample_num_steps,
205
+ sample_eta = sample_eta,
206
+ sample_guidance_scale = sample_guidance_scale,
207
+ sample_batch_size = sample_batch_size,
208
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
209
+ train_batch_size = train_batch_size,
210
+ train_use_8bit_adam = train_use_8bit_adam,
211
+ train_learning_rate = train_learning_rate,
212
+ train_adam_beta1 = train_adam_beta1,
213
+ train_adam_beta2 = train_adam_beta2,
214
+ train_adam_weight_decay = train_adam_weight_decay,
215
+ train_adam_epsilon = train_adam_epsilon,
216
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
217
+ train_max_grad_norm = train_max_grad_norm,
218
+ train_num_inner_epochs = train_num_inner_epochs,
219
+ train_cfg = train_cfg,
220
+ train_adv_clip_max = train_adv_clip_max,
221
+ train_clip_range = train_clip_range,
222
+ train_timestep_fraction = train_timestep_fraction,
223
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
224
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
225
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
226
+ async_reward_computation = async_reward_computation,
227
+ max_workers = max_workers,
228
+ negative_prompts = negative_prompts,
229
+ push_to_hub = push_to_hub,**kwargs)
230
+ self.vllm_sampling_params = vllm_sampling_params
231
+ self.unsloth_num_chunks = unsloth_num_chunks
232
+ pass
233
+
234
+ class _UnslothDDPOTrainer(PyTorchModelHubMixin):
235
+ """"""
236
+
237
+ _tag_names = ["trl", "ddpo"]
238
+
239
+ def __init__(
240
+ self,
241
+ config: DDPOConfig,
242
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
243
+ prompt_function: Callable[[], tuple[str, Any]],
244
+ sd_pipeline: DDPOStableDiffusionPipeline,
245
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
246
+ ):
247
+ if image_samples_hook is None:
248
+ warn("No image_samples_hook provided; no images will be logged")
249
+
250
+ self.prompt_fn = prompt_function
251
+ self.reward_fn = reward_function
252
+ self.config = config
253
+ self.image_samples_callback = image_samples_hook
254
+
255
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
256
+
257
+ if self.config.resume_from:
258
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
259
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
260
+ # get the most recent checkpoint in this directory
261
+ checkpoints = list(
262
+ filter(
263
+ lambda x: "checkpoint_" in x,
264
+ os.listdir(self.config.resume_from),
265
+ )
266
+ )
267
+ if len(checkpoints) == 0:
268
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
269
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
270
+ self.config.resume_from = os.path.join(
271
+ self.config.resume_from,
272
+ f"checkpoint_{checkpoint_numbers[-1]}",
273
+ )
274
+
275
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
276
+
277
+ # number of timesteps within each trajectory to train on
278
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
279
+
280
+ self.accelerator = Accelerator(
281
+ log_with=self.config.log_with,
282
+ mixed_precision=self.config.mixed_precision,
283
+ project_config=accelerator_project_config,
284
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
285
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
286
+ # the total number of optimizer steps to accumulate across.
287
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
288
+ **self.config.accelerator_kwargs,
289
+ )
290
+
291
+ is_okay, message = self._config_check()
292
+ if not is_okay:
293
+ raise ValueError(message)
294
+
295
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
296
+
297
+ if self.accelerator.is_main_process:
298
+ self.accelerator.init_trackers(
299
+ self.config.tracker_project_name,
300
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
301
+ init_kwargs=self.config.tracker_kwargs,
302
+ )
303
+
304
+ logger.info(f"\n{config}")
305
+
306
+ set_seed(self.config.seed, device_specific=True)
307
+
308
+ self.sd_pipeline = sd_pipeline
309
+
310
+ self.sd_pipeline.set_progress_bar_config(
311
+ position=1,
312
+ disable=not self.accelerator.is_local_main_process,
313
+ leave=False,
314
+ desc="Timestep",
315
+ dynamic_ncols=True,
316
+ )
317
+
318
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
319
+ # as these weights are only used for inference, keeping weights in full precision is not required.
320
+ if self.accelerator.mixed_precision == "fp16":
321
+ inference_dtype = torch.float16
322
+ elif self.accelerator.mixed_precision == "bf16":
323
+ inference_dtype = torch.bfloat16
324
+ else:
325
+ inference_dtype = torch.float32
326
+
327
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
328
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
329
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
330
+
331
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
332
+
333
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
334
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
335
+
336
+ # Enable TF32 for faster training on Ampere GPUs,
337
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
338
+ if self.config.allow_tf32:
339
+ torch.backends.cuda.matmul.allow_tf32 = True
340
+
341
+ self.optimizer = self._setup_optimizer(
342
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
343
+ )
344
+
345
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
346
+ self.sd_pipeline.tokenizer(
347
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
348
+ return_tensors="pt",
349
+ padding="max_length",
350
+ truncation=True,
351
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
352
+ ).input_ids.to(self.accelerator.device)
353
+ )[0]
354
+
355
+ if config.per_prompt_stat_tracking:
356
+ self.stat_tracker = PerPromptStatTracker(
357
+ config.per_prompt_stat_tracking_buffer_size,
358
+ config.per_prompt_stat_tracking_min_count,
359
+ )
360
+
361
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
362
+ # more memory
363
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
364
+
365
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
366
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
367
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
368
+ else:
369
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
370
+
371
+ if self.config.async_reward_computation:
372
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
373
+
374
+ if config.resume_from:
375
+ logger.info(f"Resuming from {config.resume_from}")
376
+ self.accelerator.load_state(config.resume_from)
377
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
378
+ else:
379
+ self.first_epoch = 0
380
+
381
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
382
+ if not is_async:
383
+ rewards = []
384
+ for images, prompts, prompt_metadata in prompt_image_pairs:
385
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
386
+ rewards.append(
387
+ (
388
+ torch.as_tensor(reward, device=self.accelerator.device),
389
+ reward_metadata,
390
+ )
391
+ )
392
+ else:
393
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
394
+ rewards = [
395
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
396
+ for reward, reward_metadata in rewards
397
+ ]
398
+
399
+ return zip(*rewards)
400
+
401
+ def step(self, epoch: int, global_step: int):
402
+ """
403
+ Perform a single step of training.
404
+
405
+ Args:
406
+ epoch (int): The current epoch.
407
+ global_step (int): The current global step.
408
+
409
+ Side Effects:
410
+ - Model weights are updated
411
+ - Logs the statistics to the accelerator trackers.
412
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
413
+
414
+ Returns:
415
+ global_step (int): The updated global step.
416
+
417
+ """
418
+ samples, prompt_image_data = self._generate_samples(
419
+ iterations=self.config.sample_num_batches_per_epoch,
420
+ batch_size=self.config.sample_batch_size,
421
+ )
422
+
423
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
424
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
425
+ rewards, rewards_metadata = self.compute_rewards(
426
+ prompt_image_data, is_async=self.config.async_reward_computation
427
+ )
428
+
429
+ for i, image_data in enumerate(prompt_image_data):
430
+ image_data.extend([rewards[i], rewards_metadata[i]])
431
+
432
+ if self.image_samples_callback is not None:
433
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
434
+
435
+ rewards = torch.cat(rewards)
436
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
437
+
438
+ self.accelerator.log(
439
+ {
440
+ "reward": rewards,
441
+ "epoch": epoch,
442
+ "reward_mean": rewards.mean(),
443
+ "reward_std": rewards.std(),
444
+ },
445
+ step=global_step,
446
+ )
447
+
448
+ if self.config.per_prompt_stat_tracking:
449
+ # gather the prompts across processes
450
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
451
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
452
+ advantages = self.stat_tracker.update(prompts, rewards)
453
+ else:
454
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
455
+
456
+ # ungather advantages; keep the entries corresponding to the samples on this process
457
+ samples["advantages"] = (
458
+ torch.as_tensor(advantages)
459
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
460
+ .to(self.accelerator.device)
461
+ )
462
+
463
+ del samples["prompt_ids"]
464
+
465
+ total_batch_size, num_timesteps = samples["timesteps"].shape
466
+
467
+ for inner_epoch in range(self.config.train_num_inner_epochs):
468
+ # shuffle samples along batch dimension
469
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
470
+ samples = {k: v[perm] for k, v in samples.items()}
471
+
472
+ # shuffle along time dimension independently for each sample
473
+ # still trying to understand the code below
474
+ perms = torch.stack(
475
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
476
+ )
477
+
478
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
479
+ samples[key] = samples[key][
480
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
481
+ perms,
482
+ ]
483
+
484
+ original_keys = samples.keys()
485
+ original_values = samples.values()
486
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
487
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
488
+
489
+ # Transpose the list of original values
490
+ transposed_values = zip(*reshaped_values)
491
+ # Create new dictionaries for each row of transposed values
492
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
493
+
494
+ self.sd_pipeline.unet.train()
495
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
496
+ # ensure optimization step at the end of the inner epoch
497
+ if not self.accelerator.sync_gradients:
498
+ raise ValueError(
499
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
500
+ )
501
+
502
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
503
+ self.accelerator.save_state()
504
+
505
+ return global_step
506
+
507
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
508
+ """
509
+ Calculate the loss for a batch of an unpacked sample
510
+
511
+ Args:
512
+ latents (torch.Tensor):
513
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
514
+ timesteps (torch.Tensor):
515
+ The timesteps sampled from the diffusion model, shape: [batch_size]
516
+ next_latents (torch.Tensor):
517
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
518
+ log_probs (torch.Tensor):
519
+ The log probabilities of the latents, shape: [batch_size]
520
+ advantages (torch.Tensor):
521
+ The advantages of the latents, shape: [batch_size]
522
+ embeds (torch.Tensor):
523
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
524
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
525
+
526
+ Returns:
527
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
528
+ (all of these are of shape (1,))
529
+ """
530
+ with self.autocast():
531
+ if self.config.train_cfg:
532
+ noise_pred = self.sd_pipeline.unet(
533
+ torch.cat([latents] * 2),
534
+ torch.cat([timesteps] * 2),
535
+ embeds,
536
+ ).sample
537
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
539
+ noise_pred_text - noise_pred_uncond
540
+ )
541
+ else:
542
+ noise_pred = self.sd_pipeline.unet(
543
+ latents,
544
+ timesteps,
545
+ embeds,
546
+ ).sample
547
+ # compute the log prob of next_latents given latents under the current model
548
+
549
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
550
+ noise_pred,
551
+ timesteps,
552
+ latents,
553
+ eta=self.config.sample_eta,
554
+ prev_sample=next_latents,
555
+ )
556
+
557
+ log_prob = scheduler_step_output.log_probs
558
+
559
+ advantages = torch.clamp(
560
+ advantages,
561
+ -self.config.train_adv_clip_max,
562
+ self.config.train_adv_clip_max,
563
+ )
564
+
565
+ ratio = torch.exp(log_prob - log_probs)
566
+
567
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
568
+
569
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
570
+
571
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
572
+
573
+ return loss, approx_kl, clipfrac
574
+
575
+ def loss(
576
+ self,
577
+ advantages: torch.Tensor,
578
+ clip_range: float,
579
+ ratio: torch.Tensor,
580
+ ):
581
+ unclipped_loss = -advantages * ratio
582
+ clipped_loss = -advantages * torch.clamp(
583
+ ratio,
584
+ 1.0 - clip_range,
585
+ 1.0 + clip_range,
586
+ )
587
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
588
+
589
+ def _setup_optimizer(self, trainable_layers_parameters):
590
+ if self.config.train_use_8bit_adam:
591
+ import bitsandbytes
592
+
593
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
594
+ else:
595
+ optimizer_cls = torch.optim.AdamW
596
+
597
+ return optimizer_cls(
598
+ trainable_layers_parameters,
599
+ lr=self.config.train_learning_rate,
600
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
601
+ weight_decay=self.config.train_adam_weight_decay,
602
+ eps=self.config.train_adam_epsilon,
603
+ )
604
+
605
+ def _save_model_hook(self, models, weights, output_dir):
606
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
607
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
608
+
609
+ def _load_model_hook(self, models, input_dir):
610
+ self.sd_pipeline.load_checkpoint(models, input_dir)
611
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
612
+
613
+ def _generate_samples(self, iterations, batch_size):
614
+ """
615
+ Generate samples from the model
616
+
617
+ Args:
618
+ iterations (int): Number of iterations to generate samples for
619
+ batch_size (int): Batch size to use for sampling
620
+
621
+ Returns:
622
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
623
+ """
624
+ samples = []
625
+ prompt_image_pairs = []
626
+ self.sd_pipeline.unet.eval()
627
+
628
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
629
+
630
+ for _ in range(iterations):
631
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
632
+
633
+ prompt_ids = self.sd_pipeline.tokenizer(
634
+ prompts,
635
+ return_tensors="pt",
636
+ padding="max_length",
637
+ truncation=True,
638
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
639
+ ).input_ids.to(self.accelerator.device)
640
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
641
+
642
+ with self.autocast():
643
+ sd_output = self.sd_pipeline(
644
+ prompt_embeds=prompt_embeds,
645
+ negative_prompt_embeds=sample_neg_prompt_embeds,
646
+ num_inference_steps=self.config.sample_num_steps,
647
+ guidance_scale=self.config.sample_guidance_scale,
648
+ eta=self.config.sample_eta,
649
+ output_type="pt",
650
+ )
651
+
652
+ images = sd_output.images
653
+ latents = sd_output.latents
654
+ log_probs = sd_output.log_probs
655
+
656
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
657
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
658
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
659
+
660
+ samples.append(
661
+ {
662
+ "prompt_ids": prompt_ids,
663
+ "prompt_embeds": prompt_embeds,
664
+ "timesteps": timesteps,
665
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
666
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
667
+ "log_probs": log_probs,
668
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
669
+ }
670
+ )
671
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
672
+
673
+ return samples, prompt_image_pairs
674
+
675
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
676
+ """
677
+ Train on a batch of samples. Main training segment
678
+
679
+ Args:
680
+ inner_epoch (int): The current inner epoch
681
+ epoch (int): The current epoch
682
+ global_step (int): The current global step
683
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
684
+
685
+ Side Effects:
686
+ - Model weights are updated
687
+ - Logs the statistics to the accelerator trackers.
688
+
689
+ Returns:
690
+ global_step (int): The updated global step
691
+ """
692
+ info = defaultdict(list)
693
+ for _i, sample in enumerate(batched_samples):
694
+ if self.config.train_cfg:
695
+ # concat negative prompts to sample prompts to avoid two forward passes
696
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
697
+ else:
698
+ embeds = sample["prompt_embeds"]
699
+
700
+ for j in range(self.num_train_timesteps):
701
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
702
+ loss, approx_kl, clipfrac = self.calculate_loss(
703
+ sample["latents"][:, j],
704
+ sample["timesteps"][:, j],
705
+ sample["next_latents"][:, j],
706
+ sample["log_probs"][:, j],
707
+ sample["advantages"],
708
+ embeds,
709
+ )
710
+ info["approx_kl"].append(approx_kl)
711
+ info["clipfrac"].append(clipfrac)
712
+ info["loss"].append(loss)
713
+
714
+ self.accelerator.backward(loss)
715
+ if self.accelerator.sync_gradients:
716
+ self.accelerator.clip_grad_norm_(
717
+ self.trainable_layers.parameters()
718
+ if not isinstance(self.trainable_layers, list)
719
+ else self.trainable_layers,
720
+ self.config.train_max_grad_norm,
721
+ )
722
+ self.optimizer.step()
723
+ self.optimizer.zero_grad()
724
+
725
+ # Checks if the accelerator has performed an optimization step behind the scenes
726
+ if self.accelerator.sync_gradients:
727
+ # log training-related stuff
728
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
729
+ info = self.accelerator.reduce(info, reduction="mean")
730
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
731
+ self.accelerator.log(info, step=global_step)
732
+ global_step += 1
733
+ info = defaultdict(list)
734
+ return global_step
735
+
736
+ def _config_check(self) -> tuple[bool, str]:
737
+ samples_per_epoch = (
738
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
739
+ )
740
+ total_train_batch_size = (
741
+ self.config.train_batch_size
742
+ * self.accelerator.num_processes
743
+ * self.config.train_gradient_accumulation_steps
744
+ )
745
+
746
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
747
+ return (
748
+ False,
749
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
750
+ )
751
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
752
+ return (
753
+ False,
754
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
755
+ )
756
+ if not samples_per_epoch % total_train_batch_size == 0:
757
+ return (
758
+ False,
759
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
760
+ )
761
+ return True, ""
762
+
763
+ def train(self, epochs: Optional[int] = None):
764
+ """
765
+ Train the model for a given number of epochs
766
+ """
767
+ global_step = 0
768
+ if epochs is None:
769
+ epochs = self.config.num_epochs
770
+ for epoch in range(self.first_epoch, epochs):
771
+ global_step = self.step(epoch, global_step)
772
+
773
+ def _save_pretrained(self, save_directory):
774
+ self.sd_pipeline.save_pretrained(save_directory)
775
+ self.create_model_card()
776
+
777
+ def create_model_card(
778
+ self,
779
+ model_name: Optional[str] = None,
780
+ dataset_name: Optional[str] = None,
781
+ tags: Union[str, list[str], None] = None,
782
+ ):
783
+ """
784
+ Creates a draft of a model card using the information available to the `Trainer`.
785
+
786
+ Args:
787
+ model_name (`str` or `None`, *optional*, defaults to `None`):
788
+ Name of the model.
789
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
+ Name of the dataset used for training.
791
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
+ Tags to be associated with the model card.
793
+ """
794
+ if not self.is_world_process_zero():
795
+ return
796
+
797
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
+ base_model = self.model.config._name_or_path
799
+ else:
800
+ base_model = None
801
+
802
+ tags = tags or []
803
+ if isinstance(tags, str):
804
+ tags = [tags]
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.append("unsloth")
808
+
809
+ citation = textwrap.dedent("""\
810
+ @inproceedings{black2024training,
811
+ title = {{Training Diffusion Models with Reinforcement Learning}},
812
+ author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
813
+ year = 2024,
814
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
815
+ publisher = {OpenReview.net},
816
+ url = {https://openreview.net/forum?id=YCWjhGrJFD},
817
+ }""")
818
+
819
+ model_card = generate_model_card(
820
+ base_model=base_model,
821
+ model_name=model_name,
822
+ hub_model_id=self.hub_model_id,
823
+ dataset_name=dataset_name,
824
+ tags=tags,
825
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
826
+ comet_url=get_comet_experiment_url(),
827
+ trainer_name="DDPO",
828
+ trainer_citation=citation,
829
+ paper_title="Training Diffusion Models with Reinforcement Learning",
830
+ paper_id="2305.13301",
831
+ )
832
+
833
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
834
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
835
+ """
836
+
837
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
838
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
839
+ As of now only Stable Diffusion based pipelines are supported
840
+
841
+ Attributes:
842
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
843
+ details.
844
+ **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
845
+ **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
846
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
847
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
848
+
849
+ """
850
+ def __init__(
851
+ self,
852
+ config,
853
+ reward_function,
854
+ prompt_function,
855
+ sd_pipeline,
856
+ image_samples_hook = None,
857
+ **kwargs
858
+ ):
859
+ if args is None: args = UnslothDDPOConfig()
860
+ other_metrics = []
861
+
862
+ from unsloth_zoo.logging_utils import PatchRLStatistics
863
+ PatchRLStatistics('ddpo_trainer', other_metrics)
864
+
865
+ super().__init__(
866
+ config = config,
867
+ reward_function = reward_function,
868
+ prompt_function = prompt_function,
869
+ sd_pipeline = sd_pipeline,
870
+ image_samples_hook = image_samples_hook,**kwargs)
871
+
872
+ pass
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothGKDConfig(GKDConfig):
44
+ """
45
+
46
+ Configuration class for [`GKDTrainer`].
47
+
48
+ Args:
49
+ temperature (`float`, *optional*, defaults to `0.9`):
50
+ Temperature for sampling. The higher the temperature, the more random the completions.
51
+ lmbda (`float`, *optional*, defaults to `0.5`):
52
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
53
+ student-generated outputs).
54
+ beta (`float`, *optional*, defaults to `0.5`):
55
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
56
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
57
+ max_new_tokens (`int`, *optional*, defaults to `128`):
58
+ Maximum number of tokens to generate per completion.
59
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
60
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
61
+ being trained.
62
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
63
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
64
+ from a string.
65
+ disable_dropout (`bool`, *optional*, defaults to `True`):
66
+ Whether to disable dropout in the model.
67
+ seq_kd (`bool`, *optional*, defaults to `False`):
68
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
69
+ on teacher-generated output).
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ fsdp_transformer_layer_cls_to_wrap = None,
159
+ accelerator_config = None,
160
+ deepspeed = None,
161
+ label_smoothing_factor = 0.0,
162
+ optim = 'adamw_8bit',
163
+ optim_args = None,
164
+ adafactor = False,
165
+ group_by_length = False,
166
+ length_column_name = 'length',
167
+ report_to = None,
168
+ ddp_find_unused_parameters = None,
169
+ ddp_bucket_cap_mb = None,
170
+ ddp_broadcast_buffers = None,
171
+ dataloader_pin_memory = True,
172
+ dataloader_persistent_workers = False,
173
+ skip_memory_metrics = True,
174
+ use_legacy_prediction_loop = False,
175
+ push_to_hub = False,
176
+ resume_from_checkpoint = None,
177
+ hub_model_id = None,
178
+ hub_strategy = 'every_save',
179
+ hub_token = None,
180
+ hub_private_repo = None,
181
+ hub_always_push = False,
182
+ gradient_checkpointing = False,
183
+ gradient_checkpointing_kwargs = None,
184
+ include_inputs_for_metrics = False,
185
+ eval_do_concat_batches = True,
186
+ fp16_backend = 'auto',
187
+ evaluation_strategy = None,
188
+ push_to_hub_model_id = None,
189
+ push_to_hub_organization = None,
190
+ push_to_hub_token = None,
191
+ mp_parameters = '',
192
+ auto_find_batch_size = False,
193
+ full_determinism = False,
194
+ torchdynamo = None,
195
+ ray_scope = 'last',
196
+ ddp_timeout = 1800,
197
+ torch_compile = False,
198
+ torch_compile_backend = None,
199
+ torch_compile_mode = None,
200
+ dispatch_batches = None,
201
+ split_batches = None,
202
+ include_tokens_per_second = False,
203
+ include_num_input_tokens_seen = False,
204
+ neftune_noise_alpha = None,
205
+ optim_target_modules = None,
206
+ batch_eval_metrics = False,
207
+ eval_on_start = False,
208
+ use_liger_kernel = False,
209
+ eval_use_gather_object = False,
210
+ average_tokens_across_devices = False,
211
+ model_init_kwargs = None,
212
+ use_liger = False,
213
+ dataset_text_field = 'text',
214
+ dataset_kwargs = None,
215
+ dataset_num_proc = None,
216
+ max_seq_length = None,
217
+ packing = False,
218
+ eval_packing = None,
219
+ dataset_batch_size = None,
220
+ num_of_sequences = None,
221
+ chars_per_token = None,
222
+ temperature = 0.9,
223
+ lmbda = 0.5,
224
+ beta = 0.5,
225
+ max_new_tokens = 128,
226
+ teacher_model_name_or_path = None,
227
+ teacher_model_init_kwargs = None,
228
+ disable_dropout = True,
229
+ seq_kd = False,
230
+ vllm_sampling_params = None,
231
+ unsloth_num_chunks = -1,
232
+ **kwargs,
233
+ ):
234
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
235
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
236
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
237
+ output_dir = 'unsloth_training_checkpoints'
238
+ save_strategy = 'no'
239
+ if dataset_num_proc is None:
240
+ from multiprocessing import cpu_count
241
+ dataset_num_proc = cpu_count()
242
+
243
+ super().__init__(
244
+ output_dir = output_dir,
245
+ overwrite_output_dir = overwrite_output_dir,
246
+ do_train = do_train,
247
+ do_eval = do_eval,
248
+ do_predict = do_predict,
249
+ eval_strategy = eval_strategy,
250
+ prediction_loss_only = prediction_loss_only,
251
+ per_device_train_batch_size = per_device_train_batch_size,
252
+ per_device_eval_batch_size = per_device_eval_batch_size,
253
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
254
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
255
+ gradient_accumulation_steps = gradient_accumulation_steps,
256
+ eval_accumulation_steps = eval_accumulation_steps,
257
+ eval_delay = eval_delay,
258
+ torch_empty_cache_steps = torch_empty_cache_steps,
259
+ learning_rate = learning_rate,
260
+ weight_decay = weight_decay,
261
+ adam_beta1 = adam_beta1,
262
+ adam_beta2 = adam_beta2,
263
+ adam_epsilon = adam_epsilon,
264
+ max_grad_norm = max_grad_norm,
265
+ num_train_epochs = num_train_epochs,
266
+ max_steps = max_steps,
267
+ lr_scheduler_type = lr_scheduler_type,
268
+ warmup_ratio = warmup_ratio,
269
+ warmup_steps = warmup_steps,
270
+ log_level = log_level,
271
+ log_level_replica = log_level_replica,
272
+ log_on_each_node = log_on_each_node,
273
+ logging_dir = logging_dir,
274
+ logging_strategy = logging_strategy,
275
+ logging_first_step = logging_first_step,
276
+ logging_steps = logging_steps,
277
+ logging_nan_inf_filter = logging_nan_inf_filter,
278
+ save_strategy = save_strategy,
279
+ save_steps = save_steps,
280
+ save_total_limit = save_total_limit,
281
+ save_safetensors = save_safetensors,
282
+ save_on_each_node = save_on_each_node,
283
+ save_only_model = save_only_model,
284
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
285
+ no_cuda = no_cuda,
286
+ use_cpu = use_cpu,
287
+ use_mps_device = use_mps_device,
288
+ seed = seed,
289
+ data_seed = data_seed,
290
+ jit_mode_eval = jit_mode_eval,
291
+ use_ipex = use_ipex,
292
+ bf16 = bf16,
293
+ fp16 = fp16,
294
+ fp16_opt_level = fp16_opt_level,
295
+ half_precision_backend = half_precision_backend,
296
+ bf16_full_eval = bf16_full_eval,
297
+ fp16_full_eval = fp16_full_eval,
298
+ tf32 = tf32,
299
+ local_rank = local_rank,
300
+ ddp_backend = ddp_backend,
301
+ tpu_num_cores = tpu_num_cores,
302
+ tpu_metrics_debug = tpu_metrics_debug,
303
+ debug = debug,
304
+ dataloader_drop_last = dataloader_drop_last,
305
+ eval_steps = eval_steps,
306
+ dataloader_num_workers = dataloader_num_workers,
307
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
308
+ past_index = past_index,
309
+ run_name = run_name,
310
+ disable_tqdm = disable_tqdm,
311
+ remove_unused_columns = remove_unused_columns,
312
+ label_names = label_names,
313
+ load_best_model_at_end = load_best_model_at_end,
314
+ metric_for_best_model = metric_for_best_model,
315
+ greater_is_better = greater_is_better,
316
+ ignore_data_skip = ignore_data_skip,
317
+ fsdp = fsdp,
318
+ fsdp_min_num_params = fsdp_min_num_params,
319
+ fsdp_config = fsdp_config,
320
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
321
+ accelerator_config = accelerator_config,
322
+ deepspeed = deepspeed,
323
+ label_smoothing_factor = label_smoothing_factor,
324
+ optim = optim,
325
+ optim_args = optim_args,
326
+ adafactor = adafactor,
327
+ group_by_length = group_by_length,
328
+ length_column_name = length_column_name,
329
+ report_to = report_to,
330
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
331
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
332
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
333
+ dataloader_pin_memory = dataloader_pin_memory,
334
+ dataloader_persistent_workers = dataloader_persistent_workers,
335
+ skip_memory_metrics = skip_memory_metrics,
336
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
337
+ push_to_hub = push_to_hub,
338
+ resume_from_checkpoint = resume_from_checkpoint,
339
+ hub_model_id = hub_model_id,
340
+ hub_strategy = hub_strategy,
341
+ hub_token = hub_token,
342
+ hub_private_repo = hub_private_repo,
343
+ hub_always_push = hub_always_push,
344
+ gradient_checkpointing = gradient_checkpointing,
345
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
346
+ include_inputs_for_metrics = include_inputs_for_metrics,
347
+ eval_do_concat_batches = eval_do_concat_batches,
348
+ fp16_backend = fp16_backend,
349
+ evaluation_strategy = evaluation_strategy,
350
+ push_to_hub_model_id = push_to_hub_model_id,
351
+ push_to_hub_organization = push_to_hub_organization,
352
+ push_to_hub_token = push_to_hub_token,
353
+ mp_parameters = mp_parameters,
354
+ auto_find_batch_size = auto_find_batch_size,
355
+ full_determinism = full_determinism,
356
+ torchdynamo = torchdynamo,
357
+ ray_scope = ray_scope,
358
+ ddp_timeout = ddp_timeout,
359
+ torch_compile = torch_compile,
360
+ torch_compile_backend = torch_compile_backend,
361
+ torch_compile_mode = torch_compile_mode,
362
+ dispatch_batches = dispatch_batches,
363
+ split_batches = split_batches,
364
+ include_tokens_per_second = include_tokens_per_second,
365
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
366
+ neftune_noise_alpha = neftune_noise_alpha,
367
+ optim_target_modules = optim_target_modules,
368
+ batch_eval_metrics = batch_eval_metrics,
369
+ eval_on_start = eval_on_start,
370
+ use_liger_kernel = use_liger_kernel,
371
+ eval_use_gather_object = eval_use_gather_object,
372
+ average_tokens_across_devices = average_tokens_across_devices,
373
+ model_init_kwargs = model_init_kwargs,
374
+ use_liger = use_liger,
375
+ dataset_text_field = dataset_text_field,
376
+ dataset_kwargs = dataset_kwargs,
377
+ dataset_num_proc = dataset_num_proc,
378
+ max_seq_length = max_seq_length,
379
+ packing = packing,
380
+ eval_packing = eval_packing,
381
+ dataset_batch_size = dataset_batch_size,
382
+ num_of_sequences = num_of_sequences,
383
+ chars_per_token = chars_per_token,
384
+ temperature = temperature,
385
+ lmbda = lmbda,
386
+ beta = beta,
387
+ max_new_tokens = max_new_tokens,
388
+ teacher_model_name_or_path = teacher_model_name_or_path,
389
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
390
+ disable_dropout = disable_dropout,
391
+ seq_kd = seq_kd,**kwargs)
392
+ self.vllm_sampling_params = vllm_sampling_params
393
+ self.unsloth_num_chunks = unsloth_num_chunks
394
+ pass
395
+
396
+ class _UnslothGKDTrainer(SFTTrainer):
397
+ _tag_names = ["trl", "gkd"]
398
+
399
+ def __init__(
400
+ self,
401
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
402
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
403
+ args: Optional[GKDConfig] = None,
404
+ data_collator: Optional[DataCollator] = None, # type: ignore
405
+ train_dataset: Optional[Dataset] = None,
406
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
407
+ processing_class: Optional[
408
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
409
+ ] = None,
410
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
411
+ callbacks: Optional[list[TrainerCallback]] = None,
412
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
413
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
414
+ peft_config: Optional["PeftConfig"] = None,
415
+ formatting_func: Optional[Callable] = None,
416
+ ):
417
+ # add remove_unused_columns=False to the dataclass args
418
+ args.remove_unused_columns = False
419
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
420
+
421
+ super().__init__(
422
+ model,
423
+ args=args,
424
+ data_collator=data_collator,
425
+ train_dataset=train_dataset,
426
+ eval_dataset=eval_dataset,
427
+ processing_class=processing_class,
428
+ compute_metrics=compute_metrics,
429
+ callbacks=callbacks,
430
+ optimizers=optimizers,
431
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
432
+ peft_config=peft_config,
433
+ formatting_func=formatting_func,
434
+ )
435
+
436
+ if args.teacher_model_init_kwargs is None:
437
+ teacher_model_init_kwargs = {}
438
+ elif not isinstance(teacher_model, str):
439
+ raise ValueError(
440
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
441
+ )
442
+ else:
443
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
444
+ teacher_model_init_kwargs["torch_dtype"] = (
445
+ teacher_model_init_kwargs["torch_dtype"]
446
+ if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
447
+ else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
448
+ )
449
+
450
+ if isinstance(teacher_model, str):
451
+ if args.use_liger:
452
+ teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
453
+ else:
454
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
455
+
456
+ # Disable dropout in the model
457
+ if args.disable_dropout:
458
+ disable_dropout_in_model(self.model)
459
+
460
+ if self.is_deepspeed_enabled:
461
+ self.teacher_model = self._prepare_deepspeed(teacher_model)
462
+ else:
463
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
464
+
465
+ self.lmbda = args.lmbda
466
+ self.beta = args.beta
467
+ self.temperature = args.temperature
468
+ self.seq_kd = args.seq_kd
469
+
470
+ self.generation_config = GenerationConfig(
471
+ max_new_tokens=args.max_new_tokens,
472
+ temperature=args.temperature,
473
+ do_sample=True,
474
+ top_k=0,
475
+ use_cache=False if args.gradient_checkpointing else True,
476
+ pad_token_id=self.processing_class.pad_token_id,
477
+ )
478
+ # Set custom EOS tokens if they are specified by the model's generation
479
+ # config. This is important for models with the Llama 3 chat template,
480
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
481
+ # turns or messages.
482
+ if (
483
+ hasattr(self.model.generation_config, "eos_token_id")
484
+ and self.model.generation_config.eos_token_id is not None
485
+ ):
486
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
487
+
488
+ def _prepare_dataset(self, dataset, *args):
489
+ # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
490
+ # need to keep the messages column as it is. We use the following workaround to keep the messages column.
491
+ dataset = dataset.add_column("_messages", dataset["messages"])
492
+ dataset = super()._prepare_dataset(dataset, *args)
493
+ dataset = dataset.rename_column("_messages", "messages")
494
+ return dataset
495
+
496
+ @staticmethod
497
+ def generalized_jsd_loss(
498
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
499
+ ):
500
+ """
501
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
502
+ of https://huggingface.co/papers/2306.13649 for the definition.
503
+
504
+ Args:
505
+ student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
506
+ teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
507
+ labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
508
+ beta: Interpolation coefficient between 0 and 1 (default: 0.5)
509
+ temperature: Softmax temperature (default: 1.0)
510
+ reduction: Specifies the reduction to apply to the output (default: 'batchmean')
511
+
512
+ Returns:
513
+ loss: Scalar tensor with the generalized JSD loss
514
+ """
515
+
516
+ # Apply temperature scaling
517
+ student_logits = student_logits / temperature
518
+ teacher_logits = teacher_logits / temperature
519
+
520
+ # Compute log probabilities for student and probabilities for teacher
521
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
522
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
523
+
524
+ # Compute the log of the mixture distribution
525
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
526
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
527
+ mixture_log_probs = torch.logsumexp(
528
+ torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
529
+ dim=0,
530
+ )
531
+
532
+ # Compute KL divergences using F.kl_div
533
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
534
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
535
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
536
+
537
+ # Compute the Generalized Jensen-Shannon Divergence
538
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
539
+
540
+ # Masking
541
+ if labels is not None:
542
+ mask = labels != -100
543
+ jsd = jsd[mask]
544
+
545
+ # Apply reduction
546
+ if reduction == "batchmean":
547
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
548
+ elif reduction == "sum":
549
+ return jsd.sum()
550
+ elif reduction == "mean":
551
+ return jsd.mean()
552
+ else:
553
+ return jsd
554
+
555
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
556
+ # compute student output
557
+ outputs_student = model(
558
+ input_ids=inputs["input_ids"],
559
+ attention_mask=inputs["attention_mask"],
560
+ )
561
+
562
+ # compute teacher output in eval mode
563
+ self.teacher_model.eval()
564
+ with torch.no_grad():
565
+ outputs_teacher = self.teacher_model(
566
+ input_ids=inputs["input_ids"],
567
+ attention_mask=inputs["attention_mask"],
568
+ )
569
+
570
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
571
+ prompt_lengths = inputs["prompts"].shape[1]
572
+ shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
573
+ shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
574
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
575
+
576
+ # compute loss
577
+ loss = self.generalized_jsd_loss(
578
+ student_logits=shifted_student_logits,
579
+ teacher_logits=shifted_teacher_logits,
580
+ labels=shifted_labels,
581
+ beta=self.beta,
582
+ )
583
+
584
+ # empty cache
585
+ empty_cache()
586
+
587
+ # Return loss
588
+ return (loss, outputs_student) if return_outputs else loss
589
+
590
+ @staticmethod
591
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
592
+ # Generate output with respect to the prompt only
593
+ generated_outputs = model.generate(
594
+ input_ids=inputs["prompts"],
595
+ attention_mask=inputs.get("prompt_attention_mask", None),
596
+ generation_config=generation_config,
597
+ return_dict_in_generate=True,
598
+ )
599
+
600
+ # Get the generated token IDs
601
+ generated_tokens = generated_outputs.sequences
602
+ # Calculate new attention mask
603
+ new_attention_mask = torch.ones_like(generated_tokens)
604
+ new_labels = generated_tokens.clone()
605
+
606
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
607
+ if pad_token_id is not None:
608
+ new_labels[new_labels == pad_token_id] = -100
609
+ new_attention_mask[generated_tokens == pad_token_id] = 0
610
+
611
+ return generated_tokens, new_attention_mask, new_labels
612
+
613
+ def training_step(
614
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
615
+ ) -> torch.Tensor:
616
+ """
617
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
618
+
619
+ This method implements the on-policy learning approach described in the GKD paper.
620
+ With probability `self.lmbda`, it generates new responses using the student model,
621
+ which are then used for training instead of the original inputs.
622
+ """
623
+ if self.seq_kd:
624
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
625
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
626
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
627
+ )
628
+ inputs["input_ids"] = new_input_ids
629
+ inputs["attention_mask"] = new_attention_mask
630
+ inputs["labels"] = new_labels
631
+ if random.random() <= self.lmbda:
632
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
633
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
634
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
635
+ )
636
+ inputs["input_ids"] = new_input_ids
637
+ inputs["attention_mask"] = new_attention_mask
638
+ inputs["labels"] = new_labels
639
+
640
+ loss = super().training_step(model, inputs, num_items_in_batch)
641
+ return loss
642
+
643
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
644
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
645
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
646
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
647
+
648
+ if model is not None:
649
+ if hasattr(model, "config"):
650
+ hidden_size = (
651
+ max(model.config.hidden_sizes)
652
+ if getattr(model.config, "hidden_sizes", None)
653
+ else getattr(model.config, "hidden_size", None)
654
+ )
655
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
656
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
657
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
658
+ config_kwargs.update(
659
+ {
660
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
661
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
662
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
663
+ }
664
+ )
665
+
666
+ # If ZeRO-3 is used, we shard both the active and reference model.
667
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
668
+ if config_kwargs["zero_optimization"]["stage"] != 3:
669
+ config_kwargs["zero_optimization"]["stage"] = 0
670
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
671
+ model.eval()
672
+ return model
673
+
674
+ def create_model_card(
675
+ self,
676
+ model_name: Optional[str] = None,
677
+ dataset_name: Optional[str] = None,
678
+ tags: Union[str, list[str], None] = None,
679
+ ):
680
+ """
681
+ Creates a draft of a model card using the information available to the `Trainer`.
682
+
683
+ Args:
684
+ model_name (`str` or `None`, *optional*, defaults to `None`):
685
+ Name of the model.
686
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
687
+ Name of the dataset used for training.
688
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
689
+ Tags to be associated with the model card.
690
+ """
691
+ if not self.is_world_process_zero():
692
+ return
693
+
694
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
695
+ base_model = self.model.config._name_or_path
696
+ else:
697
+ base_model = None
698
+
699
+ tags = tags or []
700
+ if isinstance(tags, str):
701
+ tags = [tags]
702
+
703
+ if hasattr(self.model.config, "unsloth_version"):
704
+ tags.append("unsloth")
705
+
706
+ citation = textwrap.dedent("""\
707
+ @inproceedings{agarwal2024on-policy,
708
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
709
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
710
+ year = 2024,
711
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
712
+ publisher = {OpenReview.net},
713
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
714
+ }""")
715
+
716
+ model_card = generate_model_card(
717
+ base_model=base_model,
718
+ model_name=model_name,
719
+ hub_model_id=self.hub_model_id,
720
+ dataset_name=dataset_name,
721
+ tags=tags,
722
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
723
+ comet_url=get_comet_experiment_url(),
724
+ trainer_name="GKD",
725
+ trainer_citation=citation,
726
+ paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
727
+ paper_id="2306.13649",
728
+ )
729
+
730
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
731
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
732
+ """
733
+
734
+ """
735
+ def __init__(
736
+ self,
737
+ model = None,
738
+ teacher_model = None,
739
+ args = None,
740
+ data_collator = None,
741
+ train_dataset = None,
742
+ eval_dataset = None,
743
+ processing_class = None,
744
+ compute_metrics = None,
745
+ callbacks = None,
746
+ preprocess_logits_for_metrics = None,
747
+ peft_config = None,
748
+ formatting_func = None,
749
+ **kwargs
750
+ ):
751
+ if args is None: args = UnslothGKDConfig()
752
+ use_bf16 = getattr(args, 'bf16', False)
753
+ use_fp16 = getattr(args, 'fp16', False)
754
+ force_float32 = False
755
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
756
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
757
+ force_float32 = True
758
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
759
+ dtype = getattr(model.config, 'torch_dtype', None)
760
+ if dtype is None: dtype = model.get_input_embeddings().dtype
761
+ from unsloth_zoo.utils import _get_dtype
762
+ dtype = _get_dtype(dtype)
763
+ float16 = dtype == torch.float16
764
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
765
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
766
+ if force_float32:
767
+ args.fp16 = False
768
+ args.bf16 = False
769
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
770
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
771
+ args.fp16 = float16
772
+ args.bf16 = not float16
773
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
774
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
775
+ args.eval_strategy = 'steps'
776
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
777
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
778
+ if ga_steps is not None and ga_steps > 1:
779
+ from transformers import __version__ as transformers_version
780
+ if Version(transformers_version) <= Version('4.45.2'):
781
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
782
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
783
+ if getattr(args, 'eval_strategy', 'no') != 'no':
784
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
785
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
786
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
787
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
788
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
789
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
790
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
791
+ if force_float32:
792
+ args.bf16_full_eval = False
793
+ args.fp16_full_eval = False
794
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
795
+ args.bf16_full_eval = True
796
+ args.fp16_full_eval = False
797
+ elif not bf16_full_eval and not fp16_full_eval:
798
+ args.bf16_full_eval = args.bf16
799
+ args.fp16_full_eval = args.fp16
800
+ _output_logits = False
801
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
802
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
803
+ if _output_logits:
804
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
805
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
806
+ pass
807
+ else:
808
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
809
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
810
+ if args_max_seq_length is None and model_max_seq_length is not None:
811
+ max_seq_length = model.max_seq_length
812
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
813
+ if model is not None and hasattr(model, 'for_training'):
814
+ model.for_training()
815
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
816
+ if 'processing_class' in locals():
817
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
818
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
819
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
820
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
821
+ if not isinstance(data_collator, UnslothVisionDataCollator):
822
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
823
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
824
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
825
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
826
+ else:
827
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
828
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
829
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
830
+ if not isinstance(data_collator, UnslothVisionDataCollator):
831
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
832
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
833
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
834
+ else:
835
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
836
+ other_metrics = []
837
+
838
+ from unsloth_zoo.logging_utils import PatchRLStatistics
839
+ PatchRLStatistics('gkd_trainer', other_metrics)
840
+
841
+ super().__init__(
842
+ model = model,
843
+ teacher_model = teacher_model,
844
+ args = args,
845
+ data_collator = data_collator,
846
+ train_dataset = train_dataset,
847
+ eval_dataset = eval_dataset,
848
+ processing_class = processing_class,
849
+ compute_metrics = compute_metrics,
850
+ callbacks = callbacks,
851
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
852
+ peft_config = peft_config,
853
+ formatting_func = formatting_func,**kwargs)
854
+ if hasattr(self, 'neftune_hook_handle'):
855
+ self.neftune_hook_handle.remove()
856
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
857
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
858
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
859
+ pass
860
+
861
+ pass
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
@@ -0,0 +1,1436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+
43
+ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
44
+ # All Unsloth Zoo code licensed under LGPLv3
45
+ old_logits = old_logits.to(torch.float32)
46
+ new_logits = new_logits.to(torch.float32)
47
+ input_ids = input_ids.unsqueeze(-1)
48
+
49
+ # x_i - logsumexp(x_i)
50
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
51
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
52
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
53
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
54
+
55
+ # Reverse KL
56
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
57
+ # Full correct reverse KL divergence?? Missing term maybe?
58
+ # kl_i = torch.exp(new) * kl_i
59
+
60
+ # Below is forward KL (normal KL)
61
+ # kl_i = torch.exp(old) * (old - new)
62
+
63
+ # Must detach - otherwise gradients are not propagated correctly!
64
+ # exp(x - x) == 1
65
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
66
+ loss_i = -(loss_i - beta * kl_i)
67
+
68
+ mask = mask.to(torch.float32)
69
+ n_mask_per_reward = mask.sum(1)
70
+
71
+ # See https://github.com/huggingface/trl/pull/2881
72
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
73
+ loss = loss_per_reward.mean()
74
+ # loss = (loss_i * mask).sum() / mask.sum()
75
+
76
+ # Get metrics as well which are folded
77
+ with torch.inference_mode():
78
+ completion_length = n_mask_per_reward.mean()
79
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
80
+ mean_kl = mean_kl_per_reward.mean()
81
+ pass
82
+ return loss, completion_length, mean_kl
83
+
84
+ class UnslothEfficientGRPO(torch.autograd.Function):
85
+ # All Unsloth Zoo code licensed under LGPLv3
86
+ @staticmethod
87
+ def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
88
+ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
89
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
90
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
91
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
92
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
93
+ loss, completion_length, mean_kl = grpo_compute_loss(
94
+ old_logits, new_logits, input_ids, mask, beta, advantages,
95
+ )
96
+ # Scale loss if needed for mixed precision training
97
+ scaled_loss = loss * scaling
98
+ # Must add .loss.detach otherwise autograd uses 2x VRAM
99
+ return scaled_loss, (loss.detach(), completion_length, mean_kl,)
100
+ pass
101
+
102
+ device =_new_hidden_states.device
103
+ grad_inputs = torch.empty_like(_new_hidden_states)
104
+ accumulated_loss = torch.zeros(1, device = device)
105
+ accumulated_completion_length = torch.zeros(1, device = device)
106
+ accumulated_mean_kl = torch.zeros(1, device = device)
107
+
108
+ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
109
+ (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
110
+ compute_loss,
111
+ argnums = (0,),
112
+ has_aux = True,
113
+ )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
114
+ accumulated_loss .add_(unscaled_loss)
115
+ accumulated_completion_length.add_(chunk_completion_length)
116
+ accumulated_mean_kl .add_(chunk_mean_kl)
117
+ return chunk_grad_input
118
+ pass
119
+
120
+ accumulate_chunk = torch.compile(
121
+ accumulate_chunk,
122
+ fullgraph = True,
123
+ options = torch_compile_options,
124
+ )
125
+
126
+ grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
127
+ new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
128
+ old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
129
+ input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
130
+ mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
131
+ advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
132
+
133
+ # Get mixed precision scaling if seen
134
+ scaling = scaler.get_scale() if scaler is not None else 1.0
135
+
136
+ # Force torch.compile to use dynamic shapes for seqlen dim
137
+ mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
138
+
139
+ for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
140
+ zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
141
+
142
+ mark_dynamic(new_hidden_states_j)
143
+ mark_dynamic(old_hidden_states_j)
144
+ mark_dynamic(input_ids_j)
145
+ mark_dynamic(mask_j)
146
+
147
+ grad_inputs_j.copy_(
148
+ accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
149
+ )
150
+ pass
151
+
152
+ grad_inputs .div_(n_chunks)
153
+ accumulated_loss .div_(n_chunks)
154
+ accumulated_completion_length.div_(n_chunks)
155
+ accumulated_mean_kl .div_(n_chunks)
156
+ ctx.save_for_backward(grad_inputs)
157
+
158
+ return (
159
+ accumulated_loss,
160
+ accumulated_completion_length,
161
+ accumulated_mean_kl,
162
+ )
163
+ pass
164
+
165
+ @staticmethod
166
+ def backward(ctx, grad_output, dcompletion_length, dmean_kl):
167
+ (grad_input,) = ctx.saved_tensors
168
+ return (grad_input, None, None, None, None, None, None, None, None,)
169
+ pass
170
+
171
+ def grpo_accumulated_loss(
172
+ trainer,
173
+ input_ids,
174
+ logits_to_keep,
175
+ completion_mask,
176
+ advantages,
177
+ n_chunks = -1,
178
+ ):
179
+ # All Unsloth Zoo code licensed under LGPLv3
180
+ bsz, qlen = input_ids.shape
181
+ # Find closest multiple
182
+ factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
183
+ if n_chunks == -1: n_chunks = bsz
184
+ n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
185
+
186
+ mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
187
+ os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
188
+
189
+ completion_input_ids = input_ids[:, -logits_to_keep:]
190
+ lm_head = trainer.model.get_output_embeddings().weight
191
+
192
+ with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
193
+ with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
194
+ old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
195
+ pass
196
+
197
+ new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
198
+
199
+ loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
200
+ new_hidden_states, old_hidden_states, lm_head,
201
+ completion_input_ids, completion_mask, advantages, trainer.beta,
202
+ trainer.accelerator.scaler,
203
+ n_chunks,
204
+ )
205
+ return loss, completion_length, mean_kl
206
+
207
+ # Old non efficient code path
208
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
209
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
210
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
211
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
212
+ loss, completion_length, mean_kl = grpo_compute_loss(
213
+ old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
214
+ )
215
+ return loss, completion_length, mean_kl
216
+ pass
217
+
218
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
219
+ def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
220
+ # All Unsloth Zoo code licensed under LGPLv3
221
+ old_logits = old_logits.to(torch.float32)
222
+ new_logits = new_logits.to(torch.float32)
223
+ input_ids = input_ids.unsqueeze(-1)
224
+
225
+ # x_i - logsumexp(x_i)
226
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
227
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
228
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
229
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
230
+
231
+ # Reverse KL
232
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
233
+ # Full correct reverse KL divergence?? Missing term maybe?
234
+ # kl_i = torch.exp(new) * kl_i
235
+
236
+ # Below is forward KL (normal KL)
237
+ # kl_i = torch.exp(old) * (old - new)
238
+
239
+ # Must detach - otherwise gradients are not propagated correctly!
240
+ # exp(x - x) == 1
241
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
242
+ loss_i = -(loss_i - beta * kl_i)
243
+
244
+ mask = mask.to(torch.float32)
245
+ n_mask_per_reward = mask.sum(1)
246
+
247
+ # See https://github.com/huggingface/trl/pull/2881
248
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
249
+ loss = loss_per_reward.mean()
250
+ # loss = (loss_i * mask).sum() / mask.sum()
251
+
252
+ # Get metrics as well which are folded
253
+ with torch.inference_mode():
254
+ completion_length = n_mask_per_reward.mean()
255
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
256
+ mean_kl = mean_kl_per_reward.mean()
257
+ pass
258
+ return loss, completion_length, mean_kl
259
+
260
+ def vLLMSamplingParams(**kwargs):
261
+ from vllm import SamplingParams
262
+ sampling_params = SamplingParams(**kwargs)
263
+ sampling_params._set_kwargs = kwargs
264
+ return sampling_params
265
+ @dataclass
266
+ class UnslothGRPOConfig(GRPOConfig):
267
+ """
268
+
269
+ Configuration class for the [`GRPOTrainer`].
270
+
271
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
272
+ [`~transformers.TrainingArguments`] documentation.
273
+
274
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
275
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
276
+ command line.
277
+
278
+ Parameters:
279
+ > Parameters that control the model and reference model
280
+
281
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
282
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
283
+ argument of the [`GRPOTrainer`] is provided as a string.
284
+
285
+ > Parameters that control the data preprocessing
286
+
287
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
288
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
289
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
290
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
291
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
292
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
293
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
294
+ must be divisible by this value.
295
+ temperature (`float`, *optional*, defaults to `0.9`):
296
+ Temperature for sampling. The higher the temperature, the more random the completions.
297
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
298
+ Maximum length of the generated completion.
299
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
300
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
301
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
302
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
303
+ with vLLM generation.
304
+
305
+ > Parameters that control generation acceleration powered by vLLM
306
+
307
+ use_vllm (`bool`, *optional*, defaults to `False`):
308
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
309
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
310
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
311
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
312
+ automatically select the next available GPU after the last one used for training. This assumes that
313
+ training has not already occupied all available GPUs. If only one device is available, the device will be
314
+ shared between both training and vLLM.
315
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
316
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
317
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
318
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
319
+ during initialization.
320
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
321
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
322
+ based on the model configuration. Find the supported values in the vLLM documentation.
323
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
324
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
325
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
326
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
327
+
328
+ > Parameters that control the training
329
+
330
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
331
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
332
+ [`~transformers.TrainingArguments`].
333
+ beta (`float`, *optional*, defaults to `0.04`):
334
+ KL coefficient.
335
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
336
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
337
+ weighted equally with weight `1.0`.
338
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
339
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
340
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
341
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
342
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
343
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
344
+ between the current policy and the previous reference policy during updates. The reference policy is
345
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
346
+ must set `sync_ref_model=True`.
347
+ ref_model_sync_steps (`int`, *optional*, defaults to `64`):
348
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
349
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
350
+ set `sync_ref_model=True`.
351
+
352
+ > Parameters that control the logging
353
+
354
+ log_completions (`bool`, *optional*, defaults to `False`):
355
+ Whether to log the completions during training.
356
+
357
+ """
358
+ vllm_sampling_params: Optional[Any] = field(
359
+ default = None,
360
+ metadata = {'help': 'vLLM SamplingParams'},
361
+ )
362
+ unsloth_num_chunks : Optional[int] = field(
363
+ default = -1,
364
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
365
+ )
366
+ def __init__(
367
+ self,
368
+ output_dir = None,
369
+ overwrite_output_dir = None,
370
+ do_train = False,
371
+ do_eval = False,
372
+ do_predict = False,
373
+ eval_strategy = 'no',
374
+ prediction_loss_only = False,
375
+ per_device_train_batch_size = 4,
376
+ per_device_eval_batch_size = 4,
377
+ per_gpu_train_batch_size = None,
378
+ per_gpu_eval_batch_size = None,
379
+ gradient_accumulation_steps = 2,
380
+ eval_accumulation_steps = 2,
381
+ eval_delay = 0,
382
+ torch_empty_cache_steps = 250,
383
+ learning_rate = 5e-05,
384
+ weight_decay = 0.01,
385
+ adam_beta1 = 0.9,
386
+ adam_beta2 = 0.999,
387
+ adam_epsilon = 1e-08,
388
+ max_grad_norm = 1.0,
389
+ num_train_epochs = 3.0,
390
+ max_steps = -1,
391
+ lr_scheduler_type = 'linear',
392
+ warmup_ratio = 0.1,
393
+ warmup_steps = 0,
394
+ log_level = 'passive',
395
+ log_level_replica = 'warning',
396
+ log_on_each_node = True,
397
+ logging_dir = None,
398
+ logging_strategy = 'steps',
399
+ logging_first_step = False,
400
+ logging_steps = 1,
401
+ logging_nan_inf_filter = False,
402
+ save_strategy = 'steps',
403
+ save_steps = 500,
404
+ save_total_limit = None,
405
+ save_safetensors = True,
406
+ save_on_each_node = False,
407
+ save_only_model = False,
408
+ restore_callback_states_from_checkpoint = False,
409
+ no_cuda = False,
410
+ use_cpu = False,
411
+ use_mps_device = False,
412
+ seed = 3407,
413
+ data_seed = 3407,
414
+ jit_mode_eval = False,
415
+ use_ipex = False,
416
+ bf16 = False,
417
+ fp16 = False,
418
+ fp16_opt_level = 'O1',
419
+ half_precision_backend = 'auto',
420
+ bf16_full_eval = False,
421
+ fp16_full_eval = False,
422
+ tf32 = None,
423
+ local_rank = -1,
424
+ ddp_backend = None,
425
+ tpu_num_cores = None,
426
+ tpu_metrics_debug = False,
427
+ debug = '',
428
+ dataloader_drop_last = False,
429
+ eval_steps = None,
430
+ dataloader_num_workers = 0,
431
+ dataloader_prefetch_factor = None,
432
+ past_index = -1,
433
+ run_name = None,
434
+ disable_tqdm = None,
435
+ remove_unused_columns = False,
436
+ label_names = None,
437
+ load_best_model_at_end = False,
438
+ metric_for_best_model = None,
439
+ greater_is_better = None,
440
+ ignore_data_skip = False,
441
+ fsdp = '',
442
+ fsdp_min_num_params = 0,
443
+ fsdp_config = None,
444
+ fsdp_transformer_layer_cls_to_wrap = None,
445
+ accelerator_config = None,
446
+ deepspeed = None,
447
+ label_smoothing_factor = 0.0,
448
+ optim = 'adamw_8bit',
449
+ optim_args = None,
450
+ adafactor = False,
451
+ group_by_length = False,
452
+ length_column_name = 'length',
453
+ report_to = None,
454
+ ddp_find_unused_parameters = None,
455
+ ddp_bucket_cap_mb = None,
456
+ ddp_broadcast_buffers = None,
457
+ dataloader_pin_memory = True,
458
+ dataloader_persistent_workers = False,
459
+ skip_memory_metrics = True,
460
+ use_legacy_prediction_loop = False,
461
+ push_to_hub = False,
462
+ resume_from_checkpoint = None,
463
+ hub_model_id = None,
464
+ hub_strategy = 'every_save',
465
+ hub_token = None,
466
+ hub_private_repo = None,
467
+ hub_always_push = False,
468
+ gradient_checkpointing = False,
469
+ gradient_checkpointing_kwargs = None,
470
+ include_inputs_for_metrics = False,
471
+ eval_do_concat_batches = True,
472
+ fp16_backend = 'auto',
473
+ evaluation_strategy = None,
474
+ push_to_hub_model_id = None,
475
+ push_to_hub_organization = None,
476
+ push_to_hub_token = None,
477
+ mp_parameters = '',
478
+ auto_find_batch_size = False,
479
+ full_determinism = False,
480
+ torchdynamo = None,
481
+ ray_scope = 'last',
482
+ ddp_timeout = 1800,
483
+ torch_compile = False,
484
+ torch_compile_backend = None,
485
+ torch_compile_mode = None,
486
+ dispatch_batches = None,
487
+ split_batches = None,
488
+ include_tokens_per_second = False,
489
+ include_num_input_tokens_seen = False,
490
+ neftune_noise_alpha = None,
491
+ optim_target_modules = None,
492
+ batch_eval_metrics = False,
493
+ eval_on_start = False,
494
+ use_liger_kernel = False,
495
+ eval_use_gather_object = False,
496
+ average_tokens_across_devices = False,
497
+ model_init_kwargs = None,
498
+ max_prompt_length = 512,
499
+ num_generations = 8,
500
+ temperature = 0.9,
501
+ max_completion_length = 256,
502
+ ds3_gather_for_generation = True,
503
+ use_vllm = False,
504
+ vllm_device = 'auto',
505
+ vllm_gpu_memory_utilization = 0.9,
506
+ vllm_dtype = 'auto',
507
+ vllm_max_model_len = None,
508
+ beta = 0.04,
509
+ reward_weights = None,
510
+ sync_ref_model = False,
511
+ ref_model_mixup_alpha = 0.9,
512
+ ref_model_sync_steps = 64,
513
+ log_completions = False,
514
+ vllm_sampling_params = None,
515
+ unsloth_num_chunks = -1,
516
+ **kwargs,
517
+ ):
518
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
519
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
520
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
521
+ output_dir = 'unsloth_training_checkpoints'
522
+ save_strategy = 'no'
523
+ div = per_device_train_batch_size // num_generations
524
+ if div * num_generations != per_device_train_batch_size:
525
+ print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
526
+ per_device_train_batch_size = num_generations
527
+
528
+ super().__init__(
529
+ output_dir = output_dir,
530
+ overwrite_output_dir = overwrite_output_dir,
531
+ do_train = do_train,
532
+ do_eval = do_eval,
533
+ do_predict = do_predict,
534
+ eval_strategy = eval_strategy,
535
+ prediction_loss_only = prediction_loss_only,
536
+ per_device_train_batch_size = per_device_train_batch_size,
537
+ per_device_eval_batch_size = per_device_eval_batch_size,
538
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
539
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
540
+ gradient_accumulation_steps = gradient_accumulation_steps,
541
+ eval_accumulation_steps = eval_accumulation_steps,
542
+ eval_delay = eval_delay,
543
+ torch_empty_cache_steps = torch_empty_cache_steps,
544
+ learning_rate = learning_rate,
545
+ weight_decay = weight_decay,
546
+ adam_beta1 = adam_beta1,
547
+ adam_beta2 = adam_beta2,
548
+ adam_epsilon = adam_epsilon,
549
+ max_grad_norm = max_grad_norm,
550
+ num_train_epochs = num_train_epochs,
551
+ max_steps = max_steps,
552
+ lr_scheduler_type = lr_scheduler_type,
553
+ warmup_ratio = warmup_ratio,
554
+ warmup_steps = warmup_steps,
555
+ log_level = log_level,
556
+ log_level_replica = log_level_replica,
557
+ log_on_each_node = log_on_each_node,
558
+ logging_dir = logging_dir,
559
+ logging_strategy = logging_strategy,
560
+ logging_first_step = logging_first_step,
561
+ logging_steps = logging_steps,
562
+ logging_nan_inf_filter = logging_nan_inf_filter,
563
+ save_strategy = save_strategy,
564
+ save_steps = save_steps,
565
+ save_total_limit = save_total_limit,
566
+ save_safetensors = save_safetensors,
567
+ save_on_each_node = save_on_each_node,
568
+ save_only_model = save_only_model,
569
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
570
+ no_cuda = no_cuda,
571
+ use_cpu = use_cpu,
572
+ use_mps_device = use_mps_device,
573
+ seed = seed,
574
+ data_seed = data_seed,
575
+ jit_mode_eval = jit_mode_eval,
576
+ use_ipex = use_ipex,
577
+ bf16 = bf16,
578
+ fp16 = fp16,
579
+ fp16_opt_level = fp16_opt_level,
580
+ half_precision_backend = half_precision_backend,
581
+ bf16_full_eval = bf16_full_eval,
582
+ fp16_full_eval = fp16_full_eval,
583
+ tf32 = tf32,
584
+ local_rank = local_rank,
585
+ ddp_backend = ddp_backend,
586
+ tpu_num_cores = tpu_num_cores,
587
+ tpu_metrics_debug = tpu_metrics_debug,
588
+ debug = debug,
589
+ dataloader_drop_last = dataloader_drop_last,
590
+ eval_steps = eval_steps,
591
+ dataloader_num_workers = dataloader_num_workers,
592
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
593
+ past_index = past_index,
594
+ run_name = run_name,
595
+ disable_tqdm = disable_tqdm,
596
+ remove_unused_columns = remove_unused_columns,
597
+ label_names = label_names,
598
+ load_best_model_at_end = load_best_model_at_end,
599
+ metric_for_best_model = metric_for_best_model,
600
+ greater_is_better = greater_is_better,
601
+ ignore_data_skip = ignore_data_skip,
602
+ fsdp = fsdp,
603
+ fsdp_min_num_params = fsdp_min_num_params,
604
+ fsdp_config = fsdp_config,
605
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
606
+ accelerator_config = accelerator_config,
607
+ deepspeed = deepspeed,
608
+ label_smoothing_factor = label_smoothing_factor,
609
+ optim = optim,
610
+ optim_args = optim_args,
611
+ adafactor = adafactor,
612
+ group_by_length = group_by_length,
613
+ length_column_name = length_column_name,
614
+ report_to = report_to,
615
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
616
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
617
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
618
+ dataloader_pin_memory = dataloader_pin_memory,
619
+ dataloader_persistent_workers = dataloader_persistent_workers,
620
+ skip_memory_metrics = skip_memory_metrics,
621
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
622
+ push_to_hub = push_to_hub,
623
+ resume_from_checkpoint = resume_from_checkpoint,
624
+ hub_model_id = hub_model_id,
625
+ hub_strategy = hub_strategy,
626
+ hub_token = hub_token,
627
+ hub_private_repo = hub_private_repo,
628
+ hub_always_push = hub_always_push,
629
+ gradient_checkpointing = gradient_checkpointing,
630
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
631
+ include_inputs_for_metrics = include_inputs_for_metrics,
632
+ eval_do_concat_batches = eval_do_concat_batches,
633
+ fp16_backend = fp16_backend,
634
+ evaluation_strategy = evaluation_strategy,
635
+ push_to_hub_model_id = push_to_hub_model_id,
636
+ push_to_hub_organization = push_to_hub_organization,
637
+ push_to_hub_token = push_to_hub_token,
638
+ mp_parameters = mp_parameters,
639
+ auto_find_batch_size = auto_find_batch_size,
640
+ full_determinism = full_determinism,
641
+ torchdynamo = torchdynamo,
642
+ ray_scope = ray_scope,
643
+ ddp_timeout = ddp_timeout,
644
+ torch_compile = torch_compile,
645
+ torch_compile_backend = torch_compile_backend,
646
+ torch_compile_mode = torch_compile_mode,
647
+ dispatch_batches = dispatch_batches,
648
+ split_batches = split_batches,
649
+ include_tokens_per_second = include_tokens_per_second,
650
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
651
+ neftune_noise_alpha = neftune_noise_alpha,
652
+ optim_target_modules = optim_target_modules,
653
+ batch_eval_metrics = batch_eval_metrics,
654
+ eval_on_start = eval_on_start,
655
+ use_liger_kernel = use_liger_kernel,
656
+ eval_use_gather_object = eval_use_gather_object,
657
+ average_tokens_across_devices = average_tokens_across_devices,
658
+ model_init_kwargs = model_init_kwargs,
659
+ max_prompt_length = max_prompt_length,
660
+ num_generations = num_generations,
661
+ temperature = temperature,
662
+ max_completion_length = max_completion_length,
663
+ ds3_gather_for_generation = ds3_gather_for_generation,
664
+ use_vllm = use_vllm,
665
+ vllm_device = vllm_device,
666
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
667
+ vllm_dtype = vllm_dtype,
668
+ vllm_max_model_len = vllm_max_model_len,
669
+ beta = beta,
670
+ reward_weights = reward_weights,
671
+ sync_ref_model = sync_ref_model,
672
+ ref_model_mixup_alpha = ref_model_mixup_alpha,
673
+ ref_model_sync_steps = ref_model_sync_steps,
674
+ log_completions = log_completions,**kwargs)
675
+ self.vllm_sampling_params = vllm_sampling_params
676
+ self.unsloth_num_chunks = unsloth_num_chunks
677
+ pass
678
+
679
+ class _UnslothGRPOTrainer(Trainer):
680
+ """"""
681
+
682
+ _tag_names = ["trl", "grpo"]
683
+
684
+ def __init__(
685
+ self,
686
+ model: Union[str, PreTrainedModel],
687
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
688
+ args: GRPOConfig = None,
689
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
690
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
691
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
692
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
693
+ callbacks: Optional[list[TrainerCallback]] = None,
694
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
695
+ peft_config: Optional["PeftConfig"] = None,
696
+ ):
697
+
698
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
699
+ # Args
700
+ if args is None:
701
+ model_name = model if isinstance(model, str) else model.config._name_or_path
702
+ model_name = model_name.split("/")[-1]
703
+ args = GRPOConfig(f"{model_name}-GRPO")
704
+
705
+ # Models
706
+ # Trained model
707
+ model_init_kwargs = args.model_init_kwargs or {}
708
+ if isinstance(model, str):
709
+ model_id = model
710
+ torch_dtype = model_init_kwargs.get("torch_dtype")
711
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
712
+ pass # torch_dtype is already a torch.dtype or "auto" or None
713
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
714
+ torch_dtype = getattr(torch, torch_dtype)
715
+ model_init_kwargs["torch_dtype"] = torch_dtype
716
+ else:
717
+ raise ValueError(
718
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
719
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
720
+ )
721
+ # Disable caching if gradient checkpointing is enabled (not supported)
722
+ model_init_kwargs["use_cache"] = (
723
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
724
+ )
725
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
726
+ else:
727
+ model_id = model.config._name_or_path
728
+ if args.model_init_kwargs is not None:
729
+ raise ValueError(
730
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
731
+ "This argument can only be used when the `model` argument is a string."
732
+ )
733
+
734
+ if False:
735
+ model = model
736
+
737
+ # Reference model
738
+ if is_deepspeed_zero3_enabled():
739
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
740
+ elif not is_peft_model(model):
741
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
742
+ self.ref_model = create_reference_model(model)
743
+ else:
744
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
745
+ # to revert to the initial model.
746
+ self.ref_model = None
747
+
748
+ # Processing class
749
+ if processing_class is None:
750
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
751
+
752
+ # Reward functions
753
+ if not isinstance(reward_funcs, list):
754
+ reward_funcs = [reward_funcs]
755
+ for i, reward_func in enumerate(reward_funcs):
756
+ if isinstance(reward_func, str):
757
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
758
+ reward_func, num_labels=1, **model_init_kwargs
759
+ )
760
+ self.reward_funcs = reward_funcs
761
+
762
+ # Reward weights
763
+ if args.reward_weights is not None:
764
+ if len(args.reward_weights) != len(reward_funcs):
765
+ raise ValueError(
766
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
767
+ f"functions ({len(reward_funcs)})"
768
+ )
769
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
770
+ else:
771
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
772
+
773
+ # Reward processing class
774
+ if reward_processing_classes is None:
775
+ reward_processing_classes = [None] * len(reward_funcs)
776
+ elif not isinstance(reward_processing_classes, list):
777
+ reward_processing_classes = [reward_processing_classes]
778
+ else:
779
+ if len(reward_processing_classes) != len(reward_funcs):
780
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
781
+
782
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
783
+ if isinstance(reward_func, PreTrainedModel):
784
+ if reward_processing_class is None:
785
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
786
+ if reward_processing_class.pad_token_id is None:
787
+ reward_processing_class.pad_token = reward_processing_class.eos_token
788
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
789
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
790
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
791
+ reward_processing_classes[i] = reward_processing_class
792
+ self.reward_processing_classes = reward_processing_classes
793
+
794
+ # Data collator
795
+ def data_collator(features): # No data collation is needed in GRPO
796
+ return features
797
+
798
+ # Training arguments
799
+ self.max_prompt_length = args.max_prompt_length
800
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
801
+ self.num_generations = args.num_generations # = G in the GRPO paper
802
+ self.use_vllm = args.use_vllm
803
+
804
+ self.beta = args.beta
805
+
806
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
807
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
808
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
809
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
810
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
811
+ # This acts as a flag to indicate that the warning has already been issued.
812
+ model.warnings_issued["estimate_tokens"] = True
813
+
814
+ # Initialize the metrics
815
+ self._metrics = defaultdict(list)
816
+ self.log_completions = args.log_completions
817
+
818
+ super().__init__(
819
+ model=model,
820
+ args=args,
821
+ data_collator=data_collator,
822
+ train_dataset=train_dataset,
823
+ eval_dataset=eval_dataset,
824
+ processing_class=processing_class,
825
+ callbacks=callbacks,
826
+ optimizers=optimizers,
827
+ )
828
+
829
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
830
+ num_processes = self.accelerator.num_processes
831
+ global_batch_size = args.per_device_train_batch_size * num_processes
832
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
833
+ if self.num_generations not in possible_values:
834
+ raise ValueError(
835
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
836
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
837
+ f"batch size, the valid values for the number of generations are: {possible_values}."
838
+ )
839
+ if self.args.eval_strategy != "no":
840
+ global_batch_size = args.per_device_eval_batch_size * num_processes
841
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
842
+ if self.num_generations not in possible_values:
843
+ raise ValueError(
844
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
845
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
846
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
847
+ )
848
+
849
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
850
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
851
+ # it's safer to set it in all cases.
852
+ set_seed(args.seed, device_specific=True)
853
+
854
+ if self.use_vllm:
855
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
856
+ temperature=args.temperature,
857
+ max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
858
+ else:
859
+ self.generation_config = GenerationConfig(
860
+ max_new_tokens=self.max_completion_length,
861
+ do_sample=True,
862
+ temperature=args.temperature,
863
+ pad_token_id=processing_class.pad_token_id,
864
+ )
865
+
866
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
867
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
868
+ # self.model_accepts_loss_kwargs to False to enable scaling.
869
+ self.model_accepts_loss_kwargs = False
870
+
871
+ # Add tags to the model
872
+ self.model.add_model_tags(self._tag_names)
873
+
874
+ if self.ref_model is not None:
875
+ if self.is_deepspeed_enabled:
876
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
877
+ else:
878
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
879
+
880
+ if args.sync_ref_model:
881
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
882
+
883
+ for i, reward_func in enumerate(self.reward_funcs):
884
+ if isinstance(reward_func, PreTrainedModel):
885
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
886
+
887
+ def _set_signature_columns_if_needed(self):
888
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
889
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
890
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
891
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
892
+ if self._signature_columns is None:
893
+ self._signature_columns = ["prompt"]
894
+
895
+ def _get_train_sampler(self) -> Sampler:
896
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
897
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
898
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
899
+ # preventing discrepancies in group formation.
900
+ return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
901
+
902
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
903
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
904
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
905
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
906
+ # preventing discrepancies in group formation.
907
+ return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
908
+
909
+ # Get the per-token log probabilities for the completions for the model and the reference model
910
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
911
+ if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
912
+ return None # Unsloth efficient GRPO
913
+ # Otherwise, calculate normally:
914
+ if not hasattr(self, '_autocast_dtype'):
915
+ self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
916
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
917
+ with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
918
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
919
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
920
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
921
+
922
+ input_ids = input_ids[:, -logits_to_keep:]
923
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
924
+ # See https://github.com/huggingface/trl/issues/2770
925
+ logits = logits[:, -logits_to_keep:]
926
+ return logits
927
+ # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
928
+ pass
929
+
930
+ def _move_model_to_vllm(self, *args, **kwargs): return None
931
+
932
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
933
+ device = self.accelerator.device
934
+ prompts = [x["prompt"] for x in inputs]
935
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
936
+ prompt_inputs = self.processing_class(
937
+ prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
938
+ )
939
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
940
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
941
+
942
+ if self.max_prompt_length is not None:
943
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
944
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
945
+
946
+ # Generate completions using either vLLM or regular generation
947
+ if self.args.use_vllm:
948
+ # First, have main process load weights if needed
949
+ if self.state.global_step != self._last_loaded_step:
950
+ self._move_model_to_vllm()
951
+ self._last_loaded_step = self.state.global_step
952
+
953
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
954
+ all_prompts_text = gather_object(prompts_text)
955
+ if self.accelerator.is_main_process:
956
+ outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
957
+ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
958
+ else:
959
+ completion_ids = [None] * len(all_prompts_text)
960
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
961
+ # corresponding slice.
962
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
963
+ process_slice = slice(
964
+ self.accelerator.process_index * len(prompts),
965
+ (self.accelerator.process_index + 1) * len(prompts),
966
+ )
967
+ completion_ids = completion_ids[process_slice]
968
+
969
+ # Pad the completions, and concatenate them with the prompts
970
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
971
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
972
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
973
+ else:
974
+ # Regular generation path
975
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
976
+ prompt_completion_ids = unwrapped_model.generate(
977
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
978
+ )
979
+
980
+ # Compute prompt length and extract completion ids
981
+ prompt_length = prompt_ids.size(1)
982
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
983
+ completion_ids = prompt_completion_ids[:, prompt_length:]
984
+
985
+ # Mask everything after the first EOS token
986
+ is_eos = completion_ids == self.processing_class.eos_token_id
987
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
988
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
989
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
990
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
991
+
992
+ # Concatenate prompt_mask with completion_mask for logit computation
993
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
994
+
995
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
996
+
997
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
998
+ if self.ref_model is not None:
999
+ ref_per_token_logps = self._get_per_token_logps(
1000
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
1001
+ )
1002
+ else:
1003
+ with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
1004
+ ref_per_token_logps = self._get_per_token_logps(
1005
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
1006
+ )
1007
+
1008
+ # Decode the generated completions
1009
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1010
+ if is_conversational(inputs[0]):
1011
+ completions = []
1012
+ for prompt, completion in zip(prompts, completions_text):
1013
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1014
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
1015
+ else:
1016
+ completions = completions_text
1017
+
1018
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1019
+ for i, (reward_func, reward_processing_class) in enumerate(
1020
+ zip(self.reward_funcs, self.reward_processing_classes)
1021
+ ):
1022
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1023
+ if is_conversational(inputs[0]):
1024
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1025
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1026
+ else:
1027
+ texts = [p + c for p, c in zip(prompts, completions)]
1028
+ reward_inputs = reward_processing_class(
1029
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1030
+ )
1031
+ reward_inputs = super()._prepare_inputs(reward_inputs)
1032
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1033
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1034
+ else:
1035
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1036
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
1037
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1038
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1039
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1040
+
1041
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1042
+ # completions may be distributed across processes
1043
+ rewards_per_func = gather(rewards_per_func)
1044
+
1045
+ # Apply weights to each reward function's output and sum
1046
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
1047
+
1048
+ # Compute grouped-wise rewards
1049
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1050
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1051
+
1052
+ # Normalize the rewards to compute the advantages
1053
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1054
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1055
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1056
+
1057
+ # Slice to keep only the local part of the data
1058
+ process_slice = slice(
1059
+ self.accelerator.process_index * len(prompts),
1060
+ (self.accelerator.process_index + 1) * len(prompts),
1061
+ )
1062
+ advantages = advantages[process_slice]
1063
+
1064
+ # Log the metrics
1065
+ reward_per_func = rewards_per_func.mean(0)
1066
+ for i, reward_func in enumerate(self.reward_funcs):
1067
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1068
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1069
+ else:
1070
+ reward_func_name = reward_func.__name__
1071
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1072
+
1073
+ self._metrics["reward"].append(rewards.mean().item())
1074
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
1075
+
1076
+ if (
1077
+ self.log_completions
1078
+ and self.state.global_step % self.args.logging_steps == 0
1079
+ and "wandb" in self.args.report_to
1080
+ ):
1081
+ import pandas as pd
1082
+
1083
+ # For logging
1084
+ table = {
1085
+ "step": [str(self.state.global_step)] * len(rewards),
1086
+ "prompt": gather_object(prompts_text),
1087
+ "completion": gather_object(completions_text),
1088
+ "reward": rewards.tolist(),
1089
+ }
1090
+ df = pd.DataFrame(table)
1091
+
1092
+ if wandb.run is not None and self.accelerator.is_main_process:
1093
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1094
+
1095
+ return {
1096
+ "prompt_ids": prompt_ids,
1097
+ "prompt_mask": prompt_mask,
1098
+ "completion_ids": completion_ids,
1099
+ "completion_mask": completion_mask,
1100
+ "ref_per_token_logps": ref_per_token_logps,
1101
+ "advantages": advantages,
1102
+ }
1103
+
1104
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
1105
+ if return_outputs:
1106
+ raise ValueError("The GRPOTrainer does not support returning outputs")
1107
+ # Compute the per-token log probabilities for the model
1108
+
1109
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1110
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1111
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1112
+ bsz, qlen = input_ids.shape
1113
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1114
+ # attention_mask = None
1115
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1116
+ _input_ids = input_ids
1117
+ _logits_to_keep = logits_to_keep
1118
+
1119
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
1120
+
1121
+ # Compute the KL divergence between the model and the reference model
1122
+ ref_per_token_logps = inputs["ref_per_token_logps"]
1123
+ # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
1124
+
1125
+ # x - x.detach() allows for preserving gradients from x
1126
+ advantages = inputs["advantages"]
1127
+ # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1128
+ # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1129
+ # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1130
+ input_ids = input_ids[:, -logits_to_keep:]
1131
+ if per_token_logps is not None:
1132
+ loss, completion_length, mean_kl = grpo_compute_loss_slow(
1133
+ ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
1134
+ )
1135
+ else:
1136
+ loss, completion_length, mean_kl = grpo_accumulated_loss(
1137
+ self, _input_ids, logits_to_keep, completion_mask, advantages,
1138
+ n_chunks = self.args.unsloth_num_chunks,
1139
+ )
1140
+
1141
+ # Log the metrics
1142
+ # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1143
+
1144
+ # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1145
+ # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1146
+
1147
+ if "train" in self._metrics:
1148
+ mode = "eval" if self.control.should_evaluate else "train"
1149
+ self._metrics[mode]["completion_length"].append(completion_length.item())
1150
+ self._metrics[mode]["kl"].append(mean_kl.item())
1151
+ else:
1152
+ self._metrics["completion_length"].append(completion_length.item())
1153
+ self._metrics["kl"].append(mean_kl.item())
1154
+ return loss
1155
+
1156
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
1157
+ inputs = self._prepare_inputs(inputs)
1158
+ with torch.no_grad():
1159
+ with self.compute_loss_context_manager():
1160
+ loss = self.compute_loss(model, inputs)
1161
+ loss = loss.mean().detach()
1162
+ return loss, None, None
1163
+
1164
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1165
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1166
+
1167
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1168
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1169
+ if next(iter(logs.keys())).startswith("eval_"):
1170
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1171
+
1172
+ logs = {**logs, **metrics}
1173
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1174
+ super().log(logs, start_time)
1175
+ else: # transformers<=4.46
1176
+ super().log(logs)
1177
+ self._metrics.clear()
1178
+
1179
+ def create_model_card(
1180
+ self,
1181
+ model_name: Optional[str] = None,
1182
+ dataset_name: Optional[str] = None,
1183
+ tags: Union[str, list[str], None] = None,
1184
+ ):
1185
+ """
1186
+ Creates a draft of a model card using the information available to the `Trainer`.
1187
+
1188
+ Args:
1189
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1190
+ Name of the model.
1191
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1192
+ Name of the dataset used for training.
1193
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1194
+ Tags to be associated with the model card.
1195
+ """
1196
+ if not self.is_world_process_zero():
1197
+ return
1198
+
1199
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1200
+ base_model = self.model.config._name_or_path
1201
+ else:
1202
+ base_model = None
1203
+
1204
+ tags = tags or []
1205
+ if isinstance(tags, str):
1206
+ tags = [tags]
1207
+
1208
+ if hasattr(self.model.config, "unsloth_version"):
1209
+ tags.append("unsloth")
1210
+
1211
+ citation = textwrap.dedent(
1212
+ """\
1213
+ @article{zhihong2024deepseekmath,
1214
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1215
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1216
+ year = 2024,
1217
+ eprint = {arXiv:2402.03300},
1218
+ }
1219
+ """
1220
+ )
1221
+
1222
+ model_card = generate_model_card(
1223
+ base_model=base_model,
1224
+ model_name=model_name,
1225
+ hub_model_id=self.hub_model_id,
1226
+ dataset_name=dataset_name,
1227
+ tags=tags,
1228
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1229
+ comet_url=get_comet_experiment_url(),
1230
+ trainer_name="GRPO",
1231
+ trainer_citation=citation,
1232
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1233
+ paper_id="2402.03300",
1234
+ )
1235
+
1236
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1237
+ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
1238
+ """
1239
+
1240
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
1241
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
1242
+
1243
+ Example:
1244
+
1245
+ ```python
1246
+ from datasets import load_dataset
1247
+ from trl import GRPOTrainer
1248
+
1249
+ dataset = load_dataset("trl-lib/tldr", split="train")
1250
+
1251
+ def reward_func(completions, **kwargs):
1252
+ # Dummy reward function that rewards completions with more unique letters.
1253
+ return [float(len(set(completion))) for completion in completions]
1254
+
1255
+ trainer = GRPOTrainer(
1256
+ model="Qwen/Qwen2-0.5B-Instruct",
1257
+ reward_funcs=reward_func,
1258
+ train_dataset=dataset,
1259
+ )
1260
+
1261
+ trainer.train()
1262
+ ```
1263
+
1264
+ Args:
1265
+ model (`Union[str, PreTrainedModel]`):
1266
+ Model to be trained. Can be either:
1267
+
1268
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
1269
+ a path to a *directory* containing model weights saved using
1270
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
1271
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
1272
+ in `args.model_init_kwargs`.
1273
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1274
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
1275
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
1276
+ functions with the prompts and completions and sum the rewards. Can be either:
1277
+
1278
+ - A single reward function, such as:
1279
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
1280
+ path to a *directory* containing model weights saved using
1281
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1282
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
1283
+ keyword arguments in `args.model_init_kwargs`.
1284
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
1285
+ - A custom reward function: The function is provided with the prompts and the generated completions,
1286
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
1287
+ [Using a custom reward function](#using-a-custom-reward-function).
1288
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
1289
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
1290
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
1291
+ Configuration for this trainer. If `None`, a default configuration is used.
1292
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1293
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
1294
+ ignored. The format of the samples can be either:
1295
+
1296
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1297
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1298
+ and content).
1299
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1300
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1301
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
1302
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
1303
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
1304
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
1305
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
1306
+
1307
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
1308
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
1309
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
1310
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
1311
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
1312
+ the corresponding entries in `reward_processing_classes` are ignored.
1313
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1314
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
1315
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1316
+
1317
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1318
+ method.
1319
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1320
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1321
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1322
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1323
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1324
+
1325
+ """
1326
+ def __init__(
1327
+ self,
1328
+ model,
1329
+ reward_funcs,
1330
+ args = None,
1331
+ train_dataset = None,
1332
+ eval_dataset = None,
1333
+ processing_class = None,
1334
+ reward_processing_classes = None,
1335
+ callbacks = None,
1336
+ peft_config = None,
1337
+ **kwargs
1338
+ ):
1339
+ if args is None: args = UnslothGRPOConfig()
1340
+ use_bf16 = getattr(args, 'bf16', False)
1341
+ use_fp16 = getattr(args, 'fp16', False)
1342
+ force_float32 = False
1343
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1344
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1345
+ force_float32 = True
1346
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1347
+ dtype = getattr(model.config, 'torch_dtype', None)
1348
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1349
+ from unsloth_zoo.utils import _get_dtype
1350
+ dtype = _get_dtype(dtype)
1351
+ float16 = dtype == torch.float16
1352
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1353
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1354
+ if force_float32:
1355
+ args.fp16 = False
1356
+ args.bf16 = False
1357
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1358
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1359
+ args.fp16 = float16
1360
+ args.bf16 = not float16
1361
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1362
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1363
+ args.eval_strategy = 'steps'
1364
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1365
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1366
+ if ga_steps is not None and ga_steps > 1:
1367
+ from transformers import __version__ as transformers_version
1368
+ if Version(transformers_version) <= Version('4.45.2'):
1369
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1370
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1371
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1372
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1373
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1374
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1375
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1376
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1377
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1378
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1379
+ if force_float32:
1380
+ args.bf16_full_eval = False
1381
+ args.fp16_full_eval = False
1382
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1383
+ args.bf16_full_eval = True
1384
+ args.fp16_full_eval = False
1385
+ elif not bf16_full_eval and not fp16_full_eval:
1386
+ args.bf16_full_eval = args.bf16
1387
+ args.fp16_full_eval = args.fp16
1388
+ _output_logits = False
1389
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1390
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1391
+ if _output_logits:
1392
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1393
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1394
+ pass
1395
+ else:
1396
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1397
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1398
+ if args_max_seq_length is None and model_max_seq_length is not None:
1399
+ max_seq_length = model.max_seq_length
1400
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1401
+ if model is not None and hasattr(model, 'for_training'):
1402
+ model.for_training()
1403
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1404
+ if 'processing_class' in locals():
1405
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1406
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1407
+ other_metrics = []
1408
+ if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
1409
+ else: _reward_funcs = reward_funcs
1410
+ for reward_func in _reward_funcs:
1411
+ try:
1412
+ reward_func_name = reward_func.__name__
1413
+ other_metrics.append(f'rewards/{reward_func_name}')
1414
+ except: pass
1415
+
1416
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1417
+ PatchRLStatistics('grpo_trainer', other_metrics)
1418
+
1419
+ super().__init__(
1420
+ model = model,
1421
+ reward_funcs = reward_funcs,
1422
+ args = args,
1423
+ train_dataset = train_dataset,
1424
+ eval_dataset = eval_dataset,
1425
+ processing_class = processing_class,
1426
+ reward_processing_classes = reward_processing_classes,
1427
+ callbacks = callbacks,
1428
+ peft_config = peft_config,**kwargs)
1429
+ if hasattr(self, 'neftune_hook_handle'):
1430
+ self.neftune_hook_handle.remove()
1431
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1432
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1433
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1434
+ pass
1435
+
1436
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothKTOConfig(KTOConfig):
44
+ """
45
+
46
+ Configuration class for the [`KTOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model.
67
+ loss_type (`str`, *optional*, defaults to `"kto"`):
68
+ Type of loss to use. Possible values are:
69
+
70
+ - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
71
+ - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
72
+
73
+ desirable_weight (`float`, *optional*, defaults to `1.0`):
74
+ Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
75
+ undesirable_weight (`float`, *optional*, defaults to `1.0`):
76
+ Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
77
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
+ Label pad token id. This argument is required if you want to use the default data collator.
79
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
80
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
81
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
+ This argument is required if you want to use the default data collator.
84
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
85
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
86
+ evaluation.
87
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
+ you need to specify if the model returned by the callable is an encoder-decoder model.
90
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
91
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
92
+ useful when training without the reference model to reduce the total GPU memory needed.
93
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
94
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
95
+ string.
96
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
98
+ from a string.
99
+ dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+ disable_dropout (`bool`, *optional*, defaults to `True`):
102
+ Whether to disable dropout in the model and reference model.
103
+
104
+ """
105
+ vllm_sampling_params: Optional[Any] = field(
106
+ default = None,
107
+ metadata = {'help': 'vLLM SamplingParams'},
108
+ )
109
+ unsloth_num_chunks : Optional[int] = field(
110
+ default = -1,
111
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
112
+ )
113
+ def __init__(
114
+ self,
115
+ output_dir = None,
116
+ overwrite_output_dir = None,
117
+ do_train = False,
118
+ do_eval = False,
119
+ do_predict = False,
120
+ eval_strategy = 'no',
121
+ prediction_loss_only = False,
122
+ per_device_train_batch_size = 4,
123
+ per_device_eval_batch_size = 4,
124
+ per_gpu_train_batch_size = None,
125
+ per_gpu_eval_batch_size = None,
126
+ gradient_accumulation_steps = 2,
127
+ eval_accumulation_steps = 2,
128
+ eval_delay = 0,
129
+ torch_empty_cache_steps = 250,
130
+ learning_rate = 5e-05,
131
+ weight_decay = 0.01,
132
+ adam_beta1 = 0.9,
133
+ adam_beta2 = 0.999,
134
+ adam_epsilon = 1e-08,
135
+ max_grad_norm = 1.0,
136
+ num_train_epochs = 3.0,
137
+ max_steps = -1,
138
+ lr_scheduler_type = 'linear',
139
+ warmup_ratio = 0.1,
140
+ warmup_steps = 0,
141
+ log_level = 'passive',
142
+ log_level_replica = 'warning',
143
+ log_on_each_node = True,
144
+ logging_dir = None,
145
+ logging_strategy = 'steps',
146
+ logging_first_step = False,
147
+ logging_steps = 1,
148
+ logging_nan_inf_filter = False,
149
+ save_strategy = 'steps',
150
+ save_steps = 500,
151
+ save_total_limit = None,
152
+ save_safetensors = True,
153
+ save_on_each_node = False,
154
+ save_only_model = False,
155
+ restore_callback_states_from_checkpoint = False,
156
+ no_cuda = False,
157
+ use_cpu = False,
158
+ use_mps_device = False,
159
+ seed = 3407,
160
+ data_seed = 3407,
161
+ jit_mode_eval = False,
162
+ use_ipex = False,
163
+ bf16 = False,
164
+ fp16 = False,
165
+ fp16_opt_level = 'O1',
166
+ half_precision_backend = 'auto',
167
+ bf16_full_eval = False,
168
+ fp16_full_eval = False,
169
+ tf32 = None,
170
+ local_rank = -1,
171
+ ddp_backend = None,
172
+ tpu_num_cores = None,
173
+ tpu_metrics_debug = False,
174
+ debug = '',
175
+ dataloader_drop_last = False,
176
+ eval_steps = None,
177
+ dataloader_num_workers = 0,
178
+ dataloader_prefetch_factor = None,
179
+ past_index = -1,
180
+ run_name = None,
181
+ disable_tqdm = None,
182
+ remove_unused_columns = True,
183
+ label_names = None,
184
+ load_best_model_at_end = False,
185
+ metric_for_best_model = None,
186
+ greater_is_better = None,
187
+ ignore_data_skip = False,
188
+ fsdp = '',
189
+ fsdp_min_num_params = 0,
190
+ fsdp_config = None,
191
+ fsdp_transformer_layer_cls_to_wrap = None,
192
+ accelerator_config = None,
193
+ deepspeed = None,
194
+ label_smoothing_factor = 0.0,
195
+ optim = 'adamw_8bit',
196
+ optim_args = None,
197
+ adafactor = False,
198
+ group_by_length = False,
199
+ length_column_name = 'length',
200
+ report_to = None,
201
+ ddp_find_unused_parameters = None,
202
+ ddp_bucket_cap_mb = None,
203
+ ddp_broadcast_buffers = None,
204
+ dataloader_pin_memory = True,
205
+ dataloader_persistent_workers = False,
206
+ skip_memory_metrics = True,
207
+ use_legacy_prediction_loop = False,
208
+ push_to_hub = False,
209
+ resume_from_checkpoint = None,
210
+ hub_model_id = None,
211
+ hub_strategy = 'every_save',
212
+ hub_token = None,
213
+ hub_private_repo = None,
214
+ hub_always_push = False,
215
+ gradient_checkpointing = False,
216
+ gradient_checkpointing_kwargs = None,
217
+ include_inputs_for_metrics = False,
218
+ eval_do_concat_batches = True,
219
+ fp16_backend = 'auto',
220
+ evaluation_strategy = None,
221
+ push_to_hub_model_id = None,
222
+ push_to_hub_organization = None,
223
+ push_to_hub_token = None,
224
+ mp_parameters = '',
225
+ auto_find_batch_size = False,
226
+ full_determinism = False,
227
+ torchdynamo = None,
228
+ ray_scope = 'last',
229
+ ddp_timeout = 1800,
230
+ torch_compile = False,
231
+ torch_compile_backend = None,
232
+ torch_compile_mode = None,
233
+ dispatch_batches = None,
234
+ split_batches = None,
235
+ include_tokens_per_second = False,
236
+ include_num_input_tokens_seen = False,
237
+ neftune_noise_alpha = None,
238
+ optim_target_modules = None,
239
+ batch_eval_metrics = False,
240
+ eval_on_start = False,
241
+ use_liger_kernel = False,
242
+ eval_use_gather_object = False,
243
+ average_tokens_across_devices = False,
244
+ max_length = 1024,
245
+ max_prompt_length = 512,
246
+ max_completion_length = None,
247
+ beta = 0.1,
248
+ loss_type = 'kto',
249
+ desirable_weight = 1.0,
250
+ undesirable_weight = 1.0,
251
+ label_pad_token_id = -100,
252
+ padding_value = None,
253
+ truncation_mode = 'keep_end',
254
+ generate_during_eval = False,
255
+ is_encoder_decoder = None,
256
+ disable_dropout = True,
257
+ precompute_ref_log_probs = False,
258
+ model_init_kwargs = None,
259
+ ref_model_init_kwargs = None,
260
+ dataset_num_proc = None,
261
+ vllm_sampling_params = None,
262
+ unsloth_num_chunks = -1,
263
+ **kwargs,
264
+ ):
265
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
266
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
267
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
268
+ output_dir = 'unsloth_training_checkpoints'
269
+ save_strategy = 'no'
270
+ if dataset_num_proc is None:
271
+ from multiprocessing import cpu_count
272
+ dataset_num_proc = cpu_count()
273
+
274
+ super().__init__(
275
+ output_dir = output_dir,
276
+ overwrite_output_dir = overwrite_output_dir,
277
+ do_train = do_train,
278
+ do_eval = do_eval,
279
+ do_predict = do_predict,
280
+ eval_strategy = eval_strategy,
281
+ prediction_loss_only = prediction_loss_only,
282
+ per_device_train_batch_size = per_device_train_batch_size,
283
+ per_device_eval_batch_size = per_device_eval_batch_size,
284
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
285
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
286
+ gradient_accumulation_steps = gradient_accumulation_steps,
287
+ eval_accumulation_steps = eval_accumulation_steps,
288
+ eval_delay = eval_delay,
289
+ torch_empty_cache_steps = torch_empty_cache_steps,
290
+ learning_rate = learning_rate,
291
+ weight_decay = weight_decay,
292
+ adam_beta1 = adam_beta1,
293
+ adam_beta2 = adam_beta2,
294
+ adam_epsilon = adam_epsilon,
295
+ max_grad_norm = max_grad_norm,
296
+ num_train_epochs = num_train_epochs,
297
+ max_steps = max_steps,
298
+ lr_scheduler_type = lr_scheduler_type,
299
+ warmup_ratio = warmup_ratio,
300
+ warmup_steps = warmup_steps,
301
+ log_level = log_level,
302
+ log_level_replica = log_level_replica,
303
+ log_on_each_node = log_on_each_node,
304
+ logging_dir = logging_dir,
305
+ logging_strategy = logging_strategy,
306
+ logging_first_step = logging_first_step,
307
+ logging_steps = logging_steps,
308
+ logging_nan_inf_filter = logging_nan_inf_filter,
309
+ save_strategy = save_strategy,
310
+ save_steps = save_steps,
311
+ save_total_limit = save_total_limit,
312
+ save_safetensors = save_safetensors,
313
+ save_on_each_node = save_on_each_node,
314
+ save_only_model = save_only_model,
315
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
316
+ no_cuda = no_cuda,
317
+ use_cpu = use_cpu,
318
+ use_mps_device = use_mps_device,
319
+ seed = seed,
320
+ data_seed = data_seed,
321
+ jit_mode_eval = jit_mode_eval,
322
+ use_ipex = use_ipex,
323
+ bf16 = bf16,
324
+ fp16 = fp16,
325
+ fp16_opt_level = fp16_opt_level,
326
+ half_precision_backend = half_precision_backend,
327
+ bf16_full_eval = bf16_full_eval,
328
+ fp16_full_eval = fp16_full_eval,
329
+ tf32 = tf32,
330
+ local_rank = local_rank,
331
+ ddp_backend = ddp_backend,
332
+ tpu_num_cores = tpu_num_cores,
333
+ tpu_metrics_debug = tpu_metrics_debug,
334
+ debug = debug,
335
+ dataloader_drop_last = dataloader_drop_last,
336
+ eval_steps = eval_steps,
337
+ dataloader_num_workers = dataloader_num_workers,
338
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
339
+ past_index = past_index,
340
+ run_name = run_name,
341
+ disable_tqdm = disable_tqdm,
342
+ remove_unused_columns = remove_unused_columns,
343
+ label_names = label_names,
344
+ load_best_model_at_end = load_best_model_at_end,
345
+ metric_for_best_model = metric_for_best_model,
346
+ greater_is_better = greater_is_better,
347
+ ignore_data_skip = ignore_data_skip,
348
+ fsdp = fsdp,
349
+ fsdp_min_num_params = fsdp_min_num_params,
350
+ fsdp_config = fsdp_config,
351
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
352
+ accelerator_config = accelerator_config,
353
+ deepspeed = deepspeed,
354
+ label_smoothing_factor = label_smoothing_factor,
355
+ optim = optim,
356
+ optim_args = optim_args,
357
+ adafactor = adafactor,
358
+ group_by_length = group_by_length,
359
+ length_column_name = length_column_name,
360
+ report_to = report_to,
361
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
362
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
363
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
364
+ dataloader_pin_memory = dataloader_pin_memory,
365
+ dataloader_persistent_workers = dataloader_persistent_workers,
366
+ skip_memory_metrics = skip_memory_metrics,
367
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
368
+ push_to_hub = push_to_hub,
369
+ resume_from_checkpoint = resume_from_checkpoint,
370
+ hub_model_id = hub_model_id,
371
+ hub_strategy = hub_strategy,
372
+ hub_token = hub_token,
373
+ hub_private_repo = hub_private_repo,
374
+ hub_always_push = hub_always_push,
375
+ gradient_checkpointing = gradient_checkpointing,
376
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
377
+ include_inputs_for_metrics = include_inputs_for_metrics,
378
+ eval_do_concat_batches = eval_do_concat_batches,
379
+ fp16_backend = fp16_backend,
380
+ evaluation_strategy = evaluation_strategy,
381
+ push_to_hub_model_id = push_to_hub_model_id,
382
+ push_to_hub_organization = push_to_hub_organization,
383
+ push_to_hub_token = push_to_hub_token,
384
+ mp_parameters = mp_parameters,
385
+ auto_find_batch_size = auto_find_batch_size,
386
+ full_determinism = full_determinism,
387
+ torchdynamo = torchdynamo,
388
+ ray_scope = ray_scope,
389
+ ddp_timeout = ddp_timeout,
390
+ torch_compile = torch_compile,
391
+ torch_compile_backend = torch_compile_backend,
392
+ torch_compile_mode = torch_compile_mode,
393
+ dispatch_batches = dispatch_batches,
394
+ split_batches = split_batches,
395
+ include_tokens_per_second = include_tokens_per_second,
396
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
397
+ neftune_noise_alpha = neftune_noise_alpha,
398
+ optim_target_modules = optim_target_modules,
399
+ batch_eval_metrics = batch_eval_metrics,
400
+ eval_on_start = eval_on_start,
401
+ use_liger_kernel = use_liger_kernel,
402
+ eval_use_gather_object = eval_use_gather_object,
403
+ average_tokens_across_devices = average_tokens_across_devices,
404
+ max_length = max_length,
405
+ max_prompt_length = max_prompt_length,
406
+ max_completion_length = max_completion_length,
407
+ beta = beta,
408
+ loss_type = loss_type,
409
+ desirable_weight = desirable_weight,
410
+ undesirable_weight = undesirable_weight,
411
+ label_pad_token_id = label_pad_token_id,
412
+ padding_value = padding_value,
413
+ truncation_mode = truncation_mode,
414
+ generate_during_eval = generate_during_eval,
415
+ is_encoder_decoder = is_encoder_decoder,
416
+ disable_dropout = disable_dropout,
417
+ precompute_ref_log_probs = precompute_ref_log_probs,
418
+ model_init_kwargs = model_init_kwargs,
419
+ ref_model_init_kwargs = ref_model_init_kwargs,
420
+ dataset_num_proc = dataset_num_proc,**kwargs)
421
+ self.vllm_sampling_params = vllm_sampling_params
422
+ self.unsloth_num_chunks = unsloth_num_chunks
423
+ pass
424
+
425
+ class _UnslothKTOTrainer(Trainer):
426
+ r""""""
427
+
428
+ _tag_names = ["trl", "kto"]
429
+
430
+ def __init__(
431
+ self,
432
+ model: Union[PreTrainedModel, nn.Module, str] = None,
433
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
434
+ args: KTOConfig = None,
435
+ train_dataset: Optional[Dataset] = None,
436
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
437
+ processing_class: Optional[
438
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
439
+ ] = None,
440
+ data_collator: Optional[DataCollator] = None,
441
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
442
+ callbacks: Optional[list[TrainerCallback]] = None,
443
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
444
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
445
+ peft_config: Optional[dict] = None,
446
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
447
+ model_adapter_name: Optional[str] = None,
448
+ ref_adapter_name: Optional[str] = None,
449
+ ):
450
+ if type(args) is TrainingArguments:
451
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
452
+
453
+ if not isinstance(model, str) and ref_model is model:
454
+ raise ValueError(
455
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
456
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
457
+ )
458
+
459
+ if args.model_init_kwargs is None:
460
+ model_init_kwargs = {}
461
+ elif not isinstance(model, str):
462
+ raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
463
+ else:
464
+ model_init_kwargs = args.model_init_kwargs
465
+ torch_dtype = model_init_kwargs.get("torch_dtype")
466
+ if torch_dtype is not None:
467
+ # Convert to `torch.dtype` if an str is passed
468
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
469
+ torch_dtype = getattr(torch, torch_dtype)
470
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
471
+ raise ValueError(
472
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
473
+ )
474
+ model_init_kwargs["torch_dtype"] = torch_dtype
475
+
476
+ if args.ref_model_init_kwargs is None:
477
+ ref_model_init_kwargs = {}
478
+ elif not isinstance(ref_model, str):
479
+ raise ValueError(
480
+ "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
481
+ )
482
+ else:
483
+ ref_model_init_kwargs = args.ref_model_init_kwargs
484
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
485
+ if torch_dtype is not None:
486
+ # Convert to `torch.dtype` if an str is passed
487
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
488
+ torch_dtype = getattr(torch, torch_dtype)
489
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
490
+ raise ValueError(
491
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
492
+ )
493
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
494
+
495
+ if isinstance(model, str):
496
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
497
+
498
+ if isinstance(ref_model, str):
499
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
500
+
501
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
502
+ # has been called in order to properly call autocast if needed.
503
+ self._peft_has_been_casted_to_bf16 = False
504
+
505
+ if not is_peft_available() and peft_config is not None:
506
+ raise ValueError(
507
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
508
+ )
509
+ elif is_peft_available() and peft_config is not None:
510
+ # if model is a peft model and we have a peft_config, we merge and unload it first
511
+ if isinstance(model, PeftModel):
512
+ model = model.merge_and_unload()
513
+
514
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
515
+ _support_gc_kwargs = hasattr(
516
+ args, "gradient_checkpointing_kwargs"
517
+ ) and "gradient_checkpointing_kwargs" in list(
518
+ inspect.signature(prepare_model_for_kbit_training).parameters
519
+ )
520
+
521
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
522
+
523
+ if _support_gc_kwargs:
524
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
525
+
526
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
527
+ elif getattr(args, "gradient_checkpointing", False):
528
+ # For backward compatibility with older versions of transformers
529
+ if hasattr(model, "enable_input_require_grads"):
530
+ model.enable_input_require_grads()
531
+ else:
532
+
533
+ def make_inputs_require_grad(module, input, output):
534
+ output.requires_grad_(True)
535
+
536
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
537
+
538
+ # get peft model with the given config
539
+ model = model
540
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
541
+ peft_module_casting_to_bf16(model)
542
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
543
+ self._peft_has_been_casted_to_bf16 = True
544
+
545
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
546
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
547
+ # fail or completely fail.
548
+ elif getattr(args, "gradient_checkpointing", False):
549
+ # For backward compatibility with older versions of transformers
550
+ if hasattr(model, "enable_input_require_grads"):
551
+ model.enable_input_require_grads()
552
+ else:
553
+
554
+ def make_inputs_require_grad(module, input, output):
555
+ output.requires_grad_(True)
556
+
557
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
558
+
559
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
560
+ raise ValueError(
561
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
562
+ " Please install `wandb` or `comet-ml` to resolve."
563
+ )
564
+
565
+ if model is not None:
566
+ self.is_encoder_decoder = model.config.is_encoder_decoder
567
+ elif args.is_encoder_decoder is None:
568
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
569
+ else:
570
+ self.is_encoder_decoder = args.is_encoder_decoder
571
+
572
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
573
+ self.model_adapter_name = model_adapter_name
574
+ self.ref_adapter_name = ref_adapter_name
575
+
576
+ if ref_model:
577
+ self.ref_model = ref_model
578
+ elif self.is_peft_model or args.precompute_ref_log_probs:
579
+ # The `model` with adapters turned off will be used as the reference model
580
+ self.ref_model = None
581
+ else:
582
+ self.ref_model = create_reference_model(model)
583
+
584
+ if processing_class is None:
585
+ raise ValueError(
586
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
587
+ )
588
+ if args.max_length is None:
589
+ warnings.warn(
590
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
591
+ " it will be set to `512` by default, but you should do it yourself in the future.",
592
+ UserWarning,
593
+ )
594
+ max_length = 512
595
+ if args.max_length is not None:
596
+ max_length = args.max_length
597
+
598
+ if args.max_prompt_length is None:
599
+ warnings.warn(
600
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
601
+ " it will be set to `128` by default, but you should do it yourself in the future.",
602
+ UserWarning,
603
+ )
604
+ max_prompt_length = 128
605
+ if args.max_prompt_length is not None:
606
+ max_prompt_length = args.max_prompt_length
607
+
608
+ max_completion_length = None
609
+ if args.max_completion_length is None and self.is_encoder_decoder:
610
+ warnings.warn(
611
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
612
+ " it will be set to `128` by default, but you should do it yourself in the future.",
613
+ UserWarning,
614
+ )
615
+ max_completion_length = 128
616
+ if args.max_completion_length is not None and self.is_encoder_decoder:
617
+ max_completion_length = args.max_completion_length
618
+
619
+ if data_collator is None:
620
+ data_collator = DPODataCollatorWithPadding(
621
+ pad_token_id=processing_class.pad_token_id,
622
+ label_pad_token_id=args.label_pad_token_id,
623
+ is_encoder_decoder=self.is_encoder_decoder,
624
+ )
625
+
626
+ if args.remove_unused_columns:
627
+ args.remove_unused_columns = False
628
+ # warn users
629
+ warnings.warn(
630
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
631
+ " we have set it for you, but you should do it yourself in the future.",
632
+ UserWarning,
633
+ )
634
+
635
+ self.use_dpo_data_collator = True
636
+ else:
637
+ self.use_dpo_data_collator = False
638
+
639
+ # Disable dropout in the model and reference model
640
+ if args.disable_dropout:
641
+ disable_dropout_in_model(model)
642
+ if self.ref_model is not None:
643
+ disable_dropout_in_model(self.ref_model)
644
+
645
+ self.loss_type = args.loss_type
646
+ self.max_length = max_length
647
+ self.generate_during_eval = args.generate_during_eval
648
+ self.label_pad_token_id = args.label_pad_token_id
649
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
650
+ self.max_prompt_length = max_prompt_length
651
+ self.truncation_mode = args.truncation_mode
652
+ self.max_completion_length = max_completion_length
653
+ self.processing_class = processing_class
654
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
655
+
656
+ # Not all losses require a KL calculation
657
+ self.calculate_KL = True
658
+ if self.loss_type in ["apo_zero_unpaired"]:
659
+ self.calculate_KL = False
660
+
661
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
662
+ # keep track of first called to avoid computation of future calls
663
+ self._precomputed_train_ref_log_probs = False
664
+ self._precomputed_eval_ref_log_probs = False
665
+
666
+ # metric
667
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
668
+
669
+ # KTO parameter
670
+ self.beta = args.beta
671
+ self.desirable_weight = args.desirable_weight
672
+ self.undesirable_weight = args.undesirable_weight
673
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
674
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
675
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
676
+ warnings.warn(
677
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
678
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
679
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
680
+ "loss.",
681
+ UserWarning,
682
+ )
683
+
684
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
685
+ # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
686
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
687
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
688
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
689
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
690
+ # issued.
691
+ model.warnings_issued["estimate_tokens"] = True
692
+
693
+ # Compute that only on the main process for faster data processing.
694
+ # see: https://github.com/huggingface/trl/pull/1255
695
+ with PartialState().local_main_process_first():
696
+ # Extract the prompt if needed
697
+ train_dataset = train_dataset.map(
698
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
699
+ )
700
+ # Unpair the dataset if needed
701
+ train_dataset = maybe_unpair_preference_dataset(
702
+ train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
703
+ )
704
+ # Apply the chat template if needed
705
+ train_dataset = train_dataset.map(
706
+ maybe_apply_chat_template,
707
+ fn_kwargs={"tokenizer": processing_class},
708
+ num_proc=args.dataset_num_proc,
709
+ desc="Applying chat template to train dataset",
710
+ )
711
+ if eval_dataset is not None:
712
+ eval_dataset = eval_dataset.map(
713
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
714
+ )
715
+ eval_dataset = maybe_unpair_preference_dataset(
716
+ eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
717
+ )
718
+ eval_dataset = eval_dataset.map(
719
+ maybe_apply_chat_template,
720
+ fn_kwargs={"tokenizer": processing_class},
721
+ num_proc=args.dataset_num_proc,
722
+ desc="Applying chat template to eval dataset",
723
+ )
724
+
725
+ # Tokenize and prepare the training datasets
726
+ train_dataset = train_dataset.map(
727
+ _tokenize,
728
+ batched=True,
729
+ fn_kwargs={"tokenizer": self.processing_class},
730
+ num_proc=args.dataset_num_proc,
731
+ desc="Tokenizing train dataset",
732
+ )
733
+
734
+ fn_kwargs = {
735
+ "prefix": "",
736
+ "is_encoder_decoder": self.is_encoder_decoder,
737
+ "tokenizer": self.processing_class,
738
+ "max_length": self.max_length,
739
+ "truncation_mode": self.truncation_mode,
740
+ "label_pad_token_id": self.label_pad_token_id,
741
+ "max_prompt_length": self.max_prompt_length,
742
+ "max_completion_length": self.max_completion_length,
743
+ }
744
+
745
+ train_dataset = train_dataset.map(
746
+ _process_tokens,
747
+ fn_kwargs=fn_kwargs,
748
+ num_proc=args.dataset_num_proc,
749
+ desc="Processing tokenized train dataset",
750
+ )
751
+
752
+ # Tokenize and prepare the eval datasets
753
+ if eval_dataset is not None:
754
+ eval_dataset = eval_dataset.map(
755
+ _tokenize,
756
+ fn_kwargs={"tokenizer": self.processing_class},
757
+ batched=True,
758
+ num_proc=args.dataset_num_proc,
759
+ desc="Tokenizing eval dataset",
760
+ )
761
+
762
+ eval_dataset = eval_dataset.map(
763
+ _process_tokens,
764
+ fn_kwargs=fn_kwargs,
765
+ num_proc=args.dataset_num_proc,
766
+ desc="Processing tokenized eval dataset",
767
+ )
768
+
769
+ # Get KL datasets if needed
770
+ if self.calculate_KL:
771
+ if args.per_device_train_batch_size <= 1:
772
+ raise ValueError(
773
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
774
+ )
775
+
776
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
777
+ # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
778
+ train_kl_dataset = train_dataset.map(
779
+ _get_kl_dataset,
780
+ batched=True,
781
+ batch_size=args.per_device_train_batch_size,
782
+ num_proc=args.dataset_num_proc,
783
+ desc="Extracting KL train dataset",
784
+ )
785
+
786
+ fn_kwargs["prefix"] = "KL_"
787
+ train_kl_dataset = train_kl_dataset.map(
788
+ _process_tokens,
789
+ fn_kwargs=fn_kwargs,
790
+ num_proc=args.dataset_num_proc,
791
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
792
+ desc="Processing tokenized train KL dataset",
793
+ )
794
+
795
+ # merge the datasets
796
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
797
+
798
+ if eval_dataset is not None:
799
+ # Get KL dataset
800
+ eval_kl_dataset = eval_dataset.map(
801
+ _get_kl_dataset,
802
+ batched=True,
803
+ batch_size=args.per_device_train_batch_size,
804
+ num_proc=args.dataset_num_proc,
805
+ desc="Extracting eval KL dataset",
806
+ )
807
+
808
+ eval_kl_dataset = eval_kl_dataset.map(
809
+ _process_tokens,
810
+ fn_kwargs=fn_kwargs,
811
+ num_proc=args.dataset_num_proc,
812
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
813
+ desc="Processing tokenized eval KL dataset",
814
+ )
815
+
816
+ # merge the datasets
817
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
818
+
819
+ # calculate dataset desirability balance
820
+ num_desirable = max(sum(train_dataset["label"]), 1)
821
+ num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
822
+
823
+ if num_desirable != num_undesirable:
824
+ # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
825
+ des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
826
+ des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
827
+ und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
828
+ und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
829
+
830
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
831
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
832
+
833
+ if not (des_weight_in_range or und_weight_in_range):
834
+ warnings.warn(
835
+ "You have different amounts of desirable/positive and undesirable/negative examples but the "
836
+ "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
837
+ f"on your data, we recommend EITHER "
838
+ f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
839
+ f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
840
+ "See the documentation on how to optimally set these weights.",
841
+ UserWarning,
842
+ )
843
+
844
+ super().__init__(
845
+ model=model,
846
+ args=args,
847
+ data_collator=data_collator,
848
+ train_dataset=train_dataset,
849
+ eval_dataset=eval_dataset,
850
+ processing_class=processing_class,
851
+ model_init=model_init,
852
+ compute_metrics=compute_metrics,
853
+ callbacks=callbacks,
854
+ optimizers=optimizers,
855
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
856
+ )
857
+
858
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
859
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
860
+ # self.model_accepts_loss_kwargs to False to enable scaling.
861
+ self.model_accepts_loss_kwargs = False
862
+
863
+ # Add tags for models that have been loaded with the correct transformers version
864
+ if hasattr(self.model, "add_model_tags"):
865
+ self.model.add_model_tags(self._tag_names)
866
+
867
+ if not hasattr(self, "accelerator"):
868
+ raise AttributeError(
869
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
870
+ )
871
+
872
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
873
+ if self.is_deepspeed_enabled:
874
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
875
+ raise ValueError(
876
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
877
+ )
878
+
879
+ if self.ref_model is None:
880
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
881
+ raise ValueError(
882
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
883
+ )
884
+ else:
885
+ if self.is_deepspeed_enabled:
886
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
887
+ else:
888
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
889
+
890
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
891
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
892
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
893
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
894
+
895
+ if model is not None:
896
+ if hasattr(model, "config"):
897
+ hidden_size = (
898
+ max(model.config.hidden_sizes)
899
+ if getattr(model.config, "hidden_sizes", None)
900
+ else getattr(model.config, "hidden_size", None)
901
+ )
902
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
903
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
904
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
905
+ config_kwargs.update(
906
+ {
907
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
908
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
909
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
910
+ }
911
+ )
912
+
913
+ # If ZeRO-3 is used, we shard both the active and reference model.
914
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
915
+ if config_kwargs["zero_optimization"]["stage"] != 3:
916
+ config_kwargs["zero_optimization"]["stage"] = 0
917
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
918
+ model.eval()
919
+ return model
920
+
921
+ @contextmanager
922
+ def null_ref_context(self):
923
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
924
+ with (
925
+ self.accelerator.unwrap_model(self.model).disable_adapter()
926
+ if self.is_peft_model and not self.ref_adapter_name
927
+ else nullcontext()
928
+ ):
929
+ if self.ref_adapter_name:
930
+ self.model.set_adapter(self.ref_adapter_name)
931
+ yield
932
+ if self.ref_adapter_name:
933
+ self.model.set_adapter(self.model_adapter_name or "default")
934
+
935
+ def get_train_dataloader(self) -> DataLoader:
936
+ """
937
+ Returns the training [`~torch.utils.data.DataLoader`].
938
+
939
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
940
+ """
941
+
942
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
943
+ dataloader_params = {
944
+ "batch_size": self.args.per_device_train_batch_size,
945
+ "collate_fn": self.data_collator,
946
+ "num_workers": self.args.dataloader_num_workers,
947
+ "pin_memory": self.args.dataloader_pin_memory,
948
+ "shuffle": False,
949
+ }
950
+
951
+ # prepare dataloader
952
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
953
+ reference_completion_logps = []
954
+ reference_KL_logps = []
955
+
956
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
957
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
958
+
959
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
960
+ reference_completion_logps.append(reference_completion_logp.cpu())
961
+
962
+ if self.calculate_KL:
963
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
964
+ reference_KL_logps.append(reference_KL_logp.cpu())
965
+
966
+ self.train_dataset = self.train_dataset.add_column(
967
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
968
+ )
969
+
970
+ if self.calculate_KL:
971
+ self.train_dataset = self.train_dataset.add_column(
972
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
973
+ )
974
+
975
+ self._precomputed_train_ref_log_probs = True
976
+
977
+ return super().get_train_dataloader()
978
+
979
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
980
+ """
981
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
982
+
983
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
984
+
985
+ Args:
986
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
987
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
988
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
989
+ """
990
+ if eval_dataset is None and self.eval_dataset is None:
991
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
992
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
993
+
994
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
995
+ dataloader_params = {
996
+ "batch_size": self.args.per_device_eval_batch_size,
997
+ "collate_fn": self.data_collator,
998
+ "num_workers": self.args.dataloader_num_workers,
999
+ "pin_memory": self.args.dataloader_pin_memory,
1000
+ "shuffle": False,
1001
+ }
1002
+
1003
+ # prepare dataloader
1004
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1005
+
1006
+ reference_completion_logps = []
1007
+ reference_KL_logps = []
1008
+
1009
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1010
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1011
+
1012
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1013
+ reference_completion_logps.append(reference_completion_logp.cpu())
1014
+
1015
+ if self.calculate_KL:
1016
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1017
+ reference_KL_logps.append(reference_KL_logp.cpu())
1018
+
1019
+ eval_dataset = eval_dataset.add_column(
1020
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1021
+ )
1022
+ if self.calculate_KL:
1023
+ eval_dataset = eval_dataset.add_column(
1024
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1025
+ )
1026
+
1027
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1028
+ if self.eval_dataset is not None:
1029
+ self.eval_dataset = eval_dataset
1030
+ self._precomputed_eval_ref_log_probs = True
1031
+
1032
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1033
+
1034
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1035
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1036
+ with torch.no_grad():
1037
+ if self.ref_model is None:
1038
+ with self.null_ref_context():
1039
+ if self.is_encoder_decoder:
1040
+ completion_logits = self.model(
1041
+ padded_batch["prompt_input_ids"],
1042
+ attention_mask=padded_batch["prompt_attention_mask"],
1043
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1044
+ labels=padded_batch["completion_labels"],
1045
+ ).logits
1046
+
1047
+ if self.calculate_KL:
1048
+ KL_logits = self.model(
1049
+ padded_batch["KL_prompt_input_ids"],
1050
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1051
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1052
+ labels=padded_batch["KL_completion_labels"],
1053
+ ).logits
1054
+ else:
1055
+ completion_logits = self.model(
1056
+ padded_batch["completion_input_ids"],
1057
+ attention_mask=padded_batch["completion_attention_mask"],
1058
+ ).logits
1059
+
1060
+ if self.calculate_KL:
1061
+ KL_logits = self.model(
1062
+ padded_batch["KL_completion_input_ids"],
1063
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1064
+ ).logits
1065
+ else:
1066
+ if self.is_encoder_decoder:
1067
+ completion_logits = self.ref_model(
1068
+ padded_batch["prompt_input_ids"],
1069
+ attention_mask=padded_batch["prompt_attention_mask"],
1070
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1071
+ labels=padded_batch["completion_labels"],
1072
+ ).logits
1073
+
1074
+ if self.calculate_KL:
1075
+ KL_logits = self.ref_model(
1076
+ padded_batch["KL_prompt_input_ids"],
1077
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1078
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1079
+ labels=padded_batch["KL_completion_labels"],
1080
+ ).logits
1081
+ else:
1082
+ completion_logits = self.ref_model(
1083
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1084
+ ).logits
1085
+
1086
+ if self.calculate_KL:
1087
+ KL_logits = self.ref_model(
1088
+ padded_batch["KL_completion_input_ids"],
1089
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1090
+ ).logits
1091
+
1092
+ completion_logps = self.get_batch_logps(
1093
+ completion_logits,
1094
+ padded_batch["completion_labels"],
1095
+ average_log_prob=False,
1096
+ is_encoder_decoder=self.is_encoder_decoder,
1097
+ label_pad_token_id=self.label_pad_token_id,
1098
+ )
1099
+
1100
+ if self.calculate_KL:
1101
+ KL_logps = self.get_batch_logps(
1102
+ KL_logits,
1103
+ padded_batch["KL_completion_labels"],
1104
+ average_log_prob=False,
1105
+ is_encoder_decoder=self.is_encoder_decoder,
1106
+ label_pad_token_id=self.label_pad_token_id,
1107
+ )
1108
+ else:
1109
+ KL_logps = None
1110
+
1111
+ return completion_logps, KL_logps
1112
+
1113
+ @staticmethod
1114
+ def get_batch_logps(
1115
+ logits: torch.FloatTensor,
1116
+ labels: torch.LongTensor,
1117
+ average_log_prob: bool = False,
1118
+ label_pad_token_id: int = -100,
1119
+ is_encoder_decoder: bool = False,
1120
+ ) -> torch.FloatTensor:
1121
+ """Compute the log probabilities of the given labels under the given logits.
1122
+
1123
+ Args:
1124
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1125
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1126
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1127
+
1128
+ Returns:
1129
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1130
+ """
1131
+ if logits.shape[:-1] != labels.shape:
1132
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1133
+
1134
+ if not is_encoder_decoder:
1135
+ labels = labels[:, 1:].clone()
1136
+ logits = logits[:, :-1, :]
1137
+ else:
1138
+ # Fixes end-dec RuntimeError
1139
+ labels = labels.clone()
1140
+
1141
+ loss_mask = labels != label_pad_token_id
1142
+
1143
+ # dummy token; we'll ignore the losses on these tokens later
1144
+ labels[labels == label_pad_token_id] = 0
1145
+
1146
+ per_token_logps = selective_log_softmax(logits, labels)
1147
+
1148
+ if average_log_prob:
1149
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1150
+ else:
1151
+ return (per_token_logps * loss_mask).sum(-1)
1152
+
1153
+ def forward(
1154
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1155
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1156
+ if self.calculate_KL:
1157
+ KL_logps = None
1158
+ KL_model_kwargs = (
1159
+ {
1160
+ "input_ids": batch["KL_prompt_input_ids"],
1161
+ "attention_mask": batch["KL_prompt_attention_mask"],
1162
+ "labels": batch["KL_completion_labels"],
1163
+ "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1164
+ }
1165
+ if self.is_encoder_decoder
1166
+ else {
1167
+ "input_ids": batch["KL_completion_input_ids"],
1168
+ "attention_mask": batch["KL_completion_attention_mask"],
1169
+ }
1170
+ )
1171
+ with torch.no_grad():
1172
+ KL_logits = model(
1173
+ **KL_model_kwargs,
1174
+ ).logits
1175
+
1176
+ KL_logps = self.get_batch_logps(
1177
+ KL_logits,
1178
+ batch["KL_completion_labels"],
1179
+ average_log_prob=False,
1180
+ is_encoder_decoder=self.is_encoder_decoder,
1181
+ label_pad_token_id=self.label_pad_token_id,
1182
+ )
1183
+ else:
1184
+ KL_logps = None
1185
+
1186
+ model_kwargs = (
1187
+ {
1188
+ "labels": batch["completion_labels"],
1189
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1190
+ }
1191
+ if self.is_encoder_decoder
1192
+ else {}
1193
+ )
1194
+ if self.aux_loss_enabled:
1195
+ model_kwargs["output_router_logits"] = True
1196
+
1197
+ outputs = model(
1198
+ batch["completion_input_ids"],
1199
+ attention_mask=batch["completion_attention_mask"],
1200
+ **model_kwargs,
1201
+ )
1202
+ completion_logits = outputs.logits
1203
+
1204
+ completion_logps = self.get_batch_logps(
1205
+ completion_logits,
1206
+ batch["completion_labels"],
1207
+ average_log_prob=False,
1208
+ is_encoder_decoder=self.is_encoder_decoder,
1209
+ label_pad_token_id=self.label_pad_token_id,
1210
+ )
1211
+
1212
+ if completion_logps.shape[0] != len(batch["label"]):
1213
+ raise ValueError(
1214
+ "There is a mismatch between the number of examples in this batch and the number of "
1215
+ "examples for which an output sequence was predicted."
1216
+ )
1217
+
1218
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1219
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1220
+
1221
+ chosen_logps = completion_logps[chosen_idx, ...]
1222
+ rejected_logps = completion_logps[rejected_idx, ...]
1223
+
1224
+ chosen_logits = completion_logits[chosen_idx, ...]
1225
+ rejected_logits = completion_logits[rejected_idx, ...]
1226
+
1227
+ if self.aux_loss_enabled:
1228
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1229
+ else:
1230
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1231
+
1232
+ def kto_loss(
1233
+ self,
1234
+ policy_chosen_logps: torch.FloatTensor,
1235
+ policy_rejected_logps: torch.FloatTensor,
1236
+ policy_KL_logps: torch.FloatTensor,
1237
+ reference_chosen_logps: torch.FloatTensor,
1238
+ reference_rejected_logps: torch.FloatTensor,
1239
+ reference_KL_logps: torch.FloatTensor,
1240
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1241
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
1242
+
1243
+ Args:
1244
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1245
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1246
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1247
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1248
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1249
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1250
+
1251
+ Returns:
1252
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1253
+ The losses tensor contains the KTO loss for each example in the batch.
1254
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1255
+ The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1256
+ """
1257
+ if self.calculate_KL:
1258
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1259
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1260
+ else:
1261
+ kl = torch.zeros(1).to(policy_chosen_logps.device)
1262
+
1263
+ # Chosen losses
1264
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1265
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1266
+
1267
+ if self.loss_type == "kto":
1268
+ # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1269
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1270
+ elif self.loss_type == "apo_zero_unpaired":
1271
+ # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1272
+ # Use this loss when you believe the chosen outputs are better than your model's default output
1273
+ chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1274
+
1275
+ chosen_rewards = self.beta * chosen_logratios.detach()
1276
+
1277
+ else:
1278
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1279
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1280
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1281
+
1282
+ # Rejected losses
1283
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1284
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1285
+
1286
+ if self.loss_type == "kto":
1287
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1288
+ elif self.loss_type == "apo_zero_unpaired":
1289
+ rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1290
+
1291
+ rejected_rewards = self.beta * rejected_logratios.detach()
1292
+ else:
1293
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1294
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1295
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1296
+
1297
+ losses = torch.cat(
1298
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1299
+ 0,
1300
+ )
1301
+
1302
+ return losses, chosen_rewards, rejected_rewards, kl
1303
+
1304
+ def get_batch_loss_metrics(
1305
+ self,
1306
+ model,
1307
+ batch: dict[str, Union[list, torch.LongTensor]],
1308
+ ):
1309
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1310
+ metrics = {}
1311
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1312
+
1313
+ forward_output = self.forward(model, batch)
1314
+ (
1315
+ policy_chosen_logps,
1316
+ policy_rejected_logps,
1317
+ policy_chosen_logits,
1318
+ policy_rejected_logits,
1319
+ policy_KL_logps,
1320
+ ) = forward_output[:5]
1321
+ if self.aux_loss_enabled:
1322
+ aux_loss = forward_output[5]
1323
+
1324
+ # if reference_logps in batch use them, otherwise use the reference model
1325
+ if "reference_logps" in batch:
1326
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1327
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1328
+
1329
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1330
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1331
+ if self.calculate_KL:
1332
+ reference_KL_logps = batch["reference_KL_logps"]
1333
+ else:
1334
+ reference_KL_logps = None
1335
+ else:
1336
+ with torch.no_grad():
1337
+ if self.ref_model is None:
1338
+ with self.null_ref_context():
1339
+ (
1340
+ reference_chosen_logps,
1341
+ reference_rejected_logps,
1342
+ _,
1343
+ _,
1344
+ reference_KL_logps,
1345
+ ) = self.forward(self.model, batch)[:5]
1346
+ else:
1347
+ (
1348
+ reference_chosen_logps,
1349
+ reference_rejected_logps,
1350
+ _,
1351
+ _,
1352
+ reference_KL_logps,
1353
+ ) = self.forward(self.ref_model, batch)[:5]
1354
+
1355
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1356
+ policy_chosen_logps,
1357
+ policy_rejected_logps,
1358
+ policy_KL_logps,
1359
+ reference_chosen_logps,
1360
+ reference_rejected_logps,
1361
+ reference_KL_logps,
1362
+ )
1363
+ metrics["kl"] = kl.item()
1364
+
1365
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1366
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1367
+
1368
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1369
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1370
+
1371
+ if all_num_chosen > 0:
1372
+ metrics["rewards/chosen_sum"] = (
1373
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1374
+ )
1375
+ metrics["logps/chosen_sum"] = (
1376
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1377
+ )
1378
+ metrics["logits/chosen_sum"] = (
1379
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1380
+ )
1381
+ metrics["count/chosen"] = all_num_chosen
1382
+
1383
+ if all_num_rejected > 0:
1384
+ metrics["rewards/rejected_sum"] = (
1385
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1386
+ )
1387
+ metrics["logps/rejected_sum"] = (
1388
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1389
+ )
1390
+ metrics["logits/rejected_sum"] = (
1391
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1392
+ )
1393
+ metrics["count/rejected"] = all_num_rejected
1394
+
1395
+ loss = losses.nanmean()
1396
+ if self.aux_loss_enabled:
1397
+ loss += self.aux_loss_coef * aux_loss
1398
+
1399
+ return loss, metrics
1400
+
1401
+ def compute_loss(
1402
+ self,
1403
+ model: Union[PreTrainedModel, nn.Module],
1404
+ inputs: dict[str, Union[torch.Tensor, Any]],
1405
+ return_outputs=False,
1406
+ num_items_in_batch=None,
1407
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1408
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1409
+
1410
+ with compute_loss_context_manager:
1411
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1412
+
1413
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1414
+ loss = loss.to(self.args.device)
1415
+ # force log the metrics
1416
+ if self.accelerator.is_main_process:
1417
+ self.store_metrics(metrics, train_eval="train")
1418
+
1419
+ if return_outputs:
1420
+ return (loss, metrics)
1421
+ return loss
1422
+
1423
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1424
+ for key, value in metrics.items():
1425
+ self._stored_metrics[train_eval][key].append(value)
1426
+
1427
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1428
+ if self.train_dataset is None or not has_length(self.train_dataset):
1429
+ return None
1430
+ return SequentialSampler(self.train_dataset)
1431
+
1432
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1433
+ """Generate samples from the model and reference model for the given batch of inputs."""
1434
+
1435
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1436
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1437
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1438
+
1439
+ with generate_context_manager:
1440
+ policy_output = model.generate(
1441
+ input_ids=batch["prompt_input_ids"],
1442
+ attention_mask=batch["prompt_attention_mask"],
1443
+ max_length=self.max_length,
1444
+ do_sample=True,
1445
+ pad_token_id=self.processing_class.pad_token_id,
1446
+ )
1447
+
1448
+ # if reference_output in batch use that otherwise use the reference model
1449
+ if "reference_output" in batch:
1450
+ reference_output = batch["reference_output"]
1451
+ else:
1452
+ if self.ref_model is None:
1453
+ with self.null_ref_context():
1454
+ reference_output = self.model.generate(
1455
+ input_ids=batch["prompt_input_ids"],
1456
+ attention_mask=batch["prompt_attention_mask"],
1457
+ max_length=self.max_length,
1458
+ do_sample=True,
1459
+ pad_token_id=self.processing_class.pad_token_id,
1460
+ )
1461
+ else:
1462
+ reference_output = self.ref_model.generate(
1463
+ input_ids=batch["prompt_input_ids"],
1464
+ attention_mask=batch["prompt_attention_mask"],
1465
+ max_length=self.max_length,
1466
+ do_sample=True,
1467
+ pad_token_id=self.processing_class.pad_token_id,
1468
+ )
1469
+
1470
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1471
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1472
+
1473
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1474
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1475
+
1476
+ return policy_output_decoded, reference_output_decoded
1477
+
1478
+ def prediction_step(
1479
+ self,
1480
+ model: Union[PreTrainedModel, nn.Module],
1481
+ inputs: dict[str, Union[torch.Tensor, Any]],
1482
+ prediction_loss_only: bool,
1483
+ ignore_keys: Optional[list[str]] = None,
1484
+ ):
1485
+ if ignore_keys is None:
1486
+ if hasattr(model, "config"):
1487
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1488
+ else:
1489
+ ignore_keys = []
1490
+
1491
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1492
+ with torch.no_grad(), prediction_context_manager:
1493
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1494
+
1495
+ # force log the metrics
1496
+ if self.accelerator.is_main_process:
1497
+ self.store_metrics(metrics, train_eval="eval")
1498
+
1499
+ if prediction_loss_only:
1500
+ return (loss.detach(), None, None)
1501
+
1502
+ # logits for the chosen and rejected samples from model
1503
+ logits_dict = {
1504
+ "eval_logits/chosen": metrics["logits/chosen"],
1505
+ "eval_logits/rejected": metrics["logits/rejected"],
1506
+ }
1507
+ logits = torch.tensor(
1508
+ [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
1509
+ )
1510
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1511
+
1512
+ return (loss.detach(), logits, labels)
1513
+
1514
+ def evaluation_loop(
1515
+ self,
1516
+ dataloader: DataLoader,
1517
+ description: str,
1518
+ prediction_loss_only: Optional[bool] = None,
1519
+ ignore_keys: Optional[list[str]] = None,
1520
+ metric_key_prefix: str = "eval",
1521
+ ) -> EvalLoopOutput:
1522
+ """
1523
+ Overriding built-in evaluation loop to store metrics for each batch.
1524
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1525
+
1526
+ Works both with or without labels.
1527
+ """
1528
+
1529
+ # Sample and save to game log if requested (for one batch to save time)
1530
+ if self.generate_during_eval:
1531
+ # Generate random indices within the range of the total number of samples
1532
+ num_samples = len(dataloader.dataset)
1533
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1534
+
1535
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1536
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1537
+ random_batch = self.data_collator(random_batch_dataset)
1538
+ random_batch = self._prepare_inputs(random_batch)
1539
+
1540
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1541
+ target_batch = {
1542
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1543
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1544
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1545
+ }
1546
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1547
+
1548
+ table = pd.DataFrame(
1549
+ columns=["Prompt", "Policy", "Ref Model"],
1550
+ data=[
1551
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1552
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1553
+ ],
1554
+ )
1555
+ if "wandb" in self.args.report_to:
1556
+ wandb.log({"game_log": wandb.Table(data=table)})
1557
+
1558
+ if "comet_ml" in self.args.report_to:
1559
+ log_table_to_comet_experiment(
1560
+ name="game_log.csv",
1561
+ table=table,
1562
+ )
1563
+
1564
+ # Base evaluation
1565
+ initial_output = super().evaluation_loop(
1566
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1567
+ )
1568
+
1569
+ return initial_output
1570
+
1571
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1572
+ """
1573
+ Log `logs` on the various objects watching training, including stored metrics.
1574
+
1575
+ Args:
1576
+ logs (`dict[str, float]`):
1577
+ The values to log.
1578
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1579
+ Start time of the training.
1580
+ """
1581
+ # logs either has 'loss' or 'eval_loss'
1582
+ train_eval = "train" if "loss" in logs else "eval"
1583
+ # train metrics should have no prefix, eval should have 'eval_'
1584
+ prefix = "eval_" if train_eval == "eval" else ""
1585
+ # accumulate average metrics from sums and lengths
1586
+ for split in ["chosen", "rejected"]:
1587
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1588
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1589
+ for metric in ["rewards", "logps", "logits"]:
1590
+ logs[f"{prefix}{metric}/{split}"] = (
1591
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1592
+ / count_sum
1593
+ )
1594
+ # delete obsolete metric
1595
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1596
+ del self._stored_metrics[train_eval][f"count/{split}"]
1597
+ # calculate reward margin
1598
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1599
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1600
+ # Add averaged stored metrics to logs
1601
+ for key, metrics in self._stored_metrics[train_eval].items():
1602
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1603
+ del self._stored_metrics[train_eval]
1604
+
1605
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1606
+ return super().log(logs, start_time)
1607
+ else: # transformers<=4.46
1608
+ return super().log(logs)
1609
+
1610
+ def create_model_card(
1611
+ self,
1612
+ model_name: Optional[str] = None,
1613
+ dataset_name: Optional[str] = None,
1614
+ tags: Union[str, list[str], None] = None,
1615
+ ):
1616
+ """
1617
+ Creates a draft of a model card using the information available to the `Trainer`.
1618
+
1619
+ Args:
1620
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1621
+ Name of the model.
1622
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1623
+ Name of the dataset used for training.
1624
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1625
+ Tags to be associated with the model card.
1626
+ """
1627
+ if not self.is_world_process_zero():
1628
+ return
1629
+
1630
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1631
+ base_model = self.model.config._name_or_path
1632
+ else:
1633
+ base_model = None
1634
+
1635
+ tags = tags or []
1636
+ if isinstance(tags, str):
1637
+ tags = [tags]
1638
+
1639
+ if hasattr(self.model.config, "unsloth_version"):
1640
+ tags.append("unsloth")
1641
+
1642
+ citation = textwrap.dedent("""\
1643
+ @article{ethayarajh2024kto,
1644
+ title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1645
+ author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1646
+ year = 2024,
1647
+ eprint = {arXiv:2402.01306},
1648
+ }""")
1649
+
1650
+ model_card = generate_model_card(
1651
+ base_model=base_model,
1652
+ model_name=model_name,
1653
+ hub_model_id=self.hub_model_id,
1654
+ dataset_name=dataset_name,
1655
+ tags=tags,
1656
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1657
+ comet_url=get_comet_experiment_url(),
1658
+ trainer_name="KTO",
1659
+ trainer_citation=citation,
1660
+ paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1661
+ paper_id="2402.01306",
1662
+ )
1663
+
1664
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1665
+ class UnslothKTOTrainer(_UnslothKTOTrainer):
1666
+ """
1667
+
1668
+ Initialize KTOTrainer.
1669
+
1670
+ Args:
1671
+ model (`transformers.PreTrainedModel`):
1672
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1673
+ ref_model (`PreTrainedModelWrapper`):
1674
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1675
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1676
+ args (`KTOConfig`):
1677
+ The arguments to use for training.
1678
+ train_dataset (`datasets.Dataset`):
1679
+ The dataset to use for training.
1680
+ eval_dataset (`datasets.Dataset`):
1681
+ The dataset to use for evaluation.
1682
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1683
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1684
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1685
+ reuse the fine-tuned model.
1686
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1687
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1688
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1689
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1690
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1691
+ callbacks (`list[transformers.TrainerCallback]`):
1692
+ The callbacks to use for training.
1693
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1694
+ The optimizer and scheduler to use for training.
1695
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1696
+ The function to use to preprocess the logits before computing the metrics.
1697
+ peft_config (`dict`, defaults to `None`):
1698
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1699
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1700
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1701
+ a dictionary string to metric values.
1702
+ model_adapter_name (`str`, defaults to `None`):
1703
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1704
+ ref_adapter_name (`str`, defaults to `None`):
1705
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1706
+
1707
+ """
1708
+ def __init__(
1709
+ self,
1710
+ model = None,
1711
+ ref_model = None,
1712
+ args = None,
1713
+ train_dataset = None,
1714
+ eval_dataset = None,
1715
+ processing_class = None,
1716
+ data_collator = None,
1717
+ model_init = None,
1718
+ callbacks = None,
1719
+ preprocess_logits_for_metrics = None,
1720
+ peft_config = None,
1721
+ compute_metrics = None,
1722
+ model_adapter_name = None,
1723
+ ref_adapter_name = None,
1724
+ **kwargs
1725
+ ):
1726
+ if args is None: args = UnslothKTOConfig()
1727
+ use_bf16 = getattr(args, 'bf16', False)
1728
+ use_fp16 = getattr(args, 'fp16', False)
1729
+ force_float32 = False
1730
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1731
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1732
+ force_float32 = True
1733
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1734
+ dtype = getattr(model.config, 'torch_dtype', None)
1735
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1736
+ from unsloth_zoo.utils import _get_dtype
1737
+ dtype = _get_dtype(dtype)
1738
+ float16 = dtype == torch.float16
1739
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1740
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1741
+ if force_float32:
1742
+ args.fp16 = False
1743
+ args.bf16 = False
1744
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1745
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1746
+ args.fp16 = float16
1747
+ args.bf16 = not float16
1748
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1749
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1750
+ args.eval_strategy = 'steps'
1751
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1752
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1753
+ if ga_steps is not None and ga_steps > 1:
1754
+ from transformers import __version__ as transformers_version
1755
+ if Version(transformers_version) <= Version('4.45.2'):
1756
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1757
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1758
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1759
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1760
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1761
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1762
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1763
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1764
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1765
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1766
+ if force_float32:
1767
+ args.bf16_full_eval = False
1768
+ args.fp16_full_eval = False
1769
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1770
+ args.bf16_full_eval = True
1771
+ args.fp16_full_eval = False
1772
+ elif not bf16_full_eval and not fp16_full_eval:
1773
+ args.bf16_full_eval = args.bf16
1774
+ args.fp16_full_eval = args.fp16
1775
+ _output_logits = False
1776
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1777
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1778
+ if _output_logits:
1779
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1780
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1781
+ pass
1782
+ else:
1783
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1784
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1785
+ if args_max_seq_length is None and model_max_seq_length is not None:
1786
+ max_seq_length = model.max_seq_length
1787
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1788
+ if model is not None and hasattr(model, 'for_training'):
1789
+ model.for_training()
1790
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1791
+ if 'processing_class' in locals():
1792
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1793
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1794
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1795
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1796
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1797
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1798
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1799
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1800
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1801
+ else:
1802
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1803
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1804
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1805
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1806
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1807
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1808
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1809
+ else:
1810
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1811
+ other_metrics = []
1812
+
1813
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1814
+ PatchRLStatistics('kto_trainer', other_metrics)
1815
+
1816
+ super().__init__(
1817
+ model = model,
1818
+ ref_model = ref_model,
1819
+ args = args,
1820
+ train_dataset = train_dataset,
1821
+ eval_dataset = eval_dataset,
1822
+ processing_class = processing_class,
1823
+ data_collator = data_collator,
1824
+ model_init = model_init,
1825
+ callbacks = callbacks,
1826
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1827
+ peft_config = peft_config,
1828
+ compute_metrics = compute_metrics,
1829
+ model_adapter_name = model_adapter_name,
1830
+ ref_adapter_name = ref_adapter_name,**kwargs)
1831
+ if hasattr(self, 'neftune_hook_handle'):
1832
+ self.neftune_hook_handle.remove()
1833
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1834
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1835
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1836
+ pass
1837
+
1838
+ pass
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothNashMDConfig(NashMDConfig):
44
+ """
45
+
46
+ Configuration class for the [`NashMDTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
52
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
53
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
54
+ epochs.
55
+
56
+ """
57
+ vllm_sampling_params: Optional[Any] = field(
58
+ default = None,
59
+ metadata = {'help': 'vLLM SamplingParams'},
60
+ )
61
+ unsloth_num_chunks : Optional[int] = field(
62
+ default = -1,
63
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
64
+ )
65
+ def __init__(
66
+ self,
67
+ output_dir = None,
68
+ overwrite_output_dir = None,
69
+ do_train = False,
70
+ do_eval = False,
71
+ do_predict = False,
72
+ eval_strategy = 'no',
73
+ prediction_loss_only = False,
74
+ per_device_train_batch_size = 4,
75
+ per_device_eval_batch_size = 4,
76
+ per_gpu_train_batch_size = None,
77
+ per_gpu_eval_batch_size = None,
78
+ gradient_accumulation_steps = 2,
79
+ eval_accumulation_steps = 2,
80
+ eval_delay = 0,
81
+ torch_empty_cache_steps = 250,
82
+ learning_rate = 5e-05,
83
+ weight_decay = 0.01,
84
+ adam_beta1 = 0.9,
85
+ adam_beta2 = 0.999,
86
+ adam_epsilon = 1e-08,
87
+ max_grad_norm = 1.0,
88
+ num_train_epochs = 3.0,
89
+ max_steps = -1,
90
+ lr_scheduler_type = 'linear',
91
+ warmup_ratio = 0.1,
92
+ warmup_steps = 0,
93
+ log_level = 'passive',
94
+ log_level_replica = 'warning',
95
+ log_on_each_node = True,
96
+ logging_dir = None,
97
+ logging_strategy = 'steps',
98
+ logging_first_step = False,
99
+ logging_steps = 1,
100
+ logging_nan_inf_filter = False,
101
+ save_strategy = 'steps',
102
+ save_steps = 500,
103
+ save_total_limit = None,
104
+ save_safetensors = True,
105
+ save_on_each_node = False,
106
+ save_only_model = False,
107
+ restore_callback_states_from_checkpoint = False,
108
+ no_cuda = False,
109
+ use_cpu = False,
110
+ use_mps_device = False,
111
+ seed = 3407,
112
+ data_seed = 3407,
113
+ jit_mode_eval = False,
114
+ use_ipex = False,
115
+ bf16 = False,
116
+ fp16 = False,
117
+ fp16_opt_level = 'O1',
118
+ half_precision_backend = 'auto',
119
+ bf16_full_eval = False,
120
+ fp16_full_eval = False,
121
+ tf32 = None,
122
+ local_rank = -1,
123
+ ddp_backend = None,
124
+ tpu_num_cores = None,
125
+ tpu_metrics_debug = False,
126
+ debug = '',
127
+ dataloader_drop_last = False,
128
+ eval_steps = None,
129
+ dataloader_num_workers = 0,
130
+ dataloader_prefetch_factor = None,
131
+ past_index = -1,
132
+ run_name = None,
133
+ disable_tqdm = None,
134
+ remove_unused_columns = True,
135
+ label_names = None,
136
+ load_best_model_at_end = False,
137
+ metric_for_best_model = None,
138
+ greater_is_better = None,
139
+ ignore_data_skip = False,
140
+ fsdp = '',
141
+ fsdp_min_num_params = 0,
142
+ fsdp_config = None,
143
+ fsdp_transformer_layer_cls_to_wrap = None,
144
+ accelerator_config = None,
145
+ deepspeed = None,
146
+ label_smoothing_factor = 0.0,
147
+ optim = 'adamw_8bit',
148
+ optim_args = None,
149
+ adafactor = False,
150
+ group_by_length = False,
151
+ length_column_name = 'length',
152
+ report_to = None,
153
+ ddp_find_unused_parameters = None,
154
+ ddp_bucket_cap_mb = None,
155
+ ddp_broadcast_buffers = None,
156
+ dataloader_pin_memory = True,
157
+ dataloader_persistent_workers = False,
158
+ skip_memory_metrics = True,
159
+ use_legacy_prediction_loop = False,
160
+ push_to_hub = False,
161
+ resume_from_checkpoint = None,
162
+ hub_model_id = None,
163
+ hub_strategy = 'every_save',
164
+ hub_token = None,
165
+ hub_private_repo = None,
166
+ hub_always_push = False,
167
+ gradient_checkpointing = False,
168
+ gradient_checkpointing_kwargs = None,
169
+ include_inputs_for_metrics = False,
170
+ eval_do_concat_batches = True,
171
+ fp16_backend = 'auto',
172
+ evaluation_strategy = None,
173
+ push_to_hub_model_id = None,
174
+ push_to_hub_organization = None,
175
+ push_to_hub_token = None,
176
+ mp_parameters = '',
177
+ auto_find_batch_size = False,
178
+ full_determinism = False,
179
+ torchdynamo = None,
180
+ ray_scope = 'last',
181
+ ddp_timeout = 1800,
182
+ torch_compile = False,
183
+ torch_compile_backend = None,
184
+ torch_compile_mode = None,
185
+ dispatch_batches = None,
186
+ split_batches = None,
187
+ include_tokens_per_second = False,
188
+ include_num_input_tokens_seen = False,
189
+ neftune_noise_alpha = None,
190
+ optim_target_modules = None,
191
+ batch_eval_metrics = False,
192
+ eval_on_start = False,
193
+ use_liger_kernel = False,
194
+ eval_use_gather_object = False,
195
+ average_tokens_across_devices = False,
196
+ reward_model_path = None,
197
+ judge = None,
198
+ max_new_tokens = 64,
199
+ max_length = 512,
200
+ temperature = 0.9,
201
+ missing_eos_penalty = None,
202
+ loss_type = 'sigmoid',
203
+ dataset_num_proc = None,
204
+ disable_dropout = True,
205
+ use_vllm = False,
206
+ ds3_gather_for_generation = True,
207
+ vllm_sampling_params = None,
208
+ unsloth_num_chunks = -1,
209
+ **kwargs,
210
+ ):
211
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
212
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
213
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
214
+ output_dir = 'unsloth_training_checkpoints'
215
+ save_strategy = 'no'
216
+ if dataset_num_proc is None:
217
+ from multiprocessing import cpu_count
218
+ dataset_num_proc = cpu_count()
219
+
220
+ super().__init__(
221
+ output_dir = output_dir,
222
+ overwrite_output_dir = overwrite_output_dir,
223
+ do_train = do_train,
224
+ do_eval = do_eval,
225
+ do_predict = do_predict,
226
+ eval_strategy = eval_strategy,
227
+ prediction_loss_only = prediction_loss_only,
228
+ per_device_train_batch_size = per_device_train_batch_size,
229
+ per_device_eval_batch_size = per_device_eval_batch_size,
230
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
231
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
232
+ gradient_accumulation_steps = gradient_accumulation_steps,
233
+ eval_accumulation_steps = eval_accumulation_steps,
234
+ eval_delay = eval_delay,
235
+ torch_empty_cache_steps = torch_empty_cache_steps,
236
+ learning_rate = learning_rate,
237
+ weight_decay = weight_decay,
238
+ adam_beta1 = adam_beta1,
239
+ adam_beta2 = adam_beta2,
240
+ adam_epsilon = adam_epsilon,
241
+ max_grad_norm = max_grad_norm,
242
+ num_train_epochs = num_train_epochs,
243
+ max_steps = max_steps,
244
+ lr_scheduler_type = lr_scheduler_type,
245
+ warmup_ratio = warmup_ratio,
246
+ warmup_steps = warmup_steps,
247
+ log_level = log_level,
248
+ log_level_replica = log_level_replica,
249
+ log_on_each_node = log_on_each_node,
250
+ logging_dir = logging_dir,
251
+ logging_strategy = logging_strategy,
252
+ logging_first_step = logging_first_step,
253
+ logging_steps = logging_steps,
254
+ logging_nan_inf_filter = logging_nan_inf_filter,
255
+ save_strategy = save_strategy,
256
+ save_steps = save_steps,
257
+ save_total_limit = save_total_limit,
258
+ save_safetensors = save_safetensors,
259
+ save_on_each_node = save_on_each_node,
260
+ save_only_model = save_only_model,
261
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
262
+ no_cuda = no_cuda,
263
+ use_cpu = use_cpu,
264
+ use_mps_device = use_mps_device,
265
+ seed = seed,
266
+ data_seed = data_seed,
267
+ jit_mode_eval = jit_mode_eval,
268
+ use_ipex = use_ipex,
269
+ bf16 = bf16,
270
+ fp16 = fp16,
271
+ fp16_opt_level = fp16_opt_level,
272
+ half_precision_backend = half_precision_backend,
273
+ bf16_full_eval = bf16_full_eval,
274
+ fp16_full_eval = fp16_full_eval,
275
+ tf32 = tf32,
276
+ local_rank = local_rank,
277
+ ddp_backend = ddp_backend,
278
+ tpu_num_cores = tpu_num_cores,
279
+ tpu_metrics_debug = tpu_metrics_debug,
280
+ debug = debug,
281
+ dataloader_drop_last = dataloader_drop_last,
282
+ eval_steps = eval_steps,
283
+ dataloader_num_workers = dataloader_num_workers,
284
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
285
+ past_index = past_index,
286
+ run_name = run_name,
287
+ disable_tqdm = disable_tqdm,
288
+ remove_unused_columns = remove_unused_columns,
289
+ label_names = label_names,
290
+ load_best_model_at_end = load_best_model_at_end,
291
+ metric_for_best_model = metric_for_best_model,
292
+ greater_is_better = greater_is_better,
293
+ ignore_data_skip = ignore_data_skip,
294
+ fsdp = fsdp,
295
+ fsdp_min_num_params = fsdp_min_num_params,
296
+ fsdp_config = fsdp_config,
297
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
298
+ accelerator_config = accelerator_config,
299
+ deepspeed = deepspeed,
300
+ label_smoothing_factor = label_smoothing_factor,
301
+ optim = optim,
302
+ optim_args = optim_args,
303
+ adafactor = adafactor,
304
+ group_by_length = group_by_length,
305
+ length_column_name = length_column_name,
306
+ report_to = report_to,
307
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
308
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
309
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
310
+ dataloader_pin_memory = dataloader_pin_memory,
311
+ dataloader_persistent_workers = dataloader_persistent_workers,
312
+ skip_memory_metrics = skip_memory_metrics,
313
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
314
+ push_to_hub = push_to_hub,
315
+ resume_from_checkpoint = resume_from_checkpoint,
316
+ hub_model_id = hub_model_id,
317
+ hub_strategy = hub_strategy,
318
+ hub_token = hub_token,
319
+ hub_private_repo = hub_private_repo,
320
+ hub_always_push = hub_always_push,
321
+ gradient_checkpointing = gradient_checkpointing,
322
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
323
+ include_inputs_for_metrics = include_inputs_for_metrics,
324
+ eval_do_concat_batches = eval_do_concat_batches,
325
+ fp16_backend = fp16_backend,
326
+ evaluation_strategy = evaluation_strategy,
327
+ push_to_hub_model_id = push_to_hub_model_id,
328
+ push_to_hub_organization = push_to_hub_organization,
329
+ push_to_hub_token = push_to_hub_token,
330
+ mp_parameters = mp_parameters,
331
+ auto_find_batch_size = auto_find_batch_size,
332
+ full_determinism = full_determinism,
333
+ torchdynamo = torchdynamo,
334
+ ray_scope = ray_scope,
335
+ ddp_timeout = ddp_timeout,
336
+ torch_compile = torch_compile,
337
+ torch_compile_backend = torch_compile_backend,
338
+ torch_compile_mode = torch_compile_mode,
339
+ dispatch_batches = dispatch_batches,
340
+ split_batches = split_batches,
341
+ include_tokens_per_second = include_tokens_per_second,
342
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
343
+ neftune_noise_alpha = neftune_noise_alpha,
344
+ optim_target_modules = optim_target_modules,
345
+ batch_eval_metrics = batch_eval_metrics,
346
+ eval_on_start = eval_on_start,
347
+ use_liger_kernel = use_liger_kernel,
348
+ eval_use_gather_object = eval_use_gather_object,
349
+ average_tokens_across_devices = average_tokens_across_devices,
350
+ reward_model_path = reward_model_path,
351
+ judge = judge,
352
+ max_new_tokens = max_new_tokens,
353
+ max_length = max_length,
354
+ temperature = temperature,
355
+ missing_eos_penalty = missing_eos_penalty,
356
+ loss_type = loss_type,
357
+ dataset_num_proc = dataset_num_proc,
358
+ disable_dropout = disable_dropout,
359
+ use_vllm = use_vllm,
360
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
361
+ self.vllm_sampling_params = vllm_sampling_params
362
+ self.unsloth_num_chunks = unsloth_num_chunks
363
+ pass
364
+
365
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
366
+ r""""""
367
+
368
+ _tag_names = ["trl", "nash-md"]
369
+
370
+ def __init__(
371
+ self,
372
+ model: Union[PreTrainedModel, nn.Module] = None,
373
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
374
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
375
+ judge: Optional[BasePairwiseJudge] = None,
376
+ args: Optional[NashMDConfig] = None,
377
+ data_collator: Optional[Callable] = None,
378
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
379
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
380
+ processing_class: Optional[
381
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
382
+ ] = None,
383
+ peft_config: Optional[dict] = None,
384
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
385
+ callbacks: Optional[list[TrainerCallback]] = None,
386
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
387
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
388
+ ) -> None:
389
+ super().__init__(
390
+ model=model,
391
+ ref_model=ref_model,
392
+ reward_model=reward_model,
393
+ judge=judge,
394
+ args=args,
395
+ data_collator=data_collator,
396
+ train_dataset=train_dataset,
397
+ eval_dataset=eval_dataset,
398
+ processing_class=processing_class,
399
+ reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
400
+ peft_config=peft_config,
401
+ compute_metrics=compute_metrics,
402
+ callbacks=callbacks,
403
+ optimizers=optimizers,
404
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
405
+ )
406
+
407
+ self._mixture_coef = self.args.mixture_coef
408
+
409
+ # Overwrite the stats dictionary to include NashMD specific statistics
410
+ self.stats = {
411
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
412
+ # Add "mixture_coef"
413
+ "loss/kl": [],
414
+ "objective/entropy": [],
415
+ "loss/score": [],
416
+ "rewards/probabilities": [],
417
+ "rewards/accuracies": [],
418
+ "rewards/margins": [],
419
+ "logps/chosen": [],
420
+ "logps/rejected": [],
421
+ "val/model_contain_eos_token": [],
422
+ "val/ref_contain_eos_token": [],
423
+ "beta": [],
424
+ "mixture_coef": [],
425
+ }
426
+ if self.reward_model is not None:
427
+ self.stats["rewards/chosen"] = []
428
+ self.stats["rewards/rejected"] = []
429
+
430
+ @property
431
+ def mixture_coef(self):
432
+ if isinstance(self._mixture_coef, list):
433
+ epoch = self.state.epoch
434
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
435
+ else:
436
+ return self._mixture_coef
437
+
438
+ def _generate_completions(self, model, prompts):
439
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
440
+ model_output = unwrapped_model.generate(
441
+ input_ids=prompts["input_ids"],
442
+ attention_mask=prompts["attention_mask"],
443
+ generation_config=self.generation_config,
444
+ )
445
+
446
+ ref_model = model if self.ref_model is None else self.ref_model
447
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
448
+ mixture_model = GeometricMixtureWrapper(
449
+ model=unwrapped_model,
450
+ ref_model=unwrapped_ref_model,
451
+ generation_config=self.generation_config,
452
+ mixture_coef=self.mixture_coef,
453
+ device=self.accelerator.device,
454
+ )
455
+
456
+ mixture_output = mixture_model.generate(
457
+ input_ids=prompts["input_ids"],
458
+ attention_mask=prompts["attention_mask"],
459
+ generation_config=self.generation_config,
460
+ )
461
+
462
+ return model_output, mixture_output
463
+
464
+ def _process_completions(self, model_output, mixture_output, prompts):
465
+ context_length = prompts["input_ids"].shape[1]
466
+
467
+ # Process model completions
468
+ model_completion_ids = model_output[:, context_length:]
469
+ model_completion_ids, model_completion_mask = truncate_right(
470
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
471
+ )
472
+ model_data = {
473
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
474
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
475
+ "raw": prompts["raw"],
476
+ }
477
+
478
+ # Process reference model completions
479
+ mixture_completion_ids = mixture_output[:, context_length:]
480
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
481
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
482
+ )
483
+ mixture_data = {
484
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
485
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
486
+ "raw": prompts["raw"],
487
+ }
488
+
489
+ return model_data, mixture_data
490
+
491
+ def _compute_rewards(self, model_data, mixture_data, context_length):
492
+ with torch.no_grad():
493
+ _, model_scores, _ = get_reward(
494
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
495
+ )
496
+ _, mixture_scores, _ = get_reward(
497
+ self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
498
+ )
499
+
500
+ # Apply EOS penalty if needed
501
+ if self.args.missing_eos_penalty is not None:
502
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
503
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
504
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
505
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
506
+
507
+ return model_scores, mixture_scores
508
+
509
+ def _compute_judge(self, model_data, mixture_data, context_length):
510
+ prompts = model_data["raw"]
511
+ model_data_completions = self.processing_class.batch_decode(
512
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
513
+ )
514
+ model_data_completions = [completion.strip() for completion in model_data_completions]
515
+
516
+ mixture_data_completions = self.processing_class.batch_decode(
517
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
518
+ )
519
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
520
+ if is_conversational({"prompt": prompts[0]}):
521
+ model_data_completions = [
522
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
523
+ ]
524
+ environment = jinja2.Environment()
525
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
526
+ prompts = [template.render(messages=message) for message in prompts]
527
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
528
+
529
+ mixture_data_completions = [
530
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
531
+ ]
532
+ mixture_data_completions = [
533
+ template.render(messages=completion) for completion in mixture_data_completions
534
+ ]
535
+
536
+ probability = self.judge.judge(
537
+ prompts,
538
+ list(zip(model_data_completions, mixture_data_completions)),
539
+ return_scores=True,
540
+ )
541
+ return torch.tensor(probability, device=model_data["input_ids"].device)
542
+
543
+ def _compute_logprobs(self, model, model_data, context_length):
544
+ def compute_logprobs_for_data(m, data):
545
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
546
+ logits = output.logits[:, context_length - 1 : -1]
547
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
548
+ return token_logprobs
549
+
550
+ # Compute logprobs for model completions under the model
551
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
552
+
553
+ # Compute logprobs of model completions under the reference model
554
+ with torch.no_grad():
555
+ if self.ref_model is None:
556
+ with model.disable_adapter():
557
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
558
+ else:
559
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
560
+
561
+ # Mask padding tokens
562
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
563
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
564
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
565
+
566
+ return (model_logprobs_model_data, ref_logprobs_model_data)
567
+
568
+ def _compute_losses(
569
+ self,
570
+ model_logprobs_model_data,
571
+ ref_logprobs_model_data,
572
+ probability,
573
+ ):
574
+ # reinforce score where 0.5 is a control variate
575
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
576
+
577
+ # kl divergence via reinforce
578
+ with torch.no_grad():
579
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
580
+ kl_div_log = log_ratio.sum(1)
581
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
582
+
583
+ # final loss
584
+ loss = self.beta * kl_div_loss - score
585
+
586
+ return loss.mean(), score, kl_div_log
587
+
588
+ def _log_statistics(
589
+ self,
590
+ model_data,
591
+ mixture_data,
592
+ model_logprobs_model_data,
593
+ ref_logprobs_model_data,
594
+ probability,
595
+ score,
596
+ kl_div,
597
+ context_length,
598
+ model_scores=None,
599
+ mixture_scores=None,
600
+ ):
601
+ # Helper function to gather and compute mean
602
+ def gather_mean(tensor):
603
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
604
+
605
+ # Log score
606
+ self.stats["loss/score"].append(gather_mean(score))
607
+ # Log KL divergence
608
+ self.stats["loss/kl"].append(gather_mean(kl_div))
609
+
610
+ # Log logprobs
611
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
612
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
613
+
614
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
615
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
616
+
617
+ # Log rewards
618
+ if self.reward_model is not None:
619
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
620
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
621
+
622
+ # Log probabilities
623
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
624
+
625
+ # Calculate entropy for model data
626
+ entropy_model_data = -model_logprobs_model_data.sum(1)
627
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
628
+
629
+ # Calculate margins
630
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
631
+ self.stats["rewards/margins"].append(gather_mean(margin))
632
+
633
+ # Calculate accuracy
634
+ accuracy = (margin > 0).float()
635
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
636
+
637
+ # Log EOS token statistics
638
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
639
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
640
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
641
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
642
+
643
+ # Log beta and mixture coef
644
+ self.stats["beta"].append(self.beta)
645
+ self.stats["mixture_coef"].append(self.mixture_coef)
646
+
647
+ def training_step(
648
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
649
+ ) -> torch.Tensor:
650
+ model.train()
651
+
652
+ # Apply chat template and tokenize the input
653
+ batch_size = len(next(iter(inputs.values())))
654
+ prompts = inputs["prompt"]
655
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
656
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
657
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
658
+ inputs = self.data_collator(inputs)
659
+
660
+ # need the prompt_ only
661
+ inputs = self._prepare_inputs(inputs)
662
+ context_length = inputs["prompt_input_ids"].shape[1]
663
+ prompts = {
664
+ "input_ids": inputs["prompt_input_ids"],
665
+ "attention_mask": inputs["prompt_attention_mask"],
666
+ "raw": prompts,
667
+ }
668
+ del inputs
669
+
670
+ # Sample completions from both the model and the reference model
671
+ model_output, mixture_output = self._generate_completions(model, prompts)
672
+
673
+ # Process model completions
674
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
675
+
676
+ # Compute rewards
677
+ if self.reward_model is not None:
678
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
679
+ # probability of the model data vs the mixture data
680
+ probability = F.sigmoid(model_scores - mixture_scores)
681
+ else:
682
+ model_scores, mixture_scores = None, None
683
+ probability = self._compute_judge(model_data, mixture_data, context_length)
684
+
685
+ # Compute logprobs
686
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
687
+
688
+ # Compute loss
689
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
690
+
691
+ # Log everything
692
+ self._log_statistics(
693
+ model_data,
694
+ mixture_data,
695
+ model_logprobs_model_data.detach(),
696
+ ref_logprobs_model_data,
697
+ probability,
698
+ score.detach(),
699
+ kl_div.detach(),
700
+ context_length,
701
+ model_scores,
702
+ mixture_scores,
703
+ )
704
+
705
+ if (
706
+ self.args.torch_empty_cache_steps is not None
707
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
708
+ ):
709
+ empty_cache()
710
+
711
+ kwargs = {}
712
+ # For LOMO optimizers you need to explicitly use the learning rate
713
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
714
+ kwargs["learning_rate"] = self._get_learning_rate()
715
+
716
+ if self.args.n_gpu > 1:
717
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
718
+
719
+ if self.use_apex:
720
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
721
+ scaled_loss.backward()
722
+ else:
723
+ self.accelerator.backward(loss, **kwargs)
724
+
725
+ return loss.detach() / self.args.gradient_accumulation_steps
726
+
727
+ def create_model_card(
728
+ self,
729
+ model_name: Optional[str] = None,
730
+ dataset_name: Optional[str] = None,
731
+ tags: Union[str, list[str], None] = None,
732
+ ):
733
+ """
734
+ Creates a draft of a model card using the information available to the `Trainer`.
735
+
736
+ Args:
737
+ model_name (`str` or `None`, *optional*, defaults to `None`):
738
+ Name of the model.
739
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
740
+ Name of the dataset used for training.
741
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
742
+ Tags to be associated with the model card.
743
+ """
744
+ if not self.is_world_process_zero():
745
+ return
746
+
747
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
748
+ base_model = self.model.config._name_or_path
749
+ else:
750
+ base_model = None
751
+
752
+ tags = tags or []
753
+ if isinstance(tags, str):
754
+ tags = [tags]
755
+
756
+ if hasattr(self.model.config, "unsloth_version"):
757
+ tags.append("unsloth")
758
+
759
+ citation = textwrap.dedent("""\
760
+ @inproceedings{munos2024nash,
761
+ title = {{Nash Learning from Human Feedback}},
762
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
763
+ year = 2024,
764
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
765
+ publisher = {OpenReview.net},
766
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
767
+ }""")
768
+
769
+ model_card = generate_model_card(
770
+ base_model=base_model,
771
+ model_name=model_name,
772
+ hub_model_id=self.hub_model_id,
773
+ dataset_name=dataset_name,
774
+ tags=tags,
775
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
776
+ comet_url=get_comet_experiment_url(),
777
+ trainer_name="Nash-MD",
778
+ trainer_citation=citation,
779
+ paper_title="Nash Learning from Human Feedback",
780
+ paper_id="2312.00886",
781
+ )
782
+
783
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
784
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
785
+ """
786
+
787
+ Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
788
+
789
+ Args:
790
+ model (`transformers.PreTrainedModel`):
791
+ The model to train, preferably an `AutoModelForCausalLM`.
792
+ ref_model (`PreTrainedModelWrapper`):
793
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
794
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
795
+ reward_model (`transformers.PreTrainedModel`):
796
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
797
+ judge (`BasePairwiseJudge`):
798
+ The judge to use for pairwise comparison of model completions.
799
+ args (`NashMDConfig`):
800
+ The NashMD config arguments to use for training.
801
+ data_collator (`transformers.DataCollator`):
802
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
803
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
804
+ train_dataset (`datasets.Dataset`):
805
+ The dataset to use for training.
806
+ eval_dataset (`datasets.Dataset`):
807
+ The dataset to use for evaluation.
808
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
809
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
810
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
811
+ reuse the fine-tuned model.
812
+ peft_config (`dict`):
813
+ The peft config to use for training.
814
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
815
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
816
+ a dictionary string to metric values.
817
+ callbacks (`list[transformers.TrainerCallback]`):
818
+ The callbacks to use for training.
819
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
820
+ The optimizer and scheduler to use for training.
821
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
822
+ The function to use to preprocess the logits before computing the metrics.
823
+
824
+ """
825
+ def __init__(
826
+ self,
827
+ model = None,
828
+ ref_model = None,
829
+ reward_model = None,
830
+ judge = None,
831
+ args = None,
832
+ data_collator = None,
833
+ train_dataset = None,
834
+ eval_dataset = None,
835
+ processing_class = None,
836
+ peft_config = None,
837
+ compute_metrics = None,
838
+ callbacks = None,
839
+ preprocess_logits_for_metrics = None,
840
+ **kwargs
841
+ ):
842
+ if args is None: args = UnslothNashMDConfig()
843
+ use_bf16 = getattr(args, 'bf16', False)
844
+ use_fp16 = getattr(args, 'fp16', False)
845
+ force_float32 = False
846
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
847
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
848
+ force_float32 = True
849
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
850
+ dtype = getattr(model.config, 'torch_dtype', None)
851
+ if dtype is None: dtype = model.get_input_embeddings().dtype
852
+ from unsloth_zoo.utils import _get_dtype
853
+ dtype = _get_dtype(dtype)
854
+ float16 = dtype == torch.float16
855
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
856
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
857
+ if force_float32:
858
+ args.fp16 = False
859
+ args.bf16 = False
860
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
861
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
862
+ args.fp16 = float16
863
+ args.bf16 = not float16
864
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
865
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
866
+ args.eval_strategy = 'steps'
867
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
868
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
869
+ if ga_steps is not None and ga_steps > 1:
870
+ from transformers import __version__ as transformers_version
871
+ if Version(transformers_version) <= Version('4.45.2'):
872
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
873
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
874
+ if getattr(args, 'eval_strategy', 'no') != 'no':
875
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
876
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
877
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
878
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
879
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
880
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
881
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
882
+ if force_float32:
883
+ args.bf16_full_eval = False
884
+ args.fp16_full_eval = False
885
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
886
+ args.bf16_full_eval = True
887
+ args.fp16_full_eval = False
888
+ elif not bf16_full_eval and not fp16_full_eval:
889
+ args.bf16_full_eval = args.bf16
890
+ args.fp16_full_eval = args.fp16
891
+ _output_logits = False
892
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
893
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
894
+ if _output_logits:
895
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
896
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
897
+ pass
898
+ else:
899
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
900
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
901
+ if args_max_seq_length is None and model_max_seq_length is not None:
902
+ max_seq_length = model.max_seq_length
903
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
904
+ if model is not None and hasattr(model, 'for_training'):
905
+ model.for_training()
906
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
907
+ if 'processing_class' in locals():
908
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
909
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
910
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
911
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
912
+ if not isinstance(data_collator, UnslothVisionDataCollator):
913
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
914
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
915
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
916
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
917
+ else:
918
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
919
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
920
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
921
+ if not isinstance(data_collator, UnslothVisionDataCollator):
922
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
923
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
924
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
925
+ else:
926
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
927
+ other_metrics = []
928
+
929
+ from unsloth_zoo.logging_utils import PatchRLStatistics
930
+ PatchRLStatistics('nash_md_trainer', other_metrics)
931
+
932
+ super().__init__(
933
+ model = model,
934
+ ref_model = ref_model,
935
+ reward_model = reward_model,
936
+ judge = judge,
937
+ args = args,
938
+ data_collator = data_collator,
939
+ train_dataset = train_dataset,
940
+ eval_dataset = eval_dataset,
941
+ processing_class = processing_class,
942
+ peft_config = peft_config,
943
+ compute_metrics = compute_metrics,
944
+ callbacks = callbacks,
945
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
946
+ if hasattr(self, 'neftune_hook_handle'):
947
+ self.neftune_hook_handle.remove()
948
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
949
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
950
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
951
+ pass
952
+
953
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothORPOConfig(ORPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`ORPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
66
+ it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
67
+ disable_dropout (`bool`, *optional*, defaults to `True`):
68
+ Whether to disable dropout in the model.
69
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
70
+ Label pad token id. This argument is required if you want to use the default data collator.
71
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
72
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
73
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
74
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
75
+ This argument is required if you want to use the default data collator.
76
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
77
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
78
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
79
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
80
+ you need to specify if the model returned by the callable is an encoder-decoder model.
81
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
83
+ string.
84
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
+ Number of processes to use for processing the dataset.
86
+
87
+ """
88
+ vllm_sampling_params: Optional[Any] = field(
89
+ default = None,
90
+ metadata = {'help': 'vLLM SamplingParams'},
91
+ )
92
+ unsloth_num_chunks : Optional[int] = field(
93
+ default = -1,
94
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
95
+ )
96
+ def __init__(
97
+ self,
98
+ output_dir = None,
99
+ overwrite_output_dir = None,
100
+ do_train = False,
101
+ do_eval = False,
102
+ do_predict = False,
103
+ eval_strategy = 'no',
104
+ prediction_loss_only = False,
105
+ per_device_train_batch_size = 4,
106
+ per_device_eval_batch_size = 4,
107
+ per_gpu_train_batch_size = None,
108
+ per_gpu_eval_batch_size = None,
109
+ gradient_accumulation_steps = 2,
110
+ eval_accumulation_steps = 2,
111
+ eval_delay = 0,
112
+ torch_empty_cache_steps = 250,
113
+ learning_rate = 5e-05,
114
+ weight_decay = 0.01,
115
+ adam_beta1 = 0.9,
116
+ adam_beta2 = 0.999,
117
+ adam_epsilon = 1e-08,
118
+ max_grad_norm = 1.0,
119
+ num_train_epochs = 3.0,
120
+ max_steps = -1,
121
+ lr_scheduler_type = 'linear',
122
+ warmup_ratio = 0.1,
123
+ warmup_steps = 0,
124
+ log_level = 'passive',
125
+ log_level_replica = 'warning',
126
+ log_on_each_node = True,
127
+ logging_dir = None,
128
+ logging_strategy = 'steps',
129
+ logging_first_step = False,
130
+ logging_steps = 1,
131
+ logging_nan_inf_filter = False,
132
+ save_strategy = 'steps',
133
+ save_steps = 500,
134
+ save_total_limit = None,
135
+ save_safetensors = True,
136
+ save_on_each_node = False,
137
+ save_only_model = False,
138
+ restore_callback_states_from_checkpoint = False,
139
+ no_cuda = False,
140
+ use_cpu = False,
141
+ use_mps_device = False,
142
+ seed = 3407,
143
+ data_seed = 3407,
144
+ jit_mode_eval = False,
145
+ use_ipex = False,
146
+ bf16 = False,
147
+ fp16 = False,
148
+ fp16_opt_level = 'O1',
149
+ half_precision_backend = 'auto',
150
+ bf16_full_eval = False,
151
+ fp16_full_eval = False,
152
+ tf32 = None,
153
+ local_rank = -1,
154
+ ddp_backend = None,
155
+ tpu_num_cores = None,
156
+ tpu_metrics_debug = False,
157
+ debug = '',
158
+ dataloader_drop_last = False,
159
+ eval_steps = None,
160
+ dataloader_num_workers = 0,
161
+ dataloader_prefetch_factor = None,
162
+ past_index = -1,
163
+ run_name = None,
164
+ disable_tqdm = None,
165
+ remove_unused_columns = True,
166
+ label_names = None,
167
+ load_best_model_at_end = False,
168
+ metric_for_best_model = None,
169
+ greater_is_better = None,
170
+ ignore_data_skip = False,
171
+ fsdp = '',
172
+ fsdp_min_num_params = 0,
173
+ fsdp_config = None,
174
+ fsdp_transformer_layer_cls_to_wrap = None,
175
+ accelerator_config = None,
176
+ deepspeed = None,
177
+ label_smoothing_factor = 0.0,
178
+ optim = 'adamw_8bit',
179
+ optim_args = None,
180
+ adafactor = False,
181
+ group_by_length = False,
182
+ length_column_name = 'length',
183
+ report_to = None,
184
+ ddp_find_unused_parameters = None,
185
+ ddp_bucket_cap_mb = None,
186
+ ddp_broadcast_buffers = None,
187
+ dataloader_pin_memory = True,
188
+ dataloader_persistent_workers = False,
189
+ skip_memory_metrics = True,
190
+ use_legacy_prediction_loop = False,
191
+ push_to_hub = False,
192
+ resume_from_checkpoint = None,
193
+ hub_model_id = None,
194
+ hub_strategy = 'every_save',
195
+ hub_token = None,
196
+ hub_private_repo = None,
197
+ hub_always_push = False,
198
+ gradient_checkpointing = False,
199
+ gradient_checkpointing_kwargs = None,
200
+ include_inputs_for_metrics = False,
201
+ eval_do_concat_batches = True,
202
+ fp16_backend = 'auto',
203
+ evaluation_strategy = None,
204
+ push_to_hub_model_id = None,
205
+ push_to_hub_organization = None,
206
+ push_to_hub_token = None,
207
+ mp_parameters = '',
208
+ auto_find_batch_size = False,
209
+ full_determinism = False,
210
+ torchdynamo = None,
211
+ ray_scope = 'last',
212
+ ddp_timeout = 1800,
213
+ torch_compile = False,
214
+ torch_compile_backend = None,
215
+ torch_compile_mode = None,
216
+ dispatch_batches = None,
217
+ split_batches = None,
218
+ include_tokens_per_second = False,
219
+ include_num_input_tokens_seen = False,
220
+ neftune_noise_alpha = None,
221
+ optim_target_modules = None,
222
+ batch_eval_metrics = False,
223
+ eval_on_start = False,
224
+ use_liger_kernel = False,
225
+ eval_use_gather_object = False,
226
+ average_tokens_across_devices = False,
227
+ max_length = 1024,
228
+ max_prompt_length = 512,
229
+ max_completion_length = None,
230
+ beta = 0.1,
231
+ disable_dropout = True,
232
+ label_pad_token_id = -100,
233
+ padding_value = None,
234
+ truncation_mode = 'keep_end',
235
+ generate_during_eval = False,
236
+ is_encoder_decoder = None,
237
+ model_init_kwargs = None,
238
+ dataset_num_proc = None,
239
+ vllm_sampling_params = None,
240
+ unsloth_num_chunks = -1,
241
+ **kwargs,
242
+ ):
243
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
244
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
245
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
246
+ output_dir = 'unsloth_training_checkpoints'
247
+ save_strategy = 'no'
248
+ if dataset_num_proc is None:
249
+ from multiprocessing import cpu_count
250
+ dataset_num_proc = cpu_count()
251
+
252
+ super().__init__(
253
+ output_dir = output_dir,
254
+ overwrite_output_dir = overwrite_output_dir,
255
+ do_train = do_train,
256
+ do_eval = do_eval,
257
+ do_predict = do_predict,
258
+ eval_strategy = eval_strategy,
259
+ prediction_loss_only = prediction_loss_only,
260
+ per_device_train_batch_size = per_device_train_batch_size,
261
+ per_device_eval_batch_size = per_device_eval_batch_size,
262
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
263
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
264
+ gradient_accumulation_steps = gradient_accumulation_steps,
265
+ eval_accumulation_steps = eval_accumulation_steps,
266
+ eval_delay = eval_delay,
267
+ torch_empty_cache_steps = torch_empty_cache_steps,
268
+ learning_rate = learning_rate,
269
+ weight_decay = weight_decay,
270
+ adam_beta1 = adam_beta1,
271
+ adam_beta2 = adam_beta2,
272
+ adam_epsilon = adam_epsilon,
273
+ max_grad_norm = max_grad_norm,
274
+ num_train_epochs = num_train_epochs,
275
+ max_steps = max_steps,
276
+ lr_scheduler_type = lr_scheduler_type,
277
+ warmup_ratio = warmup_ratio,
278
+ warmup_steps = warmup_steps,
279
+ log_level = log_level,
280
+ log_level_replica = log_level_replica,
281
+ log_on_each_node = log_on_each_node,
282
+ logging_dir = logging_dir,
283
+ logging_strategy = logging_strategy,
284
+ logging_first_step = logging_first_step,
285
+ logging_steps = logging_steps,
286
+ logging_nan_inf_filter = logging_nan_inf_filter,
287
+ save_strategy = save_strategy,
288
+ save_steps = save_steps,
289
+ save_total_limit = save_total_limit,
290
+ save_safetensors = save_safetensors,
291
+ save_on_each_node = save_on_each_node,
292
+ save_only_model = save_only_model,
293
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
294
+ no_cuda = no_cuda,
295
+ use_cpu = use_cpu,
296
+ use_mps_device = use_mps_device,
297
+ seed = seed,
298
+ data_seed = data_seed,
299
+ jit_mode_eval = jit_mode_eval,
300
+ use_ipex = use_ipex,
301
+ bf16 = bf16,
302
+ fp16 = fp16,
303
+ fp16_opt_level = fp16_opt_level,
304
+ half_precision_backend = half_precision_backend,
305
+ bf16_full_eval = bf16_full_eval,
306
+ fp16_full_eval = fp16_full_eval,
307
+ tf32 = tf32,
308
+ local_rank = local_rank,
309
+ ddp_backend = ddp_backend,
310
+ tpu_num_cores = tpu_num_cores,
311
+ tpu_metrics_debug = tpu_metrics_debug,
312
+ debug = debug,
313
+ dataloader_drop_last = dataloader_drop_last,
314
+ eval_steps = eval_steps,
315
+ dataloader_num_workers = dataloader_num_workers,
316
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
317
+ past_index = past_index,
318
+ run_name = run_name,
319
+ disable_tqdm = disable_tqdm,
320
+ remove_unused_columns = remove_unused_columns,
321
+ label_names = label_names,
322
+ load_best_model_at_end = load_best_model_at_end,
323
+ metric_for_best_model = metric_for_best_model,
324
+ greater_is_better = greater_is_better,
325
+ ignore_data_skip = ignore_data_skip,
326
+ fsdp = fsdp,
327
+ fsdp_min_num_params = fsdp_min_num_params,
328
+ fsdp_config = fsdp_config,
329
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
330
+ accelerator_config = accelerator_config,
331
+ deepspeed = deepspeed,
332
+ label_smoothing_factor = label_smoothing_factor,
333
+ optim = optim,
334
+ optim_args = optim_args,
335
+ adafactor = adafactor,
336
+ group_by_length = group_by_length,
337
+ length_column_name = length_column_name,
338
+ report_to = report_to,
339
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
340
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
341
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
342
+ dataloader_pin_memory = dataloader_pin_memory,
343
+ dataloader_persistent_workers = dataloader_persistent_workers,
344
+ skip_memory_metrics = skip_memory_metrics,
345
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
346
+ push_to_hub = push_to_hub,
347
+ resume_from_checkpoint = resume_from_checkpoint,
348
+ hub_model_id = hub_model_id,
349
+ hub_strategy = hub_strategy,
350
+ hub_token = hub_token,
351
+ hub_private_repo = hub_private_repo,
352
+ hub_always_push = hub_always_push,
353
+ gradient_checkpointing = gradient_checkpointing,
354
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
355
+ include_inputs_for_metrics = include_inputs_for_metrics,
356
+ eval_do_concat_batches = eval_do_concat_batches,
357
+ fp16_backend = fp16_backend,
358
+ evaluation_strategy = evaluation_strategy,
359
+ push_to_hub_model_id = push_to_hub_model_id,
360
+ push_to_hub_organization = push_to_hub_organization,
361
+ push_to_hub_token = push_to_hub_token,
362
+ mp_parameters = mp_parameters,
363
+ auto_find_batch_size = auto_find_batch_size,
364
+ full_determinism = full_determinism,
365
+ torchdynamo = torchdynamo,
366
+ ray_scope = ray_scope,
367
+ ddp_timeout = ddp_timeout,
368
+ torch_compile = torch_compile,
369
+ torch_compile_backend = torch_compile_backend,
370
+ torch_compile_mode = torch_compile_mode,
371
+ dispatch_batches = dispatch_batches,
372
+ split_batches = split_batches,
373
+ include_tokens_per_second = include_tokens_per_second,
374
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
375
+ neftune_noise_alpha = neftune_noise_alpha,
376
+ optim_target_modules = optim_target_modules,
377
+ batch_eval_metrics = batch_eval_metrics,
378
+ eval_on_start = eval_on_start,
379
+ use_liger_kernel = use_liger_kernel,
380
+ eval_use_gather_object = eval_use_gather_object,
381
+ average_tokens_across_devices = average_tokens_across_devices,
382
+ max_length = max_length,
383
+ max_prompt_length = max_prompt_length,
384
+ max_completion_length = max_completion_length,
385
+ beta = beta,
386
+ disable_dropout = disable_dropout,
387
+ label_pad_token_id = label_pad_token_id,
388
+ padding_value = padding_value,
389
+ truncation_mode = truncation_mode,
390
+ generate_during_eval = generate_during_eval,
391
+ is_encoder_decoder = is_encoder_decoder,
392
+ model_init_kwargs = model_init_kwargs,
393
+ dataset_num_proc = dataset_num_proc,**kwargs)
394
+ self.vllm_sampling_params = vllm_sampling_params
395
+ self.unsloth_num_chunks = unsloth_num_chunks
396
+ pass
397
+
398
+ class _UnslothORPOTrainer(Trainer):
399
+ r""""""
400
+
401
+ _tag_names = ["trl", "orpo"]
402
+
403
+ def __init__(
404
+ self,
405
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
406
+ args: Optional[ORPOConfig] = None,
407
+ data_collator: Optional[DataCollator] = None,
408
+ train_dataset: Optional[Dataset] = None,
409
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
410
+ processing_class: Optional[
411
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
412
+ ] = None,
413
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
414
+ callbacks: Optional[list[TrainerCallback]] = None,
415
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
416
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
417
+ peft_config: Optional[dict] = None,
418
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
419
+ ):
420
+ if args.model_init_kwargs is None:
421
+ model_init_kwargs = {}
422
+ elif not isinstance(model, str):
423
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
424
+ else:
425
+ model_init_kwargs = args.model_init_kwargs
426
+ torch_dtype = model_init_kwargs.get("torch_dtype")
427
+ if torch_dtype is not None:
428
+ # Convert to `torch.dtype` if an str is passed
429
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
430
+ torch_dtype = getattr(torch, torch_dtype)
431
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
432
+ raise ValueError(
433
+ f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
434
+ )
435
+ model_init_kwargs["torch_dtype"] = torch_dtype
436
+
437
+ if isinstance(model, str):
438
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
439
+
440
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
441
+ # has been called in order to properly call autocast if needed.
442
+ self._peft_has_been_casted_to_bf16 = False
443
+
444
+ if not is_peft_available() and peft_config is not None:
445
+ raise ValueError(
446
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
447
+ )
448
+ elif is_peft_available() and peft_config is not None:
449
+ # if model is a peft model and we have a peft_config, we merge and unload it first
450
+ if isinstance(model, PeftModel):
451
+ model = model.merge_and_unload()
452
+
453
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
454
+ _support_gc_kwargs = hasattr(
455
+ args, "gradient_checkpointing_kwargs"
456
+ ) and "gradient_checkpointing_kwargs" in list(
457
+ inspect.signature(prepare_model_for_kbit_training).parameters
458
+ )
459
+
460
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
461
+
462
+ if _support_gc_kwargs:
463
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
464
+
465
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
466
+ elif getattr(args, "gradient_checkpointing", False):
467
+ # For backward compatibility with older versions of transformers
468
+ if hasattr(model, "enable_input_require_grads"):
469
+ model.enable_input_require_grads()
470
+ else:
471
+
472
+ def make_inputs_require_grad(module, input, output):
473
+ output.requires_grad_(True)
474
+
475
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
476
+
477
+ # get peft model with the given config
478
+ model = model
479
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
480
+ peft_module_casting_to_bf16(model)
481
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
482
+ self._peft_has_been_casted_to_bf16 = True
483
+
484
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
485
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
486
+ # fail or completely fail.
487
+ elif getattr(args, "gradient_checkpointing", False):
488
+ # For backward compatibility with older versions of transformers
489
+ if hasattr(model, "enable_input_require_grads"):
490
+ model.enable_input_require_grads()
491
+ else:
492
+
493
+ def make_inputs_require_grad(module, input, output):
494
+ output.requires_grad_(True)
495
+
496
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
497
+
498
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
499
+ raise ValueError(
500
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
501
+ " Please install `wandb` or `comet-ml` to resolve."
502
+ )
503
+
504
+ if model is not None:
505
+ self.is_encoder_decoder = model.config.is_encoder_decoder
506
+ elif args.is_encoder_decoder is None:
507
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
508
+ else:
509
+ self.is_encoder_decoder = args.is_encoder_decoder
510
+
511
+ if self.is_encoder_decoder:
512
+ self.decoder_start_token_id = model.config.decoder_start_token_id
513
+ self.pad_token_id = model.config.pad_token_id
514
+
515
+ if processing_class is None:
516
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
517
+ if args.max_length is None:
518
+ warnings.warn(
519
+ "`max_length` is not set in the ORPOConfig's init"
520
+ " it will default to `512` by default, but you should do it yourself in the future.",
521
+ UserWarning,
522
+ )
523
+ max_length = 512
524
+ else:
525
+ max_length = args.max_length
526
+ if args.max_prompt_length is None:
527
+ warnings.warn(
528
+ "`max_prompt_length` is not set in the ORPOConfig's init"
529
+ " it will default to `128` by default, but you should do it yourself in the future.",
530
+ UserWarning,
531
+ )
532
+ max_prompt_length = 128
533
+ else:
534
+ max_prompt_length = args.max_prompt_length
535
+
536
+ if args.max_completion_length is None and self.is_encoder_decoder:
537
+ warnings.warn(
538
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
539
+ " it will default to `128` by default, but you should do it yourself in the future.",
540
+ UserWarning,
541
+ )
542
+ self.max_completion_length = 128
543
+ else:
544
+ self.max_completion_length = args.max_completion_length
545
+
546
+ if data_collator is None:
547
+ data_collator = DPODataCollatorWithPadding(
548
+ pad_token_id=processing_class.pad_token_id,
549
+ label_pad_token_id=args.label_pad_token_id,
550
+ is_encoder_decoder=self.is_encoder_decoder,
551
+ )
552
+
553
+ if args.remove_unused_columns:
554
+ args.remove_unused_columns = False
555
+ # warn users
556
+ warnings.warn(
557
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
558
+ " we have set it for you, but you should do it yourself in the future.",
559
+ UserWarning,
560
+ )
561
+
562
+ self.use_dpo_data_collator = True
563
+ else:
564
+ self.use_dpo_data_collator = False
565
+
566
+ # Disable dropout in the model and reference model
567
+ if args.disable_dropout:
568
+ disable_dropout_in_model(model)
569
+
570
+ self.max_length = max_length
571
+ self.generate_during_eval = args.generate_during_eval
572
+ self.label_pad_token_id = args.label_pad_token_id
573
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
574
+ self.max_prompt_length = max_prompt_length
575
+ self.truncation_mode = args.truncation_mode
576
+ self.processing_class = processing_class
577
+
578
+ self.beta = args.beta
579
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
580
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
581
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
582
+ warnings.warn(
583
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
584
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
585
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
586
+ "loss.",
587
+ UserWarning,
588
+ )
589
+
590
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
591
+
592
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
593
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
594
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
595
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
596
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
597
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
598
+ # that the warning has already been issued.
599
+ model.warnings_issued["estimate_tokens"] = True
600
+
601
+ # Compute that only on the main process for faster data processing.
602
+ # see: https://github.com/huggingface/trl/pull/1255
603
+ with PartialState().local_main_process_first():
604
+ # Extract the prompt if needed, and apply the chat template if needed
605
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
606
+ train_dataset = train_dataset.map(
607
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
608
+ )
609
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
610
+ if eval_dataset is not None:
611
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
612
+ eval_dataset = eval_dataset.map(
613
+ maybe_apply_chat_template,
614
+ fn_kwargs={"tokenizer": processing_class},
615
+ num_proc=args.dataset_num_proc,
616
+ )
617
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
618
+
619
+ super().__init__(
620
+ model=model,
621
+ args=args,
622
+ data_collator=data_collator,
623
+ train_dataset=train_dataset,
624
+ eval_dataset=eval_dataset,
625
+ processing_class=processing_class,
626
+ model_init=model_init,
627
+ compute_metrics=compute_metrics,
628
+ callbacks=callbacks,
629
+ optimizers=optimizers,
630
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
631
+ )
632
+
633
+ # Add tags for models that have been loaded with the correct transformers version
634
+ if hasattr(self.model, "add_model_tags"):
635
+ self.model.add_model_tags(self._tag_names)
636
+
637
+ if not hasattr(self, "accelerator"):
638
+ raise AttributeError(
639
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
640
+ )
641
+
642
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
643
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
644
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
645
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
646
+
647
+ if model is not None:
648
+ if hasattr(model, "config"):
649
+ hidden_size = (
650
+ max(model.config.hidden_sizes)
651
+ if getattr(model.config, "hidden_sizes", None)
652
+ else getattr(model.config, "hidden_size", None)
653
+ )
654
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
655
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
656
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
657
+ config_kwargs.update(
658
+ {
659
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
660
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
661
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
662
+ }
663
+ )
664
+
665
+ # If ZeRO-3 is used, we shard both the active and reference model.
666
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
667
+ if config_kwargs["zero_optimization"]["stage"] != 3:
668
+ config_kwargs["zero_optimization"]["stage"] = 0
669
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
670
+ model.eval()
671
+ return model
672
+
673
+ def build_tokenized_answer(self, prompt, answer):
674
+ """
675
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
676
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
677
+ Reference:
678
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
679
+ """
680
+
681
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
682
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
683
+
684
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
685
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
686
+
687
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
688
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
689
+
690
+ # Prepare input tokens for token by token comparison
691
+ full_input_ids = np.array(full_tokenized["input_ids"])
692
+
693
+ if len(full_input_ids) != len(full_concat_input_ids):
694
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
695
+
696
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
697
+ # can be merged together when tokenizing prompt+answer. This could result
698
+ # on the last token from the prompt being different when tokenized on its own
699
+ # vs when done as prompt+answer.
700
+ response_token_ids_start_idx = len(prompt_input_ids)
701
+
702
+ # If tokenized prompt is different than both prompt+answer, then it means the
703
+ # last token has changed due to merging.
704
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
705
+ response_token_ids_start_idx -= 1
706
+
707
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
708
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
709
+
710
+ if len(prompt_input_ids) != len(prompt_attention_mask):
711
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
712
+
713
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
714
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
715
+
716
+ return dict(
717
+ prompt_input_ids=prompt_input_ids,
718
+ prompt_attention_mask=prompt_attention_mask,
719
+ input_ids=answer_input_ids,
720
+ attention_mask=answer_attention_mask,
721
+ )
722
+
723
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
724
+ """Tokenize a single row from a ORPO specific dataset.
725
+
726
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
727
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
728
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
729
+
730
+ We also create the labels for the chosen/rejected responses, which are of length equal to
731
+ the sum of the length of the prompt and the chosen/rejected response, with
732
+ label_pad_token_id for the prompt tokens.
733
+ """
734
+ batch = {}
735
+ prompt = feature["prompt"]
736
+ chosen = feature["chosen"]
737
+ rejected = feature["rejected"]
738
+
739
+ if not self.is_encoder_decoder:
740
+ # Check issues below for more details
741
+ # 1. https://github.com/huggingface/trl/issues/907
742
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
743
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
744
+
745
+ if not isinstance(prompt, str):
746
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
747
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
748
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
749
+
750
+ if not isinstance(chosen, str):
751
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
752
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
753
+
754
+ if not isinstance(rejected, str):
755
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
756
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
757
+
758
+ # Last prompt token might get merged by tokenizer and
759
+ # it should not be included for generation if that happens
760
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
761
+
762
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
763
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
764
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
765
+
766
+ for k, v in prompt_tokens.items():
767
+ prompt_tokens[k] = v[:prompt_len_input_ids]
768
+
769
+ # Make sure prompts only have one different token at most an
770
+ # and length only differs by 1 at most
771
+ num_diff_tokens = sum(
772
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
773
+ )
774
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
775
+ if num_diff_tokens > 1 or num_diff_len > 1:
776
+ raise ValueError(
777
+ "Chosen and rejected prompt_input_ids might only differ on the "
778
+ "last token due to tokenizer merge ops."
779
+ )
780
+
781
+ # add BOS token to head of prompt. Avoid adding if it's already there
782
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
783
+ self.processing_class.bos_token_id,
784
+ prompt_len_input_ids,
785
+ prompt_tokens,
786
+ chosen_prompt_len_input_ids,
787
+ chosen_tokens,
788
+ rejected_prompt_len_input_ids,
789
+ rejected_tokens,
790
+ )
791
+
792
+ # add EOS token to end of answer. Avoid adding if it's already there
793
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
794
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
795
+ )
796
+
797
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
798
+
799
+ # if combined sequence is too long, truncate the prompt
800
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
801
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
802
+ if self.truncation_mode == "keep_start":
803
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
804
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
805
+ elif self.truncation_mode == "keep_end":
806
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
807
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
808
+ else:
809
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
810
+
811
+ # if that's still too long, truncate the response
812
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
813
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
814
+ for k in ["input_ids", "attention_mask"]:
815
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
816
+
817
+ # Create labels
818
+ chosen_sequence_tokens = {
819
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
820
+ }
821
+ rejected_sequence_tokens = {
822
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
823
+ }
824
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
825
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
826
+ self.label_pad_token_id
827
+ ] * len(chosen_tokens["prompt_input_ids"])
828
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
829
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
830
+ self.label_pad_token_id
831
+ ] * len(rejected_tokens["prompt_input_ids"])
832
+
833
+ for k, toks in {
834
+ "chosen_": chosen_sequence_tokens,
835
+ "rejected_": rejected_sequence_tokens,
836
+ "": prompt_tokens,
837
+ }.items():
838
+ for type_key, tokens in toks.items():
839
+ if type_key == "token_type_ids":
840
+ continue
841
+ batch[f"{k}{type_key}"] = tokens
842
+
843
+ else:
844
+ chosen_tokens = self.processing_class(
845
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
846
+ )
847
+ rejected_tokens = self.processing_class(
848
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
849
+ )
850
+ prompt_tokens = self.processing_class(
851
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
852
+ )
853
+
854
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
855
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
856
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
857
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
858
+
859
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
860
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
861
+ labels=torch.tensor(batch["rejected_labels"])
862
+ )
863
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
864
+ labels=torch.tensor(batch["chosen_labels"])
865
+ )
866
+
867
+ if is_torch_xla_available():
868
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
869
+ for k in batch:
870
+ if "labels" in k or self.is_encoder_decoder:
871
+ pad_value = self.label_pad_token_id
872
+ elif k.endswith("_input_ids"):
873
+ pad_value = self.padding_value
874
+ elif k.endswith("_attention_mask"):
875
+ pad_value = 0
876
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
877
+ return batch
878
+
879
+ @staticmethod
880
+ def concatenated_inputs(
881
+ batch: dict[str, Union[list, torch.LongTensor]],
882
+ is_encoder_decoder: bool = False,
883
+ label_pad_token_id: int = -100,
884
+ padding_value: int = 0,
885
+ device: Optional[torch.device] = None,
886
+ ) -> dict[str, torch.LongTensor]:
887
+ """Concatenate the chosen and rejected inputs into a single tensor.
888
+
889
+ Args:
890
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
891
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
892
+ label_pad_token_id: The label pad token id.
893
+ padding_value: The padding value to use for the concatenated inputs_ids.
894
+ device: The device for the concatenated inputs.
895
+
896
+ Returns:
897
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
898
+ """
899
+ concatenated_batch = {}
900
+
901
+ if is_encoder_decoder:
902
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
903
+ else:
904
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
905
+
906
+ for k in batch:
907
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
908
+ if "labels" in k or is_encoder_decoder:
909
+ pad_value = label_pad_token_id
910
+ elif k.endswith("_input_ids"):
911
+ pad_value = padding_value
912
+ elif k.endswith("_attention_mask"):
913
+ pad_value = 0
914
+ concatenated_key = k.replace("chosen", "concatenated")
915
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
916
+ for k in batch:
917
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
918
+ if "labels" in k or is_encoder_decoder:
919
+ pad_value = label_pad_token_id
920
+ elif k.endswith("_input_ids"):
921
+ pad_value = padding_value
922
+ elif k.endswith("_attention_mask"):
923
+ pad_value = 0
924
+ concatenated_key = k.replace("rejected", "concatenated")
925
+ concatenated_batch[concatenated_key] = torch.cat(
926
+ (
927
+ concatenated_batch[concatenated_key],
928
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
929
+ ),
930
+ dim=0,
931
+ ).to(device=device)
932
+
933
+ if is_encoder_decoder:
934
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
935
+ concatenated_batch["concatenated_attention_mask"] = (
936
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
937
+ )
938
+
939
+ return concatenated_batch
940
+
941
+ def odds_ratio_loss(
942
+ self,
943
+ policy_chosen_logps: torch.FloatTensor,
944
+ policy_rejected_logps: torch.FloatTensor,
945
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
946
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
947
+
948
+ Args:
949
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
950
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
951
+
952
+ Returns:
953
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
954
+ The losses tensor contains the ORPO loss for each example in the batch.
955
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
956
+ The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
957
+ The `log(sigmoid(log_odds_chosen))` for logging purposes.
958
+ """
959
+
960
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
961
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
962
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
963
+ )
964
+ ratio = F.logsigmoid(log_odds)
965
+ losses = self.beta * ratio
966
+
967
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
968
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
969
+
970
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
971
+
972
+ @staticmethod
973
+ def get_batch_logps(
974
+ logits: torch.FloatTensor,
975
+ labels: torch.LongTensor,
976
+ average_log_prob: bool = False,
977
+ label_pad_token_id: int = -100,
978
+ is_encoder_decoder: bool = False,
979
+ ) -> torch.FloatTensor:
980
+ """Compute the log probabilities of the given labels under the given logits.
981
+
982
+ Args:
983
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
984
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
985
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
986
+ label_pad_token_id: The label pad token id.
987
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
988
+
989
+ Returns:
990
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
991
+ """
992
+ if logits.shape[:-1] != labels.shape:
993
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
994
+
995
+ if not is_encoder_decoder:
996
+ labels = labels[:, 1:].clone()
997
+ logits = logits[:, :-1, :]
998
+ loss_mask = labels != label_pad_token_id
999
+
1000
+ # dummy token; we'll ignore the losses on these tokens later
1001
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
1002
+
1003
+ per_token_logps = selective_log_softmax(logits, labels)
1004
+
1005
+ if average_log_prob:
1006
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1007
+ else:
1008
+ return (per_token_logps * loss_mask).sum(-1)
1009
+
1010
+ def concatenated_forward(
1011
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1012
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1013
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1014
+
1015
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1016
+ """
1017
+ concatenated_batch = self.concatenated_inputs(
1018
+ batch,
1019
+ is_encoder_decoder=self.is_encoder_decoder,
1020
+ label_pad_token_id=self.label_pad_token_id,
1021
+ padding_value=self.padding_value,
1022
+ device=self.accelerator.device,
1023
+ )
1024
+ len_chosen = batch["chosen_labels"].shape[0]
1025
+
1026
+ model_kwargs = (
1027
+ {
1028
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1029
+ }
1030
+ if self.is_encoder_decoder
1031
+ else {}
1032
+ )
1033
+
1034
+ if self.aux_loss_enabled:
1035
+ model_kwargs["output_router_logits"] = True
1036
+
1037
+ outputs = model(
1038
+ concatenated_batch["concatenated_input_ids"],
1039
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1040
+ use_cache=False,
1041
+ **model_kwargs,
1042
+ )
1043
+ all_logits = outputs.logits
1044
+
1045
+ def cross_entropy_loss(logits, labels):
1046
+ if not self.is_encoder_decoder:
1047
+ # Shift so that tokens < n predict n
1048
+ logits = logits[..., :-1, :].contiguous()
1049
+ labels = labels[..., 1:].contiguous()
1050
+ # Flatten the tokens
1051
+ loss_fct = nn.CrossEntropyLoss()
1052
+ logits = logits.view(-1, logits.shape[-1])
1053
+ labels = labels.view(-1)
1054
+ # Enable model parallelism
1055
+ labels = labels.to(logits.device)
1056
+ loss = loss_fct(logits, labels)
1057
+ return loss
1058
+
1059
+ if self.is_encoder_decoder:
1060
+ labels = concatenated_batch["concatenated_labels"].clone()
1061
+ else:
1062
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1063
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1064
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1065
+ # orpo chosen nll loss is computed over the full prompt and response
1066
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1067
+
1068
+ all_logps = self.get_batch_logps(
1069
+ all_logits,
1070
+ concatenated_batch["concatenated_labels"],
1071
+ average_log_prob=True,
1072
+ is_encoder_decoder=self.is_encoder_decoder,
1073
+ label_pad_token_id=self.label_pad_token_id,
1074
+ )
1075
+
1076
+ chosen_logps = all_logps[:len_chosen]
1077
+ rejected_logps = all_logps[len_chosen:]
1078
+
1079
+ if not self.is_encoder_decoder:
1080
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1081
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1082
+ else:
1083
+ chosen_logits = all_logits[:len_chosen]
1084
+ rejected_logits = all_logits[len_chosen:]
1085
+
1086
+ if self.aux_loss_enabled:
1087
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1088
+
1089
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1090
+
1091
+ def get_batch_loss_metrics(
1092
+ self,
1093
+ model,
1094
+ batch: dict[str, Union[list, torch.LongTensor]],
1095
+ train_eval: Literal["train", "eval"] = "train",
1096
+ ):
1097
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1098
+ metrics = {}
1099
+
1100
+ forward_output = self.concatenated_forward(model, batch)
1101
+ (
1102
+ policy_chosen_logps,
1103
+ policy_rejected_logps,
1104
+ policy_chosen_logits,
1105
+ policy_rejected_logits,
1106
+ policy_nll_loss,
1107
+ ) = forward_output[:5]
1108
+ if self.aux_loss_enabled:
1109
+ aux_loss = forward_output[5]
1110
+
1111
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1112
+ policy_chosen_logps, policy_rejected_logps
1113
+ )
1114
+ # full ORPO loss
1115
+ loss = policy_nll_loss - losses.mean()
1116
+
1117
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1118
+
1119
+ prefix = "eval_" if train_eval == "eval" else ""
1120
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1121
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1122
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1123
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1124
+ chosen_rewards - rejected_rewards
1125
+ ).mean()
1126
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1127
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1128
+ metrics[f"{prefix}logits/rejected"] = (
1129
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1130
+ )
1131
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1132
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1133
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1134
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1135
+ if is_torch_xla_available():
1136
+ xm.mark_step() # needed because .item() calls
1137
+ for k, v in metrics.items():
1138
+ metrics[k] = v.item()
1139
+ if self.aux_loss_enabled:
1140
+ loss += self.aux_loss_coef * aux_loss
1141
+
1142
+ return loss, metrics
1143
+
1144
+ def compute_loss(
1145
+ self,
1146
+ model: Union[PreTrainedModel, nn.Module],
1147
+ inputs: dict[str, Union[torch.Tensor, Any]],
1148
+ return_outputs=False,
1149
+ num_items_in_batch=None,
1150
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1151
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1152
+
1153
+ with compute_loss_context_manager:
1154
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1155
+
1156
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1157
+ loss = loss.to(self.args.device)
1158
+
1159
+ # force log the metrics
1160
+ self.store_metrics(metrics, train_eval="train")
1161
+
1162
+ if return_outputs:
1163
+ return (loss, metrics)
1164
+ return loss
1165
+
1166
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1167
+ """Generate samples from the model and reference model for the given batch of inputs."""
1168
+
1169
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1170
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1171
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1172
+
1173
+ with generate_context_manager:
1174
+ policy_output = model.generate(
1175
+ input_ids=batch["prompt_input_ids"],
1176
+ attention_mask=batch["prompt_attention_mask"],
1177
+ max_length=self.max_length,
1178
+ do_sample=True,
1179
+ pad_token_id=self.processing_class.pad_token_id,
1180
+ )
1181
+
1182
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1183
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1184
+
1185
+ return policy_output_decoded
1186
+
1187
+ def prediction_step(
1188
+ self,
1189
+ model: Union[PreTrainedModel, nn.Module],
1190
+ inputs: dict[str, Union[torch.Tensor, Any]],
1191
+ prediction_loss_only: bool,
1192
+ ignore_keys: Optional[list[str]] = None,
1193
+ ):
1194
+ if not self.use_dpo_data_collator:
1195
+ warnings.warn(
1196
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1197
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1198
+ )
1199
+ if ignore_keys is None:
1200
+ if hasattr(model, "config"):
1201
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1202
+ else:
1203
+ ignore_keys = []
1204
+
1205
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1206
+
1207
+ with torch.no_grad(), prediction_context_manager:
1208
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1209
+
1210
+ # force log the metrics
1211
+ self.store_metrics(metrics, train_eval="eval")
1212
+
1213
+ if prediction_loss_only:
1214
+ return (loss.detach(), None, None)
1215
+
1216
+ # logits for the chosen and rejected samples from model
1217
+ logits_dict = {
1218
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1219
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1220
+ }
1221
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1222
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1223
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1224
+
1225
+ return (loss.detach(), logits, labels)
1226
+
1227
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1228
+ for key, value in metrics.items():
1229
+ self._stored_metrics[train_eval][key].append(value)
1230
+
1231
+ def evaluation_loop(
1232
+ self,
1233
+ dataloader: DataLoader,
1234
+ description: str,
1235
+ prediction_loss_only: Optional[bool] = None,
1236
+ ignore_keys: Optional[list[str]] = None,
1237
+ metric_key_prefix: str = "eval",
1238
+ ) -> EvalLoopOutput:
1239
+ """
1240
+ Overriding built-in evaluation loop to store metrics for each batch.
1241
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1242
+
1243
+ Works both with or without labels.
1244
+ """
1245
+
1246
+ # Sample and save to game log if requested (for one batch to save time)
1247
+ if self.generate_during_eval:
1248
+ # Generate random indices within the range of the total number of samples
1249
+ num_samples = len(dataloader.dataset)
1250
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1251
+
1252
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1253
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1254
+ random_batch = self.data_collator(random_batch_dataset)
1255
+ random_batch = self._prepare_inputs(random_batch)
1256
+
1257
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1258
+
1259
+ table = pd.DataFrame(
1260
+ columns=["Prompt", "Policy"],
1261
+ data=[
1262
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1263
+ ],
1264
+ )
1265
+ if "wandb" in self.args.report_to:
1266
+ wandb.log({"game_log": wandb.Table(data=table)})
1267
+
1268
+ if "comet_ml" in self.args.report_to:
1269
+ log_table_to_comet_experiment(
1270
+ name="game_log.csv",
1271
+ table=table,
1272
+ )
1273
+
1274
+ # Base evaluation
1275
+ initial_output = super().evaluation_loop(
1276
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1277
+ )
1278
+
1279
+ return initial_output
1280
+
1281
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1282
+ """
1283
+ Log `logs` on the various objects watching training, including stored metrics.
1284
+
1285
+ Args:
1286
+ logs (`dict[str, float]`):
1287
+ The values to log.
1288
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1289
+ Start time of the training.
1290
+ """
1291
+ # logs either has 'loss' or 'eval_loss'
1292
+ train_eval = "train" if "loss" in logs else "eval"
1293
+ # Add averaged stored metrics to logs
1294
+ for key, metrics in self._stored_metrics[train_eval].items():
1295
+ logs[key] = torch.tensor(metrics).mean().item()
1296
+ del self._stored_metrics[train_eval]
1297
+
1298
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1299
+ return super().log(logs, start_time)
1300
+ else: # transformers<=4.46
1301
+ return super().log(logs)
1302
+
1303
+ def _shift_right(self, input_ids):
1304
+ if self.decoder_start_token_id is None:
1305
+ raise ValueError(
1306
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1307
+ )
1308
+
1309
+ # shift inputs to the right
1310
+ if is_torch_fx_proxy(input_ids):
1311
+ # Item assignment is not supported natively for proxies.
1312
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1313
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1314
+ else:
1315
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1316
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1317
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1318
+
1319
+ if self.pad_token_id is None:
1320
+ raise ValueError("model.config.pad_token_id has to be defined.")
1321
+ # replace possible -100 values in labels by `pad_token_id`
1322
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1323
+
1324
+ return shifted_input_ids
1325
+
1326
+ def create_model_card(
1327
+ self,
1328
+ model_name: Optional[str] = None,
1329
+ dataset_name: Optional[str] = None,
1330
+ tags: Union[str, list[str], None] = None,
1331
+ ):
1332
+ """
1333
+ Creates a draft of a model card using the information available to the `Trainer`.
1334
+
1335
+ Args:
1336
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1337
+ Name of the model.
1338
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1339
+ Name of the dataset used for training.
1340
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1341
+ Tags to be associated with the model card.
1342
+ """
1343
+ if not self.is_world_process_zero():
1344
+ return
1345
+
1346
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1347
+ base_model = self.model.config._name_or_path
1348
+ else:
1349
+ base_model = None
1350
+
1351
+ tags = tags or []
1352
+ if isinstance(tags, str):
1353
+ tags = [tags]
1354
+
1355
+ if hasattr(self.model.config, "unsloth_version"):
1356
+ tags.append("unsloth")
1357
+
1358
+ citation = textwrap.dedent("""\
1359
+ @article{hong2024orpo,
1360
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1361
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
1362
+ year = 2024,
1363
+ eprint = {arXiv:2403.07691}
1364
+ }""")
1365
+
1366
+ model_card = generate_model_card(
1367
+ base_model=base_model,
1368
+ model_name=model_name,
1369
+ hub_model_id=self.hub_model_id,
1370
+ dataset_name=dataset_name,
1371
+ tags=tags,
1372
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1373
+ comet_url=get_comet_experiment_url(),
1374
+ trainer_name="ORPO",
1375
+ trainer_citation=citation,
1376
+ paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1377
+ paper_id="2403.07691",
1378
+ )
1379
+
1380
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1381
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1382
+ """
1383
+
1384
+ Initialize ORPOTrainer.
1385
+
1386
+ Args:
1387
+ model (`transformers.PreTrainedModel`):
1388
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1389
+ args (`ORPOConfig`):
1390
+ The ORPO config arguments to use for training.
1391
+ data_collator (`transformers.DataCollator`):
1392
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1393
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1394
+ train_dataset (`datasets.Dataset`):
1395
+ The dataset to use for training.
1396
+ eval_dataset (`datasets.Dataset`):
1397
+ The dataset to use for evaluation.
1398
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1399
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1400
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1401
+ reuse the fine-tuned model.
1402
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1403
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1404
+ callbacks (`list[transformers.TrainerCallback]`):
1405
+ The callbacks to use for training.
1406
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1407
+ The optimizer and scheduler to use for training.
1408
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1409
+ The function to use to preprocess the logits before computing the metrics.
1410
+ peft_config (`dict`, defaults to `None`):
1411
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1412
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1413
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1414
+ a dictionary string to metric values.
1415
+
1416
+ """
1417
+ def __init__(
1418
+ self,
1419
+ model = None,
1420
+ args = None,
1421
+ data_collator = None,
1422
+ train_dataset = None,
1423
+ eval_dataset = None,
1424
+ processing_class = None,
1425
+ model_init = None,
1426
+ callbacks = None,
1427
+ preprocess_logits_for_metrics = None,
1428
+ peft_config = None,
1429
+ compute_metrics = None,
1430
+ **kwargs
1431
+ ):
1432
+ if args is None: args = UnslothORPOConfig()
1433
+ use_bf16 = getattr(args, 'bf16', False)
1434
+ use_fp16 = getattr(args, 'fp16', False)
1435
+ force_float32 = False
1436
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1437
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1438
+ force_float32 = True
1439
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1440
+ dtype = getattr(model.config, 'torch_dtype', None)
1441
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1442
+ from unsloth_zoo.utils import _get_dtype
1443
+ dtype = _get_dtype(dtype)
1444
+ float16 = dtype == torch.float16
1445
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1446
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1447
+ if force_float32:
1448
+ args.fp16 = False
1449
+ args.bf16 = False
1450
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1451
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1452
+ args.fp16 = float16
1453
+ args.bf16 = not float16
1454
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1455
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1456
+ args.eval_strategy = 'steps'
1457
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1458
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1459
+ if ga_steps is not None and ga_steps > 1:
1460
+ from transformers import __version__ as transformers_version
1461
+ if Version(transformers_version) <= Version('4.45.2'):
1462
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1463
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1464
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1465
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1466
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1467
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1468
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1469
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1470
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1471
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1472
+ if force_float32:
1473
+ args.bf16_full_eval = False
1474
+ args.fp16_full_eval = False
1475
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1476
+ args.bf16_full_eval = True
1477
+ args.fp16_full_eval = False
1478
+ elif not bf16_full_eval and not fp16_full_eval:
1479
+ args.bf16_full_eval = args.bf16
1480
+ args.fp16_full_eval = args.fp16
1481
+ _output_logits = False
1482
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1483
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1484
+ if _output_logits:
1485
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1486
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1487
+ pass
1488
+ else:
1489
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1490
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1491
+ if args_max_seq_length is None and model_max_seq_length is not None:
1492
+ max_seq_length = model.max_seq_length
1493
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1494
+ if model is not None and hasattr(model, 'for_training'):
1495
+ model.for_training()
1496
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1497
+ if 'processing_class' in locals():
1498
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1499
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1500
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1501
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1502
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1503
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1504
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1505
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1506
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1507
+ else:
1508
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1509
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1510
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1511
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1512
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1513
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1514
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1515
+ else:
1516
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1517
+ other_metrics = []
1518
+
1519
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1520
+ PatchRLStatistics('orpo_trainer', other_metrics)
1521
+
1522
+ super().__init__(
1523
+ model = model,
1524
+ args = args,
1525
+ data_collator = data_collator,
1526
+ train_dataset = train_dataset,
1527
+ eval_dataset = eval_dataset,
1528
+ processing_class = processing_class,
1529
+ model_init = model_init,
1530
+ callbacks = callbacks,
1531
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1532
+ peft_config = peft_config,
1533
+ compute_metrics = compute_metrics,**kwargs)
1534
+ if hasattr(self, 'neftune_hook_handle'):
1535
+ self.neftune_hook_handle.remove()
1536
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1537
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1538
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1539
+ pass
1540
+
1541
+ pass
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ def vLLMSamplingParams(**kwargs):
43
+ from vllm import SamplingParams
44
+ sampling_params = SamplingParams(**kwargs)
45
+ sampling_params._set_kwargs = kwargs
46
+ return sampling_params
47
+ @dataclass
48
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
49
+ """
50
+
51
+ Configuration class for the [`OnlineDPOTrainer`].
52
+
53
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
54
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
55
+ command line.
56
+
57
+ Parameters:
58
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
59
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
60
+ [`~transformers.TrainingArguments`].
61
+ reward_model_path (`str` or `None`, *optional*, defaults to `None`):
62
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
63
+ judge (`str` or `None`, *optional*, defaults to `None`):
64
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
65
+ max_new_tokens (`int`, *optional*, defaults to `64`):
66
+ Maximum number of tokens to generate per completion.
67
+ max_length (`int`, *optional*, defaults to `256`):
68
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
69
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
70
+ possible.
71
+ temperature (`float`, *optional*, defaults to `0.9`):
72
+ Temperature for sampling. The higher the temperature, the more random the completions.
73
+ missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
74
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
75
+ to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
76
+ value.
77
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
78
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
79
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
80
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
81
+ selected for each new epoch and the last β is used for the rest of the epochs.
82
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
83
+ Type of loss to use. Possible values are:
84
+
85
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
86
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
+
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ disable_dropout (`bool`, *optional*, defaults to `True`):
91
+ Whether to disable dropout in the model and reference model.
92
+ use_vllm (`bool`, *optional*, defaults to `False`):
93
+ Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
94
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
95
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
96
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
97
+ capacity of a single GPU, albeit at the cost of slower generation.
98
+
99
+ """
100
+ vllm_sampling_params: Optional[Any] = field(
101
+ default = None,
102
+ metadata = {'help': 'vLLM SamplingParams'},
103
+ )
104
+ unsloth_num_chunks : Optional[int] = field(
105
+ default = -1,
106
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
107
+ )
108
+ def __init__(
109
+ self,
110
+ output_dir = None,
111
+ overwrite_output_dir = None,
112
+ do_train = False,
113
+ do_eval = False,
114
+ do_predict = False,
115
+ eval_strategy = 'no',
116
+ prediction_loss_only = False,
117
+ per_device_train_batch_size = 4,
118
+ per_device_eval_batch_size = 4,
119
+ per_gpu_train_batch_size = None,
120
+ per_gpu_eval_batch_size = None,
121
+ gradient_accumulation_steps = 2,
122
+ eval_accumulation_steps = 2,
123
+ eval_delay = 0,
124
+ torch_empty_cache_steps = 250,
125
+ learning_rate = 5e-05,
126
+ weight_decay = 0.01,
127
+ adam_beta1 = 0.9,
128
+ adam_beta2 = 0.999,
129
+ adam_epsilon = 1e-08,
130
+ max_grad_norm = 1.0,
131
+ num_train_epochs = 3.0,
132
+ max_steps = -1,
133
+ lr_scheduler_type = 'linear',
134
+ warmup_ratio = 0.1,
135
+ warmup_steps = 0,
136
+ log_level = 'passive',
137
+ log_level_replica = 'warning',
138
+ log_on_each_node = True,
139
+ logging_dir = None,
140
+ logging_strategy = 'steps',
141
+ logging_first_step = False,
142
+ logging_steps = 1,
143
+ logging_nan_inf_filter = False,
144
+ save_strategy = 'steps',
145
+ save_steps = 500,
146
+ save_total_limit = None,
147
+ save_safetensors = True,
148
+ save_on_each_node = False,
149
+ save_only_model = False,
150
+ restore_callback_states_from_checkpoint = False,
151
+ no_cuda = False,
152
+ use_cpu = False,
153
+ use_mps_device = False,
154
+ seed = 3407,
155
+ data_seed = 3407,
156
+ jit_mode_eval = False,
157
+ use_ipex = False,
158
+ bf16 = False,
159
+ fp16 = False,
160
+ fp16_opt_level = 'O1',
161
+ half_precision_backend = 'auto',
162
+ bf16_full_eval = False,
163
+ fp16_full_eval = False,
164
+ tf32 = None,
165
+ local_rank = -1,
166
+ ddp_backend = None,
167
+ tpu_num_cores = None,
168
+ tpu_metrics_debug = False,
169
+ debug = '',
170
+ dataloader_drop_last = False,
171
+ eval_steps = None,
172
+ dataloader_num_workers = 0,
173
+ dataloader_prefetch_factor = None,
174
+ past_index = -1,
175
+ run_name = None,
176
+ disable_tqdm = None,
177
+ remove_unused_columns = True,
178
+ label_names = None,
179
+ load_best_model_at_end = False,
180
+ metric_for_best_model = None,
181
+ greater_is_better = None,
182
+ ignore_data_skip = False,
183
+ fsdp = '',
184
+ fsdp_min_num_params = 0,
185
+ fsdp_config = None,
186
+ fsdp_transformer_layer_cls_to_wrap = None,
187
+ accelerator_config = None,
188
+ deepspeed = None,
189
+ label_smoothing_factor = 0.0,
190
+ optim = 'adamw_8bit',
191
+ optim_args = None,
192
+ adafactor = False,
193
+ group_by_length = False,
194
+ length_column_name = 'length',
195
+ report_to = None,
196
+ ddp_find_unused_parameters = None,
197
+ ddp_bucket_cap_mb = None,
198
+ ddp_broadcast_buffers = None,
199
+ dataloader_pin_memory = True,
200
+ dataloader_persistent_workers = False,
201
+ skip_memory_metrics = True,
202
+ use_legacy_prediction_loop = False,
203
+ push_to_hub = False,
204
+ resume_from_checkpoint = None,
205
+ hub_model_id = None,
206
+ hub_strategy = 'every_save',
207
+ hub_token = None,
208
+ hub_private_repo = None,
209
+ hub_always_push = False,
210
+ gradient_checkpointing = False,
211
+ gradient_checkpointing_kwargs = None,
212
+ include_inputs_for_metrics = False,
213
+ eval_do_concat_batches = True,
214
+ fp16_backend = 'auto',
215
+ evaluation_strategy = None,
216
+ push_to_hub_model_id = None,
217
+ push_to_hub_organization = None,
218
+ push_to_hub_token = None,
219
+ mp_parameters = '',
220
+ auto_find_batch_size = False,
221
+ full_determinism = False,
222
+ torchdynamo = None,
223
+ ray_scope = 'last',
224
+ ddp_timeout = 1800,
225
+ torch_compile = False,
226
+ torch_compile_backend = None,
227
+ torch_compile_mode = None,
228
+ dispatch_batches = None,
229
+ split_batches = None,
230
+ include_tokens_per_second = False,
231
+ include_num_input_tokens_seen = False,
232
+ neftune_noise_alpha = None,
233
+ optim_target_modules = None,
234
+ batch_eval_metrics = False,
235
+ eval_on_start = False,
236
+ use_liger_kernel = False,
237
+ eval_use_gather_object = False,
238
+ average_tokens_across_devices = False,
239
+ reward_model_path = None,
240
+ judge = None,
241
+ max_new_tokens = 64,
242
+ max_length = 512,
243
+ temperature = 0.9,
244
+ missing_eos_penalty = None,
245
+ loss_type = 'sigmoid',
246
+ dataset_num_proc = None,
247
+ disable_dropout = True,
248
+ use_vllm = False,
249
+ ds3_gather_for_generation = True,
250
+ vllm_sampling_params = None,
251
+ unsloth_num_chunks = -1,
252
+ **kwargs,
253
+ ):
254
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
255
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
256
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
257
+ output_dir = 'unsloth_training_checkpoints'
258
+ save_strategy = 'no'
259
+ if dataset_num_proc is None:
260
+ from multiprocessing import cpu_count
261
+ dataset_num_proc = cpu_count()
262
+
263
+ super().__init__(
264
+ output_dir = output_dir,
265
+ overwrite_output_dir = overwrite_output_dir,
266
+ do_train = do_train,
267
+ do_eval = do_eval,
268
+ do_predict = do_predict,
269
+ eval_strategy = eval_strategy,
270
+ prediction_loss_only = prediction_loss_only,
271
+ per_device_train_batch_size = per_device_train_batch_size,
272
+ per_device_eval_batch_size = per_device_eval_batch_size,
273
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
274
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
275
+ gradient_accumulation_steps = gradient_accumulation_steps,
276
+ eval_accumulation_steps = eval_accumulation_steps,
277
+ eval_delay = eval_delay,
278
+ torch_empty_cache_steps = torch_empty_cache_steps,
279
+ learning_rate = learning_rate,
280
+ weight_decay = weight_decay,
281
+ adam_beta1 = adam_beta1,
282
+ adam_beta2 = adam_beta2,
283
+ adam_epsilon = adam_epsilon,
284
+ max_grad_norm = max_grad_norm,
285
+ num_train_epochs = num_train_epochs,
286
+ max_steps = max_steps,
287
+ lr_scheduler_type = lr_scheduler_type,
288
+ warmup_ratio = warmup_ratio,
289
+ warmup_steps = warmup_steps,
290
+ log_level = log_level,
291
+ log_level_replica = log_level_replica,
292
+ log_on_each_node = log_on_each_node,
293
+ logging_dir = logging_dir,
294
+ logging_strategy = logging_strategy,
295
+ logging_first_step = logging_first_step,
296
+ logging_steps = logging_steps,
297
+ logging_nan_inf_filter = logging_nan_inf_filter,
298
+ save_strategy = save_strategy,
299
+ save_steps = save_steps,
300
+ save_total_limit = save_total_limit,
301
+ save_safetensors = save_safetensors,
302
+ save_on_each_node = save_on_each_node,
303
+ save_only_model = save_only_model,
304
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
305
+ no_cuda = no_cuda,
306
+ use_cpu = use_cpu,
307
+ use_mps_device = use_mps_device,
308
+ seed = seed,
309
+ data_seed = data_seed,
310
+ jit_mode_eval = jit_mode_eval,
311
+ use_ipex = use_ipex,
312
+ bf16 = bf16,
313
+ fp16 = fp16,
314
+ fp16_opt_level = fp16_opt_level,
315
+ half_precision_backend = half_precision_backend,
316
+ bf16_full_eval = bf16_full_eval,
317
+ fp16_full_eval = fp16_full_eval,
318
+ tf32 = tf32,
319
+ local_rank = local_rank,
320
+ ddp_backend = ddp_backend,
321
+ tpu_num_cores = tpu_num_cores,
322
+ tpu_metrics_debug = tpu_metrics_debug,
323
+ debug = debug,
324
+ dataloader_drop_last = dataloader_drop_last,
325
+ eval_steps = eval_steps,
326
+ dataloader_num_workers = dataloader_num_workers,
327
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
328
+ past_index = past_index,
329
+ run_name = run_name,
330
+ disable_tqdm = disable_tqdm,
331
+ remove_unused_columns = remove_unused_columns,
332
+ label_names = label_names,
333
+ load_best_model_at_end = load_best_model_at_end,
334
+ metric_for_best_model = metric_for_best_model,
335
+ greater_is_better = greater_is_better,
336
+ ignore_data_skip = ignore_data_skip,
337
+ fsdp = fsdp,
338
+ fsdp_min_num_params = fsdp_min_num_params,
339
+ fsdp_config = fsdp_config,
340
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
341
+ accelerator_config = accelerator_config,
342
+ deepspeed = deepspeed,
343
+ label_smoothing_factor = label_smoothing_factor,
344
+ optim = optim,
345
+ optim_args = optim_args,
346
+ adafactor = adafactor,
347
+ group_by_length = group_by_length,
348
+ length_column_name = length_column_name,
349
+ report_to = report_to,
350
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
351
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
352
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
353
+ dataloader_pin_memory = dataloader_pin_memory,
354
+ dataloader_persistent_workers = dataloader_persistent_workers,
355
+ skip_memory_metrics = skip_memory_metrics,
356
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
357
+ push_to_hub = push_to_hub,
358
+ resume_from_checkpoint = resume_from_checkpoint,
359
+ hub_model_id = hub_model_id,
360
+ hub_strategy = hub_strategy,
361
+ hub_token = hub_token,
362
+ hub_private_repo = hub_private_repo,
363
+ hub_always_push = hub_always_push,
364
+ gradient_checkpointing = gradient_checkpointing,
365
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
366
+ include_inputs_for_metrics = include_inputs_for_metrics,
367
+ eval_do_concat_batches = eval_do_concat_batches,
368
+ fp16_backend = fp16_backend,
369
+ evaluation_strategy = evaluation_strategy,
370
+ push_to_hub_model_id = push_to_hub_model_id,
371
+ push_to_hub_organization = push_to_hub_organization,
372
+ push_to_hub_token = push_to_hub_token,
373
+ mp_parameters = mp_parameters,
374
+ auto_find_batch_size = auto_find_batch_size,
375
+ full_determinism = full_determinism,
376
+ torchdynamo = torchdynamo,
377
+ ray_scope = ray_scope,
378
+ ddp_timeout = ddp_timeout,
379
+ torch_compile = torch_compile,
380
+ torch_compile_backend = torch_compile_backend,
381
+ torch_compile_mode = torch_compile_mode,
382
+ dispatch_batches = dispatch_batches,
383
+ split_batches = split_batches,
384
+ include_tokens_per_second = include_tokens_per_second,
385
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
386
+ neftune_noise_alpha = neftune_noise_alpha,
387
+ optim_target_modules = optim_target_modules,
388
+ batch_eval_metrics = batch_eval_metrics,
389
+ eval_on_start = eval_on_start,
390
+ use_liger_kernel = use_liger_kernel,
391
+ eval_use_gather_object = eval_use_gather_object,
392
+ average_tokens_across_devices = average_tokens_across_devices,
393
+ reward_model_path = reward_model_path,
394
+ judge = judge,
395
+ max_new_tokens = max_new_tokens,
396
+ max_length = max_length,
397
+ temperature = temperature,
398
+ missing_eos_penalty = missing_eos_penalty,
399
+ loss_type = loss_type,
400
+ dataset_num_proc = dataset_num_proc,
401
+ disable_dropout = disable_dropout,
402
+ use_vllm = use_vllm,
403
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
404
+ self.vllm_sampling_params = vllm_sampling_params
405
+ self.unsloth_num_chunks = unsloth_num_chunks
406
+ pass
407
+
408
+ class _UnslothOnlineDPOTrainer(Trainer):
409
+ r""""""
410
+
411
+ _tag_names = ["trl", "online-dpo"]
412
+
413
+ def __init__(
414
+ self,
415
+ model: Union[PreTrainedModel, nn.Module],
416
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
417
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
418
+ judge: Optional[BasePairwiseJudge] = None,
419
+ args: Optional[OnlineDPOConfig] = None,
420
+ data_collator: Optional[DataCollator] = None,
421
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
422
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
423
+ processing_class: Optional[
424
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
425
+ ] = None,
426
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
427
+ peft_config: Optional[dict] = None,
428
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
429
+ callbacks: Optional[list[TrainerCallback]] = None,
430
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
431
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
432
+ ) -> None:
433
+
434
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
435
+ if ref_model is model:
436
+ raise ValueError(
437
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
438
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
439
+ )
440
+
441
+ self.ref_model = ref_model
442
+
443
+ if reward_model is not None and judge is not None:
444
+ warnings.warn(
445
+ "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
446
+ "Ignoring `judge` and using `reward_model`.",
447
+ UserWarning,
448
+ )
449
+ judge = None
450
+ elif reward_model is None and judge is None:
451
+ raise ValueError("Either `reward_model` or `judge` must be provided.")
452
+
453
+ self.reward_model = reward_model
454
+ self.reward_processing_class = reward_processing_class
455
+ self.judge = judge
456
+
457
+ if args.missing_eos_penalty is not None and judge is not None:
458
+ raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
459
+
460
+ if args is None:
461
+ raise ValueError("`args` must be provided.")
462
+
463
+ # Check that the processing_class is provided
464
+ if processing_class is None:
465
+ raise ValueError("`processing_class` must be provided.")
466
+
467
+ # Convert to PEFT model if peft_config is provided
468
+ if False:
469
+ # Check if PEFT is available
470
+ if not is_peft_available():
471
+ raise ImportError(
472
+ "PEFT is not available and passed `peft_config`. Please install PEFT with "
473
+ "`pip install peft` to use it."
474
+ )
475
+
476
+ # If the model is already a PeftModel, we need to merge and unload it.
477
+ # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
478
+ if isinstance(model, PeftModel):
479
+ model = model.merge_and_unload()
480
+
481
+ # Get peft model with the given config
482
+ model = model
483
+
484
+ # Disable dropout in the model and reference model
485
+ if args.disable_dropout:
486
+ disable_dropout_in_model(model)
487
+ if self.ref_model is not None:
488
+ disable_dropout_in_model(self.ref_model)
489
+
490
+ # Handle the ref_model
491
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
492
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
493
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
494
+ if ref_model is None: # No ref model provided, the most common case
495
+ if False:
496
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
497
+ else:
498
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
499
+ else: # rare case, the user provided a ref model
500
+ self.ref_model = ref_model
501
+ self.ref_model.eval()
502
+
503
+ # Disable the gradient and set the reward model in eval mode
504
+ if self.reward_model is not None:
505
+ self.reward_model.eval()
506
+
507
+ # Define the collator is not provided
508
+ if data_collator is None:
509
+ data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
510
+
511
+ self.max_length = args.max_length
512
+
513
+ self.stats = {
514
+ "objective/kl": [],
515
+ "objective/entropy": [],
516
+ "objective/non_score_reward": [],
517
+ "rewards/chosen": [],
518
+ "rewards/rejected": [],
519
+ "rewards/accuracies": [],
520
+ "rewards/margins": [],
521
+ "logps/chosen": [],
522
+ "logps/rejected": [],
523
+ "val/contain_eos_token": [],
524
+ "beta": [],
525
+ }
526
+ if self.reward_model is not None:
527
+ self.stats["objective/rlhf_reward"] = []
528
+ self.stats["objective/scores_margin"] = []
529
+ self.stats["objective/scores"] = []
530
+
531
+ if args.use_vllm:
532
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
533
+ n=2, max_tokens=args.max_new_tokens,
534
+ temperature=args.temperature,
535
+ top_k=50,
536
+ top_p=1.0,
537
+ detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
538
+ else:
539
+ self.generation_config = GenerationConfig(
540
+ max_new_tokens=args.max_new_tokens,
541
+ temperature=args.temperature,
542
+ top_k=50,
543
+ top_p=1.0,
544
+ do_sample=True,
545
+ use_cache=False if args.gradient_checkpointing else True,
546
+ )
547
+
548
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
549
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
550
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
551
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
552
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
553
+ # that the warning has already been issued.
554
+ model.warnings_issued["estimate_tokens"] = True
555
+
556
+ super().__init__(
557
+ model=model,
558
+ args=args,
559
+ data_collator=data_collator,
560
+ train_dataset=train_dataset,
561
+ eval_dataset=eval_dataset,
562
+ processing_class=processing_class,
563
+ compute_metrics=compute_metrics,
564
+ callbacks=callbacks,
565
+ optimizers=optimizers,
566
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
567
+ )
568
+
569
+ # Add tags for models that have been loaded with the correct transformers version
570
+ if hasattr(self.model, "add_model_tags"):
571
+ self.model.add_model_tags(self._tag_names)
572
+
573
+ self._beta = args.beta
574
+
575
+ # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
576
+ if self.is_deepspeed_enabled:
577
+ if self.reward_model is not None:
578
+ self.reward_model = prepare_deepspeed(
579
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
580
+ )
581
+ if self.ref_model is not None:
582
+ self.ref_model = prepare_deepspeed(
583
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
584
+ )
585
+ else:
586
+ if self.ref_model is not None:
587
+ self.ref_model = self.ref_model.to(self.accelerator.device)
588
+ if self.reward_model is not None:
589
+ self.reward_model = self.reward_model.to(self.accelerator.device)
590
+
591
+ @property
592
+ def beta(self):
593
+ if isinstance(self._beta, list):
594
+ epoch = self.state.epoch
595
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
596
+ else:
597
+ return self._beta
598
+
599
+ @staticmethod
600
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
601
+ """Tokenize a single row from a DPO specific dataset."""
602
+ if not is_encoder_decoder:
603
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
604
+ # Add BOS token to head of prompt. Avoid adding if it's already there
605
+ if tokenizer.bos_token_id is not None:
606
+ prompt_len_input_ids = len(batch["input_ids"])
607
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
608
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
609
+ batch["attention_mask"] = [1] + batch["attention_mask"]
610
+ else:
611
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
612
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
613
+ return batch
614
+
615
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
616
+ @wraps(Trainer.get_train_dataloader)
617
+ def get_train_dataloader(self) -> DataLoader:
618
+ if self.train_dataset is None:
619
+ raise ValueError("Trainer: training requires a train_dataset.")
620
+
621
+ train_dataset = self.train_dataset
622
+ data_collator = self.data_collator
623
+ dataloader_params = {
624
+ "batch_size": self._train_batch_size,
625
+ "collate_fn": data_collator,
626
+ "num_workers": self.args.dataloader_num_workers,
627
+ "pin_memory": self.args.dataloader_pin_memory,
628
+ "persistent_workers": self.args.dataloader_persistent_workers,
629
+ }
630
+
631
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
632
+ dataloader_params["sampler"] = self._get_train_sampler()
633
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
634
+ dataloader_params["worker_init_fn"] = seed_worker
635
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
636
+
637
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
638
+
639
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
640
+ @wraps(Trainer.get_eval_dataloader)
641
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
642
+ if eval_dataset is None and self.eval_dataset is None:
643
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
644
+
645
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
646
+ # don't change during training
647
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
648
+ if (
649
+ hasattr(self, "_eval_dataloaders")
650
+ and dataloader_key in self._eval_dataloaders
651
+ and self.args.dataloader_persistent_workers
652
+ ):
653
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
654
+
655
+ eval_dataset = (
656
+ self.eval_dataset[eval_dataset]
657
+ if isinstance(eval_dataset, str)
658
+ else eval_dataset
659
+ if eval_dataset is not None
660
+ else self.eval_dataset
661
+ )
662
+ data_collator = self.data_collator
663
+
664
+ dataloader_params = {
665
+ "batch_size": self.args.eval_batch_size,
666
+ "collate_fn": data_collator,
667
+ "num_workers": self.args.dataloader_num_workers,
668
+ "pin_memory": self.args.dataloader_pin_memory,
669
+ "persistent_workers": self.args.dataloader_persistent_workers,
670
+ }
671
+
672
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
673
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
674
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
675
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
676
+
677
+ # accelerator.free_memory() will destroy the references, so
678
+ # we need to store the non-prepared version
679
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
680
+ if self.args.dataloader_persistent_workers:
681
+ if hasattr(self, "_eval_dataloaders"):
682
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
683
+ else:
684
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
685
+
686
+ return self.accelerator.prepare(eval_dataloader)
687
+
688
+ def _generate_vllm(self, model, prompts):
689
+ eos_token_id = self.processing_class.eos_token_id
690
+ pad_token_id = self.processing_class.pad_token_id
691
+
692
+ # Load the latest weights
693
+
694
+ pass
695
+
696
+ pass
697
+
698
+ if is_conversational({"prompt": prompts[0]}):
699
+ outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
700
+ else:
701
+ outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
702
+
703
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
704
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
705
+
706
+ # Create mask and pad the prompt and completion
707
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
708
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
709
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
710
+ max_tokens = self.generation_config.max_tokens
711
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
712
+ completion_ids = [
713
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
714
+ for ids in completion_ids
715
+ ]
716
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
717
+
718
+ # Convert to tensors
719
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
720
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
721
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
722
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
723
+
724
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
725
+
726
+ def _generate(self, model, prompts):
727
+ eos_token_id = self.processing_class.eos_token_id
728
+ pad_token_id = self.processing_class.pad_token_id
729
+
730
+ # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
731
+ # policies with different tokenizers / chat templates.
732
+ inputs = [{"prompt": prompt} for prompt in prompts]
733
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
734
+ inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
735
+ inputs = self.data_collator(inputs)
736
+
737
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
738
+ inputs = self._prepare_inputs(inputs)
739
+ prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
740
+ prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
741
+ with unwrap_model_for_generation(
742
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
743
+ ) as unwrapped_model:
744
+ output = unwrapped_model.generate(
745
+ input_ids=prompt_ids,
746
+ attention_mask=prompt_mask,
747
+ generation_config=self.generation_config,
748
+ )
749
+
750
+ completion_ids = output[:, prompt_ids.size(1) :]
751
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
752
+
753
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
754
+
755
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
756
+ # Get the number of tokens to truncate from prompt
757
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
758
+
759
+ # Truncate left to avoid oom
760
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
761
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
762
+
763
+ # Concat the prompt and completion
764
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
765
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
766
+
767
+ # Get the logprobs of the completions from the model
768
+ output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
769
+
770
+ # There is 1 offset, because the model predict the next token
771
+ logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
772
+
773
+ # Take the completion tokens logprob
774
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
775
+ return logprobs
776
+
777
+ def training_step(
778
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
779
+ ) -> torch.Tensor:
780
+ model.train()
781
+
782
+ prompts = inputs["prompt"]
783
+ batch_size = len(prompts)
784
+
785
+ if self.args.use_vllm:
786
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
787
+ else:
788
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
789
+
790
+ contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
791
+
792
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
793
+ with torch.no_grad():
794
+ if self.ref_model is not None:
795
+ ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
796
+ else: # peft case: we just need to disable the adapter
797
+ with self.model.disable_adapter():
798
+ ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
799
+
800
+ # Decode the completions, and format them if the input is conversational
801
+ device = logprobs.device
802
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
803
+ if is_conversational({"prompt": prompts[0]}):
804
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
805
+
806
+ # Get the reward from the reward model or judge
807
+ if self.judge is not None:
808
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
809
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
810
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
811
+ # template to it.
812
+ if is_conversational({"prompt": prompts[0]}):
813
+ environment = jinja2.Environment()
814
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
815
+ prompts = [template.render(messages=prompt) for prompt in prompts]
816
+ completions = [template.render(messages=completion) for completion in completions]
817
+
818
+ ranks_of_first_completion = self.judge.judge(
819
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
820
+ )
821
+
822
+ # convert ranks to a True/False mask:
823
+ # when rank == 0, it means the first completion is the best
824
+ # when rank == 1, it means the second completion is the best
825
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
826
+ else:
827
+ # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
828
+ # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
829
+ prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
830
+ if is_conversational({"prompt": prompts[0]}):
831
+ examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
832
+ examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
833
+ prompts = [example["prompt"] for example in examples]
834
+ completions = [example["completion"] for example in examples]
835
+
836
+ # Tokenize the prompts
837
+ prompts_ids = self.reward_processing_class(
838
+ prompts, padding=True, return_tensors="pt", padding_side="left"
839
+ )["input_ids"].to(device)
840
+ context_length = prompts_ids.shape[1]
841
+
842
+ # Tokenize the completions
843
+ completions_ids = self.reward_processing_class(
844
+ completions, padding=True, return_tensors="pt", padding_side="right"
845
+ )["input_ids"].to(device)
846
+
847
+ # Concatenate the prompts and completions and get the reward
848
+ prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
849
+ with torch.inference_mode():
850
+ _, scores, _ = get_reward(
851
+ self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
852
+ )
853
+
854
+ # Filter completion. Ensure that the sample contains stop_token_id
855
+ # Completions not passing that filter will receive a lower score.
856
+ if self.args.missing_eos_penalty is not None:
857
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
858
+
859
+ # Split the scores in 2 (the prompts of the first half are the same as the second half)
860
+ first_half, second_half = scores.split(batch_size)
861
+
862
+ # Get the indices of the chosen and rejected examples
863
+ mask = first_half >= second_half
864
+
865
+ batch_range = torch.arange(batch_size, device=device)
866
+ chosen_indices = batch_range + (~mask * batch_size)
867
+ rejected_indices = batch_range + (mask * batch_size)
868
+
869
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
870
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
871
+ cr_logprobs = logprobs[cr_indices]
872
+ cr_ref_logprobs = ref_logprobs[cr_indices]
873
+
874
+ # mask out the padding tokens
875
+ padding_mask = ~completion_mask.bool()
876
+ cr_padding_mask = padding_mask[cr_indices]
877
+
878
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
879
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
880
+
881
+ # Split the chosen and rejected examples
882
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
883
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
884
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
885
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
886
+
887
+ logits = pi_logratios - ref_logratios
888
+
889
+ if self.args.loss_type == "sigmoid":
890
+ losses = -F.logsigmoid(self.beta * logits)
891
+ elif self.args.loss_type == "ipo":
892
+ losses = (logits - 1 / (2 * self.beta)) ** 2
893
+ else:
894
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
895
+
896
+ loss = losses.mean()
897
+
898
+ # Log everything
899
+ if self.reward_model is not None:
900
+ scores_margin = scores[chosen_indices] - scores[rejected_indices]
901
+ self.stats["objective/scores_margin"].append(
902
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
903
+ )
904
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
905
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
906
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
907
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
908
+
909
+ kl = logprobs - ref_logprobs
910
+ mean_kl = kl.sum(1).mean()
911
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
912
+ non_score_reward = (-self.beta * kl).sum(1)
913
+ mean_non_score_reward = non_score_reward.mean()
914
+ self.stats["objective/non_score_reward"].append(
915
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
916
+ )
917
+ if self.reward_model is not None:
918
+ rlhf_reward = scores + non_score_reward
919
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
920
+ mean_entropy = -logprobs.sum(1).mean()
921
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
922
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
923
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
924
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
925
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
926
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
927
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
928
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
929
+ self.stats["rewards/margins"].append(margin.mean().item())
930
+ accuracy = margin > 0
931
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
932
+ self.stats["beta"].append(self.beta)
933
+
934
+ if (
935
+ self.args.torch_empty_cache_steps is not None
936
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
937
+ ):
938
+ empty_cache()
939
+
940
+ kwargs = {}
941
+
942
+ # For LOMO optimizers you need to explicitly use the learnign rate
943
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
944
+ kwargs["learning_rate"] = self._get_learning_rate()
945
+
946
+ if self.args.n_gpu > 1:
947
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
948
+
949
+ if self.use_apex:
950
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
951
+ scaled_loss.backward()
952
+ else:
953
+ self.accelerator.backward(loss, **kwargs)
954
+
955
+ return loss.detach() / self.args.gradient_accumulation_steps
956
+
957
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
958
+ # start_time defaults to None to allow compatibility with transformers<=4.46
959
+ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
960
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
961
+ logs: dict[str, float] = {}
962
+
963
+ # all_gather + mean() to get average loss over all processes
964
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
965
+
966
+ # reset tr_loss to zero
967
+ tr_loss -= tr_loss
968
+
969
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
970
+ if grad_norm is not None:
971
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
972
+ logs["learning_rate"] = self._get_learning_rate()
973
+
974
+ # Add our metrics
975
+ for key, val in self.stats.items():
976
+ logs[key] = sum(val) / len(val)
977
+ self.stats = {key: [] for key in self.stats} # reset stats
978
+
979
+ self._total_loss_scalar += tr_loss_scalar
980
+ self._globalstep_last_logged = self.state.global_step
981
+ self.store_flos()
982
+
983
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
984
+ self.log(logs, start_time)
985
+ else: # transformers<=4.46
986
+ self.log(logs)
987
+
988
+ metrics = None
989
+ if self.control.should_evaluate:
990
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
991
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
992
+
993
+ if self.args.save_strategy == "best":
994
+ self.control.should_save = is_new_best_metric
995
+
996
+ if self.control.should_save:
997
+ self._save_checkpoint(model, trial)
998
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
999
+
1000
+ # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
1001
+ # This can be removed once the minimum transformers version is updated to 4.47.
1002
+ # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
1003
+ def _determine_best_metric(self, metrics, trial):
1004
+ """
1005
+ Determine if the model should be saved based on the evaluation metrics.
1006
+ If args.metric_for_best_model is not set, the loss is used.
1007
+ Returns:
1008
+ bool: True if a new best metric was found, else False
1009
+ """
1010
+ is_new_best_metric = False
1011
+
1012
+ if self.args.metric_for_best_model is not None:
1013
+ metric_to_check = self.args.metric_for_best_model
1014
+
1015
+ if not metric_to_check.startswith("eval_"):
1016
+ metric_to_check = f"eval_{metric_to_check}"
1017
+
1018
+ try:
1019
+ metric_value = metrics[metric_to_check]
1020
+ except KeyError as exc:
1021
+ raise KeyError(
1022
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1023
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1024
+ ) from exc
1025
+
1026
+ operator = np.greater if self.args.greater_is_better else np.less
1027
+
1028
+ if self.state.best_metric is None:
1029
+ self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1030
+
1031
+ if operator(metric_value, self.state.best_metric):
1032
+ run_dir = self._get_output_dir(trial=trial)
1033
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1034
+ output_dir = os.path.join(run_dir, checkpoint_folder)
1035
+ self.state.best_metric = metric_value
1036
+ self.state.best_model_checkpoint = output_dir
1037
+
1038
+ is_new_best_metric = True
1039
+
1040
+ return is_new_best_metric
1041
+
1042
+ def create_model_card(
1043
+ self,
1044
+ model_name: Optional[str] = None,
1045
+ dataset_name: Optional[str] = None,
1046
+ tags: Union[str, list[str], None] = None,
1047
+ ):
1048
+ """
1049
+ Creates a draft of a model card using the information available to the `Trainer`.
1050
+
1051
+ Args:
1052
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1053
+ Name of the model.
1054
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1055
+ Name of the dataset used for training.
1056
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1057
+ Tags to be associated with the model card.
1058
+ """
1059
+ if not self.is_world_process_zero():
1060
+ return
1061
+
1062
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1063
+ base_model = self.model.config._name_or_path
1064
+ else:
1065
+ base_model = None
1066
+
1067
+ tags = tags or []
1068
+ if isinstance(tags, str):
1069
+ tags = [tags]
1070
+
1071
+ if hasattr(self.model.config, "unsloth_version"):
1072
+ tags.append("unsloth")
1073
+
1074
+ citation = textwrap.dedent("""\
1075
+ @article{guo2024direct,
1076
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
1077
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1078
+ year = 2024,
1079
+ eprint = {arXiv:2402.04792}
1080
+ }""")
1081
+
1082
+ model_card = generate_model_card(
1083
+ base_model=base_model,
1084
+ model_name=model_name,
1085
+ hub_model_id=self.hub_model_id,
1086
+ dataset_name=dataset_name,
1087
+ tags=tags,
1088
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1089
+ comet_url=get_comet_experiment_url(),
1090
+ trainer_name="Online DPO",
1091
+ trainer_citation=citation,
1092
+ paper_title="Direct Language Model Alignment from Online AI Feedback",
1093
+ paper_id="2402.04792",
1094
+ )
1095
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1096
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1097
+ """
1098
+
1099
+ Initialize OnlineDPOTrainer.
1100
+
1101
+ Args:
1102
+ model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1103
+ The model to train, preferably an `AutoModelForCausalLM`.
1104
+ ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1105
+ The reference model to use for training. If None is specified, the reference model will be created from
1106
+ the model.
1107
+ reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1108
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1109
+ judge (`BasePairwiseJudge`):
1110
+ The judge to use for pairwise comparison of model completions.
1111
+ args (`OnlineDPOConfig`):
1112
+ The online DPO config arguments to use for training.
1113
+ data_collator (`transformers.DataCollator`):
1114
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1115
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1116
+ train_dataset (`datasets.Dataset`):
1117
+ The dataset to use for training.
1118
+ eval_dataset (`datasets.Dataset`):
1119
+ The dataset to use for evaluation.
1120
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1121
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1122
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1123
+ reuse the fine-tuned model.
1124
+ peft_config (`dict`):
1125
+ The peft config to use for training.
1126
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1127
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1128
+ a dictionary string to metric values.
1129
+ callbacks (`list[transformers.TrainerCallback]`):
1130
+ The callbacks to use for training.
1131
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1132
+ The optimizer and scheduler to use for training.
1133
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1134
+ The function to use to preprocess the logits before computing the metrics.
1135
+
1136
+ """
1137
+ def __init__(
1138
+ self,
1139
+ model,
1140
+ ref_model = None,
1141
+ reward_model = None,
1142
+ judge = None,
1143
+ args = None,
1144
+ data_collator = None,
1145
+ train_dataset = None,
1146
+ eval_dataset = None,
1147
+ processing_class = None,
1148
+ reward_processing_class = None,
1149
+ peft_config = None,
1150
+ compute_metrics = None,
1151
+ callbacks = None,
1152
+ preprocess_logits_for_metrics = None,
1153
+ **kwargs
1154
+ ):
1155
+ if args is None: args = UnslothOnlineDPOConfig()
1156
+ use_bf16 = getattr(args, 'bf16', False)
1157
+ use_fp16 = getattr(args, 'fp16', False)
1158
+ force_float32 = False
1159
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1160
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1161
+ force_float32 = True
1162
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1163
+ dtype = getattr(model.config, 'torch_dtype', None)
1164
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1165
+ from unsloth_zoo.utils import _get_dtype
1166
+ dtype = _get_dtype(dtype)
1167
+ float16 = dtype == torch.float16
1168
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1169
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1170
+ if force_float32:
1171
+ args.fp16 = False
1172
+ args.bf16 = False
1173
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1174
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1175
+ args.fp16 = float16
1176
+ args.bf16 = not float16
1177
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1178
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1179
+ args.eval_strategy = 'steps'
1180
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1181
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1182
+ if ga_steps is not None and ga_steps > 1:
1183
+ from transformers import __version__ as transformers_version
1184
+ if Version(transformers_version) <= Version('4.45.2'):
1185
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1186
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1187
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1188
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1189
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1190
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1191
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1192
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1193
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1194
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1195
+ if force_float32:
1196
+ args.bf16_full_eval = False
1197
+ args.fp16_full_eval = False
1198
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1199
+ args.bf16_full_eval = True
1200
+ args.fp16_full_eval = False
1201
+ elif not bf16_full_eval and not fp16_full_eval:
1202
+ args.bf16_full_eval = args.bf16
1203
+ args.fp16_full_eval = args.fp16
1204
+ _output_logits = False
1205
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1206
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1207
+ if _output_logits:
1208
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1209
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1210
+ pass
1211
+ else:
1212
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1213
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1214
+ if args_max_seq_length is None and model_max_seq_length is not None:
1215
+ max_seq_length = model.max_seq_length
1216
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1217
+ if model is not None and hasattr(model, 'for_training'):
1218
+ model.for_training()
1219
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1220
+ if 'processing_class' in locals():
1221
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1222
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1223
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1224
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1225
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1226
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1227
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1228
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1229
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1230
+ else:
1231
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1232
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1233
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1234
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1235
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1236
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1237
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1238
+ else:
1239
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1240
+ other_metrics = []
1241
+
1242
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1243
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
1244
+
1245
+ super().__init__(
1246
+ model = model,
1247
+ ref_model = ref_model,
1248
+ reward_model = reward_model,
1249
+ judge = judge,
1250
+ args = args,
1251
+ data_collator = data_collator,
1252
+ train_dataset = train_dataset,
1253
+ eval_dataset = eval_dataset,
1254
+ processing_class = processing_class,
1255
+ reward_processing_class = reward_processing_class,
1256
+ peft_config = peft_config,
1257
+ compute_metrics = compute_metrics,
1258
+ callbacks = callbacks,
1259
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1260
+ if hasattr(self, 'neftune_hook_handle'):
1261
+ self.neftune_hook_handle.remove()
1262
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1263
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1264
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1265
+ pass
1266
+
1267
+ pass
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPPOConfig(PPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`PPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
58
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
59
+ ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
60
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
61
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
62
+ Number of epochs to train.
63
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
64
+ Whether to whiten the rewards.
65
+ kl_coef (`float`, *optional*, defaults to `0.05`):
66
+ KL coefficient.
67
+ cliprange (`float`, *optional*, defaults to `0.2`):
68
+ Clip range.
69
+ vf_coef (`float`, *optional*, defaults to `0.1`):
70
+ Value function coefficient.
71
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
72
+ Clip range for the value function.
73
+ gamma (`float`, *optional*, defaults to `1.0`):
74
+ Discount factor.
75
+ lam (`float`, *optional*, defaults to `0.95`):
76
+ Lambda value for GAE.
77
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
78
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
79
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
80
+ capacity of a single GPU, albeit at the cost of slower generation.
81
+
82
+ """
83
+ vllm_sampling_params: Optional[Any] = field(
84
+ default = None,
85
+ metadata = {'help': 'vLLM SamplingParams'},
86
+ )
87
+ unsloth_num_chunks : Optional[int] = field(
88
+ default = -1,
89
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
+ )
91
+ def __init__(
92
+ self,
93
+ output_dir = None,
94
+ overwrite_output_dir = None,
95
+ do_train = False,
96
+ do_eval = False,
97
+ do_predict = False,
98
+ eval_strategy = 'no',
99
+ prediction_loss_only = False,
100
+ per_device_train_batch_size = 4,
101
+ per_device_eval_batch_size = 4,
102
+ per_gpu_train_batch_size = None,
103
+ per_gpu_eval_batch_size = None,
104
+ gradient_accumulation_steps = 2,
105
+ eval_accumulation_steps = 2,
106
+ eval_delay = 0,
107
+ torch_empty_cache_steps = 250,
108
+ learning_rate = 5e-05,
109
+ weight_decay = 0.01,
110
+ adam_beta1 = 0.9,
111
+ adam_beta2 = 0.999,
112
+ adam_epsilon = 1e-08,
113
+ max_grad_norm = 1.0,
114
+ num_train_epochs = 3.0,
115
+ max_steps = -1,
116
+ lr_scheduler_type = 'linear',
117
+ warmup_ratio = 0.1,
118
+ warmup_steps = 0,
119
+ log_level = 'passive',
120
+ log_level_replica = 'warning',
121
+ log_on_each_node = True,
122
+ logging_dir = None,
123
+ logging_strategy = 'steps',
124
+ logging_first_step = False,
125
+ logging_steps = 1,
126
+ logging_nan_inf_filter = False,
127
+ save_strategy = 'steps',
128
+ save_steps = 500,
129
+ save_total_limit = None,
130
+ save_safetensors = True,
131
+ save_on_each_node = False,
132
+ save_only_model = False,
133
+ restore_callback_states_from_checkpoint = False,
134
+ no_cuda = False,
135
+ use_cpu = False,
136
+ use_mps_device = False,
137
+ seed = 3407,
138
+ data_seed = 3407,
139
+ jit_mode_eval = False,
140
+ use_ipex = False,
141
+ bf16 = False,
142
+ fp16 = False,
143
+ fp16_opt_level = 'O1',
144
+ half_precision_backend = 'auto',
145
+ bf16_full_eval = False,
146
+ fp16_full_eval = False,
147
+ tf32 = None,
148
+ local_rank = -1,
149
+ ddp_backend = None,
150
+ tpu_num_cores = None,
151
+ tpu_metrics_debug = False,
152
+ debug = '',
153
+ dataloader_drop_last = False,
154
+ eval_steps = None,
155
+ dataloader_num_workers = 0,
156
+ dataloader_prefetch_factor = None,
157
+ past_index = -1,
158
+ run_name = None,
159
+ disable_tqdm = None,
160
+ remove_unused_columns = True,
161
+ label_names = None,
162
+ load_best_model_at_end = False,
163
+ metric_for_best_model = None,
164
+ greater_is_better = None,
165
+ ignore_data_skip = False,
166
+ fsdp = '',
167
+ fsdp_min_num_params = 0,
168
+ fsdp_config = None,
169
+ fsdp_transformer_layer_cls_to_wrap = None,
170
+ accelerator_config = None,
171
+ deepspeed = None,
172
+ label_smoothing_factor = 0.0,
173
+ optim = 'adamw_8bit',
174
+ optim_args = None,
175
+ adafactor = False,
176
+ group_by_length = False,
177
+ length_column_name = 'length',
178
+ report_to = None,
179
+ ddp_find_unused_parameters = None,
180
+ ddp_bucket_cap_mb = None,
181
+ ddp_broadcast_buffers = None,
182
+ dataloader_pin_memory = True,
183
+ dataloader_persistent_workers = False,
184
+ skip_memory_metrics = True,
185
+ use_legacy_prediction_loop = False,
186
+ push_to_hub = False,
187
+ resume_from_checkpoint = None,
188
+ hub_model_id = None,
189
+ hub_strategy = 'every_save',
190
+ hub_token = None,
191
+ hub_private_repo = None,
192
+ hub_always_push = False,
193
+ gradient_checkpointing = False,
194
+ gradient_checkpointing_kwargs = None,
195
+ include_inputs_for_metrics = False,
196
+ eval_do_concat_batches = True,
197
+ fp16_backend = 'auto',
198
+ evaluation_strategy = None,
199
+ push_to_hub_model_id = None,
200
+ push_to_hub_organization = None,
201
+ push_to_hub_token = None,
202
+ mp_parameters = '',
203
+ auto_find_batch_size = False,
204
+ full_determinism = False,
205
+ torchdynamo = None,
206
+ ray_scope = 'last',
207
+ ddp_timeout = 1800,
208
+ torch_compile = False,
209
+ torch_compile_backend = None,
210
+ torch_compile_mode = None,
211
+ dispatch_batches = None,
212
+ split_batches = None,
213
+ include_tokens_per_second = False,
214
+ include_num_input_tokens_seen = False,
215
+ neftune_noise_alpha = None,
216
+ optim_target_modules = None,
217
+ batch_eval_metrics = False,
218
+ eval_on_start = False,
219
+ use_liger_kernel = False,
220
+ eval_use_gather_object = False,
221
+ average_tokens_across_devices = False,
222
+ dataset_num_proc = None,
223
+ num_mini_batches = 1,
224
+ total_episodes = None,
225
+ local_rollout_forward_batch_size = 64,
226
+ num_sample_generations = 10,
227
+ response_length = 53,
228
+ stop_token = None,
229
+ stop_token_id = None,
230
+ temperature = 0.7,
231
+ missing_eos_penalty = None,
232
+ sft_model_path = 'EleutherAI/pythia-160m',
233
+ world_size = None,
234
+ num_total_batches = None,
235
+ micro_batch_size = None,
236
+ local_batch_size = None,
237
+ batch_size = None,
238
+ local_mini_batch_size = None,
239
+ mini_batch_size = None,
240
+ exp_name = 'ppo_config',
241
+ reward_model_path = 'EleutherAI/pythia-160m',
242
+ model_adapter_name = None,
243
+ ref_adapter_name = None,
244
+ num_ppo_epochs = 4,
245
+ whiten_rewards = False,
246
+ kl_coef = 0.05,
247
+ cliprange = 0.2,
248
+ vf_coef = 0.1,
249
+ cliprange_value = 0.2,
250
+ gamma = 1.0,
251
+ lam = 0.95,
252
+ ds3_gather_for_generation = True,
253
+ vllm_sampling_params = None,
254
+ unsloth_num_chunks = -1,
255
+ **kwargs,
256
+ ):
257
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
258
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
259
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
260
+ output_dir = 'unsloth_training_checkpoints'
261
+ save_strategy = 'no'
262
+ if dataset_num_proc is None:
263
+ from multiprocessing import cpu_count
264
+ dataset_num_proc = cpu_count()
265
+
266
+ super().__init__(
267
+ output_dir = output_dir,
268
+ overwrite_output_dir = overwrite_output_dir,
269
+ do_train = do_train,
270
+ do_eval = do_eval,
271
+ do_predict = do_predict,
272
+ eval_strategy = eval_strategy,
273
+ prediction_loss_only = prediction_loss_only,
274
+ per_device_train_batch_size = per_device_train_batch_size,
275
+ per_device_eval_batch_size = per_device_eval_batch_size,
276
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
277
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
278
+ gradient_accumulation_steps = gradient_accumulation_steps,
279
+ eval_accumulation_steps = eval_accumulation_steps,
280
+ eval_delay = eval_delay,
281
+ torch_empty_cache_steps = torch_empty_cache_steps,
282
+ learning_rate = learning_rate,
283
+ weight_decay = weight_decay,
284
+ adam_beta1 = adam_beta1,
285
+ adam_beta2 = adam_beta2,
286
+ adam_epsilon = adam_epsilon,
287
+ max_grad_norm = max_grad_norm,
288
+ num_train_epochs = num_train_epochs,
289
+ max_steps = max_steps,
290
+ lr_scheduler_type = lr_scheduler_type,
291
+ warmup_ratio = warmup_ratio,
292
+ warmup_steps = warmup_steps,
293
+ log_level = log_level,
294
+ log_level_replica = log_level_replica,
295
+ log_on_each_node = log_on_each_node,
296
+ logging_dir = logging_dir,
297
+ logging_strategy = logging_strategy,
298
+ logging_first_step = logging_first_step,
299
+ logging_steps = logging_steps,
300
+ logging_nan_inf_filter = logging_nan_inf_filter,
301
+ save_strategy = save_strategy,
302
+ save_steps = save_steps,
303
+ save_total_limit = save_total_limit,
304
+ save_safetensors = save_safetensors,
305
+ save_on_each_node = save_on_each_node,
306
+ save_only_model = save_only_model,
307
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
308
+ no_cuda = no_cuda,
309
+ use_cpu = use_cpu,
310
+ use_mps_device = use_mps_device,
311
+ seed = seed,
312
+ data_seed = data_seed,
313
+ jit_mode_eval = jit_mode_eval,
314
+ use_ipex = use_ipex,
315
+ bf16 = bf16,
316
+ fp16 = fp16,
317
+ fp16_opt_level = fp16_opt_level,
318
+ half_precision_backend = half_precision_backend,
319
+ bf16_full_eval = bf16_full_eval,
320
+ fp16_full_eval = fp16_full_eval,
321
+ tf32 = tf32,
322
+ local_rank = local_rank,
323
+ ddp_backend = ddp_backend,
324
+ tpu_num_cores = tpu_num_cores,
325
+ tpu_metrics_debug = tpu_metrics_debug,
326
+ debug = debug,
327
+ dataloader_drop_last = dataloader_drop_last,
328
+ eval_steps = eval_steps,
329
+ dataloader_num_workers = dataloader_num_workers,
330
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
331
+ past_index = past_index,
332
+ run_name = run_name,
333
+ disable_tqdm = disable_tqdm,
334
+ remove_unused_columns = remove_unused_columns,
335
+ label_names = label_names,
336
+ load_best_model_at_end = load_best_model_at_end,
337
+ metric_for_best_model = metric_for_best_model,
338
+ greater_is_better = greater_is_better,
339
+ ignore_data_skip = ignore_data_skip,
340
+ fsdp = fsdp,
341
+ fsdp_min_num_params = fsdp_min_num_params,
342
+ fsdp_config = fsdp_config,
343
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
344
+ accelerator_config = accelerator_config,
345
+ deepspeed = deepspeed,
346
+ label_smoothing_factor = label_smoothing_factor,
347
+ optim = optim,
348
+ optim_args = optim_args,
349
+ adafactor = adafactor,
350
+ group_by_length = group_by_length,
351
+ length_column_name = length_column_name,
352
+ report_to = report_to,
353
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
354
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
355
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
356
+ dataloader_pin_memory = dataloader_pin_memory,
357
+ dataloader_persistent_workers = dataloader_persistent_workers,
358
+ skip_memory_metrics = skip_memory_metrics,
359
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
360
+ push_to_hub = push_to_hub,
361
+ resume_from_checkpoint = resume_from_checkpoint,
362
+ hub_model_id = hub_model_id,
363
+ hub_strategy = hub_strategy,
364
+ hub_token = hub_token,
365
+ hub_private_repo = hub_private_repo,
366
+ hub_always_push = hub_always_push,
367
+ gradient_checkpointing = gradient_checkpointing,
368
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
369
+ include_inputs_for_metrics = include_inputs_for_metrics,
370
+ eval_do_concat_batches = eval_do_concat_batches,
371
+ fp16_backend = fp16_backend,
372
+ evaluation_strategy = evaluation_strategy,
373
+ push_to_hub_model_id = push_to_hub_model_id,
374
+ push_to_hub_organization = push_to_hub_organization,
375
+ push_to_hub_token = push_to_hub_token,
376
+ mp_parameters = mp_parameters,
377
+ auto_find_batch_size = auto_find_batch_size,
378
+ full_determinism = full_determinism,
379
+ torchdynamo = torchdynamo,
380
+ ray_scope = ray_scope,
381
+ ddp_timeout = ddp_timeout,
382
+ torch_compile = torch_compile,
383
+ torch_compile_backend = torch_compile_backend,
384
+ torch_compile_mode = torch_compile_mode,
385
+ dispatch_batches = dispatch_batches,
386
+ split_batches = split_batches,
387
+ include_tokens_per_second = include_tokens_per_second,
388
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
389
+ neftune_noise_alpha = neftune_noise_alpha,
390
+ optim_target_modules = optim_target_modules,
391
+ batch_eval_metrics = batch_eval_metrics,
392
+ eval_on_start = eval_on_start,
393
+ use_liger_kernel = use_liger_kernel,
394
+ eval_use_gather_object = eval_use_gather_object,
395
+ average_tokens_across_devices = average_tokens_across_devices,
396
+ dataset_num_proc = dataset_num_proc,
397
+ num_mini_batches = num_mini_batches,
398
+ total_episodes = total_episodes,
399
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
400
+ num_sample_generations = num_sample_generations,
401
+ response_length = response_length,
402
+ stop_token = stop_token,
403
+ stop_token_id = stop_token_id,
404
+ temperature = temperature,
405
+ missing_eos_penalty = missing_eos_penalty,
406
+ sft_model_path = sft_model_path,
407
+ world_size = world_size,
408
+ num_total_batches = num_total_batches,
409
+ micro_batch_size = micro_batch_size,
410
+ local_batch_size = local_batch_size,
411
+ batch_size = batch_size,
412
+ local_mini_batch_size = local_mini_batch_size,
413
+ mini_batch_size = mini_batch_size,
414
+ exp_name = exp_name,
415
+ reward_model_path = reward_model_path,
416
+ model_adapter_name = model_adapter_name,
417
+ ref_adapter_name = ref_adapter_name,
418
+ num_ppo_epochs = num_ppo_epochs,
419
+ whiten_rewards = whiten_rewards,
420
+ kl_coef = kl_coef,
421
+ cliprange = cliprange,
422
+ vf_coef = vf_coef,
423
+ cliprange_value = cliprange_value,
424
+ gamma = gamma,
425
+ lam = lam,
426
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
427
+ self.vllm_sampling_params = vllm_sampling_params
428
+ self.unsloth_num_chunks = unsloth_num_chunks
429
+ pass
430
+
431
+ class _UnslothPPOTrainer(Trainer):
432
+ _tag_names = ["trl", "ppo"]
433
+
434
+ def __init__(
435
+ self,
436
+ args: PPOConfig,
437
+ processing_class: Optional[
438
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
439
+ ],
440
+ model: nn.Module,
441
+ ref_model: Optional[nn.Module],
442
+ reward_model: nn.Module,
443
+ train_dataset: Dataset,
444
+ value_model: Optional[nn.Module] = None,
445
+ data_collator: Optional[DataCollatorWithPadding] = None,
446
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
447
+ # less commonly used
448
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
449
+ callbacks: Optional[list[TrainerCallback]] = None,
450
+ peft_config: Optional["PeftConfig"] = None,
451
+ ) -> None:
452
+ if ref_model is model:
453
+ raise ValueError(
454
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
455
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
456
+ )
457
+
458
+ self.args = args
459
+ self.processing_class = processing_class
460
+ self.policy_model = model
461
+
462
+ # Define the collator if not provided
463
+ if data_collator is None:
464
+ data_collator = DataCollatorWithPadding(self.processing_class)
465
+
466
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
467
+ if args.stop_token and args.stop_token_id:
468
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
469
+ elif args.stop_token:
470
+ if args.stop_token == "eos":
471
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
472
+ else:
473
+ raise ValueError(
474
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
475
+ )
476
+ else:
477
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
478
+
479
+ # peft support
480
+ if not is_peft_available() and peft_config is not None:
481
+ raise ImportError(
482
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
483
+ )
484
+ elif is_peft_available() and peft_config is not None:
485
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
486
+ if isinstance(self.policy_model, PeftModel):
487
+ self.policy_model = self.policy_model.merge_and_unload()
488
+
489
+ # get peft model with the given config
490
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
491
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
492
+ peft_module_casting_to_bf16(self.policy_model)
493
+
494
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
495
+ self.model_adapter_name = args.model_adapter_name
496
+ self.ref_adapter_name = args.ref_adapter_name
497
+
498
+ if ref_model:
499
+ self.ref_model = ref_model
500
+ elif self.is_peft_model:
501
+ self.ref_model = None
502
+ else:
503
+ self.ref_model = create_reference_model(self.policy_model)
504
+
505
+ self.reward_model = reward_model
506
+ self.train_dataset = train_dataset
507
+ self.train_dataset_len = len(train_dataset)
508
+ self.value_model = value_model
509
+ self.data_collator = data_collator
510
+ self.eval_dataset = eval_dataset
511
+ self.optimizer, self.lr_scheduler = optimizers
512
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
513
+
514
+ #########
515
+ # calculate various batch sizes
516
+ #########
517
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
518
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
519
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
520
+ self.accelerator = accelerator
521
+ args.world_size = accelerator.num_processes
522
+ args.local_batch_size = (
523
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
524
+ )
525
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
526
+ args.batch_size = int(args.local_batch_size * args.world_size)
527
+ args.mini_batch_size = exact_div(
528
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
529
+ )
530
+ args.local_mini_batch_size = exact_div(
531
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
532
+ )
533
+ if args.whiten_rewards:
534
+ assert (
535
+ args.local_mini_batch_size >= 8
536
+ ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
537
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
538
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
539
+ args.num_total_batches = math.ceil(
540
+ args.total_episodes / args.batch_size
541
+ ) # we may train for more than `total_episodes`
542
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
543
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
544
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
545
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
546
+ if args.num_sample_generations > 0:
547
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
548
+ self.local_dataloader_batch_size = args.local_batch_size
549
+
550
+ #########
551
+ # setup model, optimizer, and others
552
+ #########
553
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
554
+ if module is not None:
555
+ disable_dropout_in_model(module)
556
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
557
+ self.model.config = self.policy_model.config # needed for pushing to hub
558
+ self.create_optimizer_and_scheduler(
559
+ num_training_steps=args.num_total_batches
560
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
561
+
562
+ #########
563
+ ### trainer specifics
564
+ #########
565
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
566
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
567
+ self.callback_handler = CallbackHandler(
568
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
569
+ )
570
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
571
+ self.control = TrainerControl()
572
+ self.state = OnlineTrainerState(
573
+ is_local_process_zero=self.is_local_process_zero(),
574
+ is_world_process_zero=self.is_world_process_zero(),
575
+ stateful_callbacks=[
576
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
577
+ ],
578
+ )
579
+ self.current_flos = 0
580
+ self.hp_search_backend = None
581
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
582
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
583
+ # Create distant repo and output directory if needed
584
+ self.hub_model_id = None
585
+ if self.args.push_to_hub:
586
+ self.init_hf_repo()
587
+ if self.args.should_save:
588
+ os.makedirs(self.args.output_dir, exist_ok=True)
589
+
590
+ # Add tags for models that have been loaded with the correct transformers version
591
+ if hasattr(self.model, "add_model_tags"):
592
+ self.model.add_model_tags(self._tag_names)
593
+
594
+ #########
595
+ ### setup dataloader
596
+ #########
597
+ self.dataloader = DataLoader(
598
+ self.train_dataset,
599
+ batch_size=self.local_dataloader_batch_size,
600
+ shuffle=True,
601
+ collate_fn=self.data_collator,
602
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
603
+ )
604
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
605
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
606
+ torch.manual_seed(args.seed)
607
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
608
+ torch.manual_seed(self.local_seed) # reset the local seed again
609
+
610
+ self.eval_dataloader = DataLoader(
611
+ self.eval_dataset,
612
+ batch_size=args.per_device_eval_batch_size,
613
+ collate_fn=self.data_collator,
614
+ drop_last=True,
615
+ ) # no need to shuffle eval dataset
616
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
617
+
618
+ if self.is_deepspeed_enabled:
619
+ self.reward_model = prepare_deepspeed(
620
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
621
+ )
622
+
623
+ if self.ref_model is None:
624
+ if not self.is_peft_model:
625
+ raise ValueError("No reference model and model is not a Peft model.")
626
+ else:
627
+ self.ref_model = prepare_deepspeed(
628
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
629
+ )
630
+ else:
631
+ if self.ref_model is None:
632
+ if not self.is_peft_model:
633
+ raise ValueError("No reference model and model is not a Peft model.")
634
+ else:
635
+ self.ref_model = self.ref_model.to(self.accelerator.device)
636
+ self.reward_model = self.reward_model.to(self.accelerator.device)
637
+
638
+ def get_train_dataloader(self) -> DataLoader:
639
+ return self.dataloader
640
+
641
+ def get_eval_dataloader(self) -> DataLoader:
642
+ return self.eval_dataloader
643
+
644
+ @contextmanager
645
+ def null_ref_context(self):
646
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
647
+ with (
648
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
649
+ if self.is_peft_model and not self.ref_adapter_name
650
+ else nullcontext()
651
+ ):
652
+ if self.ref_adapter_name:
653
+ self.model.policy.set_adapter(self.ref_adapter_name)
654
+ yield
655
+ if self.ref_adapter_name:
656
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
657
+
658
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
659
+ backup_model = self.model
660
+ self.model = self.model.policy # save only the policy
661
+
662
+ if self.is_deepspeed_enabled:
663
+ backup_deepspeed = self.deepspeed
664
+ self.deepspeed = self.model
665
+
666
+ super().save_model(output_dir, _internal_call)
667
+
668
+ self.model = backup_model
669
+
670
+ if self.is_deepspeed_enabled:
671
+ self.deepspeed = backup_deepspeed
672
+
673
+ def train(self):
674
+ args = self.args
675
+ accelerator = self.accelerator
676
+ optimizer = self.optimizer
677
+ model = self.model
678
+ ref_policy = self.ref_model
679
+ reward_model = self.reward_model
680
+ processing_class = self.processing_class
681
+ dataloader = self.dataloader
682
+ device = accelerator.device
683
+
684
+ def repeat_generator():
685
+ while True:
686
+ yield from dataloader
687
+
688
+ iter_dataloader = iter(repeat_generator())
689
+ generation_config = GenerationConfig(
690
+ max_new_tokens=args.response_length,
691
+ temperature=(args.temperature + 1e-7),
692
+ top_k=0.0,
693
+ top_p=1.0,
694
+ do_sample=True,
695
+ )
696
+
697
+ accelerator.print("===training policy===")
698
+ start_time = time.time()
699
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
700
+ approxkl_stats = torch.zeros(stats_shape, device=device)
701
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
702
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
703
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
704
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
705
+ entropy_stats = torch.zeros(stats_shape, device=device)
706
+ ratio_stats = torch.zeros(stats_shape, device=device)
707
+ model.train()
708
+
709
+ # trainer state initialization
710
+ self.state.global_step = 0
711
+ self.state.episode = 0
712
+ self.state.max_steps = args.num_total_batches * args.num_mini_batches
713
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
714
+ # Compute absolute values for logging, eval, and save if given as ratio
715
+ if args.logging_steps is not None:
716
+ if args.logging_steps < 1:
717
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
718
+ else:
719
+ self.state.logging_steps = args.logging_steps
720
+ if args.eval_steps is not None:
721
+ if args.eval_steps < 1:
722
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
723
+ else:
724
+ self.state.eval_steps = args.eval_steps
725
+ if args.save_steps is not None:
726
+ if args.save_steps < 1:
727
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
728
+ else:
729
+ self.state.save_steps = args.save_steps
730
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
731
+
732
+ # backward compatibility
733
+ if self.is_deepspeed_enabled:
734
+ self.deepspeed = self.model
735
+ self.model_wrapped = self.model
736
+
737
+ for update in range(1, args.num_total_batches + 1):
738
+ self.state.episode += 1 * args.batch_size
739
+ data = next(iter_dataloader)
740
+ with torch.no_grad():
741
+ queries = data["input_ids"].to(device)
742
+ context_length = queries.shape[1]
743
+ responses = []
744
+ postprocessed_responses = []
745
+ logprobs = []
746
+ ref_logprobs = []
747
+ scores = []
748
+ sequence_lengths = []
749
+ values = []
750
+ with unwrap_model_for_generation(
751
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
752
+ ) as unwrapped_model:
753
+ query_responses, logitss = batch_generation(
754
+ unwrapped_model.policy,
755
+ queries,
756
+ args.local_rollout_forward_batch_size,
757
+ processing_class.pad_token_id,
758
+ generation_config,
759
+ )
760
+
761
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
762
+ query = queries[i : i + args.local_rollout_forward_batch_size]
763
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
764
+ response = query_response[:, context_length:]
765
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
766
+ logprob = selective_log_softmax(logits, response)
767
+ del logits
768
+ torch.cuda.empty_cache()
769
+
770
+ if ref_policy is None:
771
+ with self.null_ref_context():
772
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
773
+ else:
774
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
775
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
776
+ ref_logits /= args.temperature + 1e-7
777
+ ref_logprob = selective_log_softmax(ref_logits, response)
778
+ del ref_output, ref_logits
779
+ torch.cuda.empty_cache()
780
+
781
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
782
+ postprocessed_response = response
783
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
784
+ postprocessed_response = truncate_response(
785
+ self.stop_token_id, processing_class.pad_token_id, response
786
+ )
787
+
788
+ # Response Processing 2. run reward model on the truncated responses
789
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
790
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
791
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
792
+ full_value, _, _ = get_reward(
793
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
794
+ )
795
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
796
+ _, score, _ = get_reward(
797
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
798
+ )
799
+
800
+ responses.append(response)
801
+ postprocessed_responses.append(postprocessed_response)
802
+ logprobs.append(logprob)
803
+ ref_logprobs.append(ref_logprob)
804
+ sequence_lengths.append(sequence_length)
805
+ scores.append(score)
806
+ values.append(value)
807
+ responses = torch.cat(responses, 0)
808
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
809
+ logprobs = torch.cat(logprobs, 0)
810
+ ref_logprobs = torch.cat(ref_logprobs, 0)
811
+ sequence_lengths = torch.cat(sequence_lengths, 0)
812
+ scores = torch.cat(scores, 0)
813
+ values = torch.cat(values, 0)
814
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
815
+ torch.cuda.empty_cache()
816
+ gc.collect()
817
+
818
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
819
+ # Completions not passing that filter will receive a lower score.
820
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
821
+ if self.args.missing_eos_penalty is not None:
822
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
823
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
824
+
825
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
826
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
827
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
828
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
829
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
830
+ sequence_lengths_p1 = sequence_lengths + 1
831
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
832
+ values = torch.masked_fill(values, padding_mask_p1, 0)
833
+
834
+ # 4. compute rewards
835
+ kl = logprobs - ref_logprobs
836
+ non_score_reward = -args.kl_coef * kl
837
+ rewards = non_score_reward.clone()
838
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
839
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
840
+ rewards[[actual_start, actual_end]] += scores
841
+
842
+ # 5. whiten rewards
843
+ if args.whiten_rewards:
844
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
845
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
846
+
847
+ # 6. compute advantages and returns
848
+ lastgaelam = 0
849
+ advantages_reversed = []
850
+ gen_length = responses.shape[1]
851
+ for t in reversed(range(gen_length)):
852
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
853
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
854
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
855
+ advantages_reversed.append(lastgaelam)
856
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
857
+ returns = advantages + values
858
+ advantages = masked_whiten(advantages, ~padding_mask)
859
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
860
+ torch.cuda.empty_cache()
861
+
862
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
863
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
864
+ b_inds = np.random.permutation(args.local_batch_size)
865
+ minibatch_idx = 0
866
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
867
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
868
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
869
+ gradient_accumulation_idx = 0
870
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
871
+ with accelerator.accumulate(model):
872
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
873
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
874
+ mb_advantage = advantages[micro_batch_inds]
875
+ mb_responses = responses[micro_batch_inds]
876
+ mb_query_responses = query_responses[micro_batch_inds]
877
+ mb_logprobs = logprobs[micro_batch_inds]
878
+ mb_return = returns[micro_batch_inds]
879
+ mb_values = values[micro_batch_inds]
880
+
881
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
882
+ logits = output.logits[:, context_length - 1 : -1]
883
+ logits /= args.temperature + 1e-7
884
+ new_logprobs = selective_log_softmax(logits, mb_responses)
885
+ new_logprobs = torch.masked_fill(
886
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
887
+ )
888
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
889
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
890
+ vpredclipped = torch.clamp(
891
+ vpred,
892
+ mb_values - args.cliprange_value,
893
+ mb_values + args.cliprange_value,
894
+ )
895
+ vf_losses1 = torch.square(vpred - mb_return)
896
+ vf_losses2 = torch.square(vpredclipped - mb_return)
897
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
898
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
899
+ vf_clipfrac = masked_mean(
900
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
901
+ )
902
+ logprobs_diff = new_logprobs - mb_logprobs
903
+ ratio = torch.exp(logprobs_diff)
904
+ pg_losses = -mb_advantage * ratio
905
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
906
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
907
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
908
+ loss = pg_loss + args.vf_coef * vf_loss
909
+ accelerator.backward(loss)
910
+ optimizer.step()
911
+ optimizer.zero_grad()
912
+ with torch.no_grad():
913
+ pg_clipfrac = masked_mean(
914
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
915
+ )
916
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
917
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
918
+ approxkl = 0.5 * (logprobs_diff**2).mean()
919
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
920
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
921
+ pg_clipfrac
922
+ )
923
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
924
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
925
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
926
+ vf_clipfrac
927
+ )
928
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
929
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
930
+ gradient_accumulation_idx += 1
931
+ minibatch_idx += 1
932
+ # del everything and empty cache
933
+ # fmt: off
934
+ del (
935
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
936
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
937
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
938
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
939
+ )
940
+ # fmt: on
941
+ torch.cuda.empty_cache()
942
+ with torch.no_grad():
943
+ mean_kl = kl.sum(1).mean()
944
+ mean_entropy = (-logprobs).sum(1).mean()
945
+ mean_non_score_reward = non_score_reward.sum(1).mean()
946
+ rlhf_reward = mean_non_score_reward + scores.mean()
947
+ eps = int(self.state.episode / (time.time() - start_time))
948
+ metrics = {}
949
+ metrics["eps"] = eps
950
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
951
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
952
+ metrics["objective/non_score_reward"] = (
953
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
954
+ )
955
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
956
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
957
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
958
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
959
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
960
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
961
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
962
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
963
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
964
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
965
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
966
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
967
+ metrics["episode"] = self.state.episode
968
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
969
+ self.state.global_step += 1
970
+ self.log(metrics)
971
+
972
+ self.lr_scheduler.step()
973
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
974
+ if self.control.should_save:
975
+ self._save_checkpoint(model, trial=None)
976
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
977
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
978
+ torch.cuda.empty_cache()
979
+ gc.collect()
980
+
981
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
982
+ self.generate_completions(sampling=True)
983
+ torch.cuda.empty_cache()
984
+ del (
985
+ query_responses,
986
+ responses,
987
+ postprocessed_responses,
988
+ logprobs,
989
+ ref_logprobs,
990
+ values,
991
+ sequence_lengths,
992
+ contain_eos_token,
993
+ sequence_lengths_p1,
994
+ response_idxs,
995
+ padding_mask,
996
+ padding_mask_p1,
997
+ rewards,
998
+ actual_start,
999
+ actual_end,
1000
+ advantages,
1001
+ returns,
1002
+ )
1003
+ torch.cuda.empty_cache()
1004
+
1005
+ # HF trainer specifics
1006
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1007
+ if self.control.should_save:
1008
+ self._save_checkpoint(model, trial=None, metrics=None)
1009
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1010
+
1011
+ def generate_completions(self, sampling: bool = False):
1012
+ args = self.args
1013
+ processing_class = self.processing_class
1014
+ generation_config = GenerationConfig(
1015
+ max_new_tokens=self.args.response_length,
1016
+ temperature=(0.01 + 1e-7),
1017
+ top_k=0.0,
1018
+ top_p=1.0,
1019
+ do_sample=True,
1020
+ )
1021
+
1022
+ table = defaultdict(list)
1023
+ with unwrap_model_for_generation(
1024
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1025
+ ) as unwrapped_model:
1026
+ for batch in self.eval_dataloader:
1027
+ query = batch["input_ids"]
1028
+ with torch.no_grad():
1029
+ context_length = query.shape[1]
1030
+ query_response, _ = batch_generation(
1031
+ unwrapped_model.policy,
1032
+ query,
1033
+ query.shape[0],
1034
+ processing_class.pad_token_id,
1035
+ generation_config,
1036
+ )
1037
+ response = query_response[:, context_length:]
1038
+ postprocessed_response = response
1039
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1040
+ postprocessed_response = truncate_response(
1041
+ self.stop_token_id, processing_class.pad_token_id, response
1042
+ )
1043
+ table["query"].extend(
1044
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1045
+ )
1046
+ table["model response"].extend(
1047
+ gather_object(processing_class.batch_decode(postprocessed_response))
1048
+ )
1049
+
1050
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1051
+ _, score, _ = get_reward(
1052
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1053
+ )
1054
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1055
+
1056
+ if sampling:
1057
+ break
1058
+ df = pd.DataFrame(table)
1059
+
1060
+ if self.accelerator.is_main_process:
1061
+ print_rich_table(df.iloc[0 : 0 + 5])
1062
+ if "wandb" in args.report_to:
1063
+ import wandb
1064
+
1065
+ if wandb.run is not None:
1066
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1067
+
1068
+ if "comet_ml" in args.report_to:
1069
+ log_table_to_comet_experiment(
1070
+ name="completions.csv",
1071
+ table=df,
1072
+ )
1073
+
1074
+ def create_model_card(
1075
+ self,
1076
+ model_name: Optional[str] = None,
1077
+ dataset_name: Optional[str] = None,
1078
+ tags: Union[str, list[str], None] = None,
1079
+ ):
1080
+ """
1081
+ Creates a draft of a model card using the information available to the `Trainer`.
1082
+
1083
+ Args:
1084
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1085
+ Name of the model.
1086
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1087
+ Name of the dataset used for training.
1088
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1089
+ Tags to be associated with the model card.
1090
+ """
1091
+ if not self.is_world_process_zero():
1092
+ return
1093
+
1094
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1095
+ base_model = self.model.config._name_or_path
1096
+ else:
1097
+ base_model = None
1098
+
1099
+ tags = tags or []
1100
+ if isinstance(tags, str):
1101
+ tags = [tags]
1102
+
1103
+ if hasattr(self.model.config, "unsloth_version"):
1104
+ tags.append("unsloth")
1105
+
1106
+ citation = textwrap.dedent("""\
1107
+ @article{mziegler2019fine-tuning,
1108
+ title = {{Fine-Tuning Language Models from Human Preferences}},
1109
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1110
+ year = 2019,
1111
+ eprint = {arXiv:1909.08593}
1112
+ }""")
1113
+
1114
+ model_card = generate_model_card(
1115
+ base_model=base_model,
1116
+ model_name=model_name,
1117
+ hub_model_id=self.hub_model_id,
1118
+ dataset_name=dataset_name,
1119
+ tags=tags,
1120
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1121
+ comet_url=get_comet_experiment_url(),
1122
+ trainer_name="PPO",
1123
+ trainer_citation=citation,
1124
+ paper_title="Fine-Tuning Language Models from Human Preferences",
1125
+ paper_id="1909.08593",
1126
+ )
1127
+
1128
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1129
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1130
+ """
1131
+
1132
+ """
1133
+ def __init__(
1134
+ self,
1135
+ args,
1136
+ processing_class,
1137
+ model,
1138
+ ref_model,
1139
+ reward_model,
1140
+ train_dataset,
1141
+ value_model = None,
1142
+ data_collator = None,
1143
+ eval_dataset = None,
1144
+ callbacks = None,
1145
+ peft_config = None,
1146
+ **kwargs
1147
+ ):
1148
+ if args is None: args = UnslothPPOConfig()
1149
+ use_bf16 = getattr(args, 'bf16', False)
1150
+ use_fp16 = getattr(args, 'fp16', False)
1151
+ force_float32 = False
1152
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1153
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1154
+ force_float32 = True
1155
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1156
+ dtype = getattr(model.config, 'torch_dtype', None)
1157
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1158
+ from unsloth_zoo.utils import _get_dtype
1159
+ dtype = _get_dtype(dtype)
1160
+ float16 = dtype == torch.float16
1161
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1162
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1163
+ if force_float32:
1164
+ args.fp16 = False
1165
+ args.bf16 = False
1166
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1167
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1168
+ args.fp16 = float16
1169
+ args.bf16 = not float16
1170
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1171
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1172
+ args.eval_strategy = 'steps'
1173
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1174
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1175
+ if ga_steps is not None and ga_steps > 1:
1176
+ from transformers import __version__ as transformers_version
1177
+ if Version(transformers_version) <= Version('4.45.2'):
1178
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1179
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1180
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1181
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1182
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1183
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1184
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1185
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1186
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1187
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1188
+ if force_float32:
1189
+ args.bf16_full_eval = False
1190
+ args.fp16_full_eval = False
1191
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1192
+ args.bf16_full_eval = True
1193
+ args.fp16_full_eval = False
1194
+ elif not bf16_full_eval and not fp16_full_eval:
1195
+ args.bf16_full_eval = args.bf16
1196
+ args.fp16_full_eval = args.fp16
1197
+ _output_logits = False
1198
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1199
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1200
+ if _output_logits:
1201
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1202
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1203
+ pass
1204
+ else:
1205
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1206
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1207
+ if args_max_seq_length is None and model_max_seq_length is not None:
1208
+ max_seq_length = model.max_seq_length
1209
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1210
+ if model is not None and hasattr(model, 'for_training'):
1211
+ model.for_training()
1212
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1213
+ if 'processing_class' in locals():
1214
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1215
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1216
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1217
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1218
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1219
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1220
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1221
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1222
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1223
+ else:
1224
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1225
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1226
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1227
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1228
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1229
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1230
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1231
+ else:
1232
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1233
+ other_metrics = []
1234
+
1235
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1236
+ PatchRLStatistics('ppo_trainer', other_metrics)
1237
+
1238
+ super().__init__(
1239
+ args = args,
1240
+ processing_class = processing_class,
1241
+ model = model,
1242
+ ref_model = ref_model,
1243
+ reward_model = reward_model,
1244
+ train_dataset = train_dataset,
1245
+ value_model = value_model,
1246
+ data_collator = data_collator,
1247
+ eval_dataset = eval_dataset,
1248
+ callbacks = callbacks,
1249
+ peft_config = peft_config,**kwargs)
1250
+ if hasattr(self, 'neftune_hook_handle'):
1251
+ self.neftune_hook_handle.remove()
1252
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1253
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1254
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1255
+ pass
1256
+
1257
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.13
3
+ 2025.3.15
4
+ 4.48.3
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPRMConfig(PRMConfig):
44
+ """
45
+
46
+ Configuration class for the [`PRMTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-5`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) used for truncation.
58
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
59
+ Maximum length of the prompt used for truncation.
60
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
61
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
62
+ disable_dropout (`bool`, *optional*, defaults to `True`):
63
+ Whether to disable dropout in the model.
64
+ step_separator (`str`, *optional*, defaults to `"\n"`):
65
+ Separator used to separate each step of the reasoning process.
66
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
67
+ Whether to train only on the last step.
68
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
69
+ Number of processes to use for processing the dataset.
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ fsdp_transformer_layer_cls_to_wrap = None,
159
+ accelerator_config = None,
160
+ deepspeed = None,
161
+ label_smoothing_factor = 0.0,
162
+ optim = 'adamw_8bit',
163
+ optim_args = None,
164
+ adafactor = False,
165
+ group_by_length = False,
166
+ length_column_name = 'length',
167
+ report_to = None,
168
+ ddp_find_unused_parameters = None,
169
+ ddp_bucket_cap_mb = None,
170
+ ddp_broadcast_buffers = None,
171
+ dataloader_pin_memory = True,
172
+ dataloader_persistent_workers = False,
173
+ skip_memory_metrics = True,
174
+ use_legacy_prediction_loop = False,
175
+ push_to_hub = False,
176
+ resume_from_checkpoint = None,
177
+ hub_model_id = None,
178
+ hub_strategy = 'every_save',
179
+ hub_token = None,
180
+ hub_private_repo = None,
181
+ hub_always_push = False,
182
+ gradient_checkpointing = False,
183
+ gradient_checkpointing_kwargs = None,
184
+ include_inputs_for_metrics = False,
185
+ eval_do_concat_batches = True,
186
+ fp16_backend = 'auto',
187
+ evaluation_strategy = None,
188
+ push_to_hub_model_id = None,
189
+ push_to_hub_organization = None,
190
+ push_to_hub_token = None,
191
+ mp_parameters = '',
192
+ auto_find_batch_size = False,
193
+ full_determinism = False,
194
+ torchdynamo = None,
195
+ ray_scope = 'last',
196
+ ddp_timeout = 1800,
197
+ torch_compile = False,
198
+ torch_compile_backend = None,
199
+ torch_compile_mode = None,
200
+ dispatch_batches = None,
201
+ split_batches = None,
202
+ include_tokens_per_second = False,
203
+ include_num_input_tokens_seen = False,
204
+ neftune_noise_alpha = None,
205
+ optim_target_modules = None,
206
+ batch_eval_metrics = False,
207
+ eval_on_start = False,
208
+ use_liger_kernel = False,
209
+ eval_use_gather_object = False,
210
+ average_tokens_across_devices = False,
211
+ max_length = 1024,
212
+ max_prompt_length = 512,
213
+ max_completion_length = None,
214
+ disable_dropout = True,
215
+ step_separator = '\
216
+ ',
217
+ train_on_last_step_only = False,
218
+ dataset_num_proc = None,
219
+ vllm_sampling_params = None,
220
+ unsloth_num_chunks = -1,
221
+ **kwargs,
222
+ ):
223
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
224
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
225
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
226
+ output_dir = 'unsloth_training_checkpoints'
227
+ save_strategy = 'no'
228
+ if dataset_num_proc is None:
229
+ from multiprocessing import cpu_count
230
+ dataset_num_proc = cpu_count()
231
+
232
+ super().__init__(
233
+ output_dir = output_dir,
234
+ overwrite_output_dir = overwrite_output_dir,
235
+ do_train = do_train,
236
+ do_eval = do_eval,
237
+ do_predict = do_predict,
238
+ eval_strategy = eval_strategy,
239
+ prediction_loss_only = prediction_loss_only,
240
+ per_device_train_batch_size = per_device_train_batch_size,
241
+ per_device_eval_batch_size = per_device_eval_batch_size,
242
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
243
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
244
+ gradient_accumulation_steps = gradient_accumulation_steps,
245
+ eval_accumulation_steps = eval_accumulation_steps,
246
+ eval_delay = eval_delay,
247
+ torch_empty_cache_steps = torch_empty_cache_steps,
248
+ learning_rate = learning_rate,
249
+ weight_decay = weight_decay,
250
+ adam_beta1 = adam_beta1,
251
+ adam_beta2 = adam_beta2,
252
+ adam_epsilon = adam_epsilon,
253
+ max_grad_norm = max_grad_norm,
254
+ num_train_epochs = num_train_epochs,
255
+ max_steps = max_steps,
256
+ lr_scheduler_type = lr_scheduler_type,
257
+ warmup_ratio = warmup_ratio,
258
+ warmup_steps = warmup_steps,
259
+ log_level = log_level,
260
+ log_level_replica = log_level_replica,
261
+ log_on_each_node = log_on_each_node,
262
+ logging_dir = logging_dir,
263
+ logging_strategy = logging_strategy,
264
+ logging_first_step = logging_first_step,
265
+ logging_steps = logging_steps,
266
+ logging_nan_inf_filter = logging_nan_inf_filter,
267
+ save_strategy = save_strategy,
268
+ save_steps = save_steps,
269
+ save_total_limit = save_total_limit,
270
+ save_safetensors = save_safetensors,
271
+ save_on_each_node = save_on_each_node,
272
+ save_only_model = save_only_model,
273
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
274
+ no_cuda = no_cuda,
275
+ use_cpu = use_cpu,
276
+ use_mps_device = use_mps_device,
277
+ seed = seed,
278
+ data_seed = data_seed,
279
+ jit_mode_eval = jit_mode_eval,
280
+ use_ipex = use_ipex,
281
+ bf16 = bf16,
282
+ fp16 = fp16,
283
+ fp16_opt_level = fp16_opt_level,
284
+ half_precision_backend = half_precision_backend,
285
+ bf16_full_eval = bf16_full_eval,
286
+ fp16_full_eval = fp16_full_eval,
287
+ tf32 = tf32,
288
+ local_rank = local_rank,
289
+ ddp_backend = ddp_backend,
290
+ tpu_num_cores = tpu_num_cores,
291
+ tpu_metrics_debug = tpu_metrics_debug,
292
+ debug = debug,
293
+ dataloader_drop_last = dataloader_drop_last,
294
+ eval_steps = eval_steps,
295
+ dataloader_num_workers = dataloader_num_workers,
296
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
297
+ past_index = past_index,
298
+ run_name = run_name,
299
+ disable_tqdm = disable_tqdm,
300
+ remove_unused_columns = remove_unused_columns,
301
+ label_names = label_names,
302
+ load_best_model_at_end = load_best_model_at_end,
303
+ metric_for_best_model = metric_for_best_model,
304
+ greater_is_better = greater_is_better,
305
+ ignore_data_skip = ignore_data_skip,
306
+ fsdp = fsdp,
307
+ fsdp_min_num_params = fsdp_min_num_params,
308
+ fsdp_config = fsdp_config,
309
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
310
+ accelerator_config = accelerator_config,
311
+ deepspeed = deepspeed,
312
+ label_smoothing_factor = label_smoothing_factor,
313
+ optim = optim,
314
+ optim_args = optim_args,
315
+ adafactor = adafactor,
316
+ group_by_length = group_by_length,
317
+ length_column_name = length_column_name,
318
+ report_to = report_to,
319
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
320
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
321
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
322
+ dataloader_pin_memory = dataloader_pin_memory,
323
+ dataloader_persistent_workers = dataloader_persistent_workers,
324
+ skip_memory_metrics = skip_memory_metrics,
325
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
326
+ push_to_hub = push_to_hub,
327
+ resume_from_checkpoint = resume_from_checkpoint,
328
+ hub_model_id = hub_model_id,
329
+ hub_strategy = hub_strategy,
330
+ hub_token = hub_token,
331
+ hub_private_repo = hub_private_repo,
332
+ hub_always_push = hub_always_push,
333
+ gradient_checkpointing = gradient_checkpointing,
334
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
335
+ include_inputs_for_metrics = include_inputs_for_metrics,
336
+ eval_do_concat_batches = eval_do_concat_batches,
337
+ fp16_backend = fp16_backend,
338
+ evaluation_strategy = evaluation_strategy,
339
+ push_to_hub_model_id = push_to_hub_model_id,
340
+ push_to_hub_organization = push_to_hub_organization,
341
+ push_to_hub_token = push_to_hub_token,
342
+ mp_parameters = mp_parameters,
343
+ auto_find_batch_size = auto_find_batch_size,
344
+ full_determinism = full_determinism,
345
+ torchdynamo = torchdynamo,
346
+ ray_scope = ray_scope,
347
+ ddp_timeout = ddp_timeout,
348
+ torch_compile = torch_compile,
349
+ torch_compile_backend = torch_compile_backend,
350
+ torch_compile_mode = torch_compile_mode,
351
+ dispatch_batches = dispatch_batches,
352
+ split_batches = split_batches,
353
+ include_tokens_per_second = include_tokens_per_second,
354
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
355
+ neftune_noise_alpha = neftune_noise_alpha,
356
+ optim_target_modules = optim_target_modules,
357
+ batch_eval_metrics = batch_eval_metrics,
358
+ eval_on_start = eval_on_start,
359
+ use_liger_kernel = use_liger_kernel,
360
+ eval_use_gather_object = eval_use_gather_object,
361
+ average_tokens_across_devices = average_tokens_across_devices,
362
+ max_length = max_length,
363
+ max_prompt_length = max_prompt_length,
364
+ max_completion_length = max_completion_length,
365
+ disable_dropout = disable_dropout,
366
+ step_separator = step_separator,
367
+ train_on_last_step_only = train_on_last_step_only,
368
+ dataset_num_proc = dataset_num_proc,**kwargs)
369
+ self.vllm_sampling_params = vllm_sampling_params
370
+ self.unsloth_num_chunks = unsloth_num_chunks
371
+ pass
372
+
373
+ class _UnslothPRMTrainer(Trainer):
374
+ """"""
375
+
376
+ _tag_names = ["trl", "prm"]
377
+
378
+ def __init__(
379
+ self,
380
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
381
+ args: Optional[PRMConfig] = None,
382
+ data_collator: Optional[DataCollator] = None,
383
+ train_dataset: Optional[Dataset] = None,
384
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
385
+ processing_class: Optional[
386
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
387
+ ] = None,
388
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
389
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
390
+ callbacks: Optional[list[TrainerCallback]] = None,
391
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
392
+ None,
393
+ None,
394
+ ),
395
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
396
+ peft_config: Optional[dict] = None,
397
+ ):
398
+ if not is_peft_available() and peft_config is not None:
399
+ raise ValueError(
400
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
401
+ )
402
+ elif is_peft_available() and peft_config is not None:
403
+ if not isinstance(model, PeftModel):
404
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
405
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
406
+ inspect.signature(prepare_model_for_kbit_training).parameters
407
+ )
408
+
409
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
410
+
411
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
412
+ warnings.warn(
413
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
414
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
415
+ )
416
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
417
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
418
+
419
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
420
+
421
+ model = model
422
+
423
+ # Disable dropout in the model
424
+ if args.disable_dropout:
425
+ disable_dropout_in_model(model)
426
+
427
+ if compute_metrics is None:
428
+ compute_metrics = compute_accuracy
429
+
430
+ if data_collator is None:
431
+ if processing_class is None:
432
+ raise ValueError(
433
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
434
+ )
435
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
436
+
437
+ if "input_ids" not in train_dataset.column_names:
438
+ with PartialState().local_main_process_first():
439
+ fn_kwargs = {
440
+ "tokenizer": processing_class,
441
+ "step_separator": args.step_separator,
442
+ "max_length": args.max_length,
443
+ "max_prompt_length": args.max_prompt_length,
444
+ "max_completion_length": args.max_completion_length,
445
+ "train_on_last_step_only": args.train_on_last_step_only,
446
+ }
447
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
448
+ train_dataset = train_dataset.map(
449
+ self.tokenize_row,
450
+ fn_kwargs=train_fn_kwargs,
451
+ num_proc=args.dataset_num_proc,
452
+ remove_columns=train_dataset.features,
453
+ desc="Tokenizing train dataset",
454
+ features=features.Features( # needed to avoid map to cast labels to bool
455
+ {
456
+ "labels": features.Sequence(features.Value("int64")),
457
+ "input_ids": features.Sequence(features.Value("int64")),
458
+ }
459
+ ),
460
+ )
461
+
462
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
463
+ if eval_dataset is not None:
464
+ eval_dataset = eval_dataset.map(
465
+ self.tokenize_row,
466
+ fn_kwargs=eval_fn_kwargs,
467
+ num_proc=args.dataset_num_proc,
468
+ remove_columns=eval_dataset.features,
469
+ desc="Tokenizing eval dataset",
470
+ features=features.Features( # needed to avoid map to cast labels to bool
471
+ {
472
+ "labels": features.Sequence(features.Value("int64")),
473
+ "input_ids": features.Sequence(features.Value("int64")),
474
+ }
475
+ ),
476
+ )
477
+
478
+ super().__init__(
479
+ model=model,
480
+ args=args,
481
+ data_collator=data_collator,
482
+ train_dataset=train_dataset,
483
+ eval_dataset=eval_dataset,
484
+ processing_class=processing_class,
485
+ model_init=model_init,
486
+ compute_metrics=compute_metrics,
487
+ callbacks=callbacks,
488
+ optimizers=optimizers,
489
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
490
+ )
491
+
492
+ # Add tags for models that have been loaded with the correct transformers version
493
+ if hasattr(self.model, "add_model_tags"):
494
+ self.model.add_model_tags(self._tag_names)
495
+
496
+ @staticmethod
497
+ def tokenize_row(
498
+ features,
499
+ tokenizer,
500
+ step_separator,
501
+ max_length,
502
+ max_prompt_length,
503
+ max_completion_length,
504
+ train_on_last_step_only,
505
+ is_eval,
506
+ ):
507
+ r"""
508
+ Tokenize a row of the dataset.
509
+
510
+ Args:
511
+ features (`dict[str, str]`):
512
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
513
+ tokenizer (`PreTrainedTokenizerBase`):
514
+ Tokenizer used to process the data.
515
+ step_separator (`str`):
516
+ Separator between steps in the completion.
517
+ max_length (`int` or `None`):
518
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
519
+ max_prompt_length (`int` or `None`):
520
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
521
+ max_completion_length (`int` or `None`):
522
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
523
+ train_on_last_step_only (`bool`):
524
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
525
+ token of the completion.
526
+ is_eval (`bool`):
527
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
528
+
529
+ Returns:
530
+ `dict[str, list[int]]`:
531
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
532
+
533
+ Example:
534
+ ```python
535
+ >>> from transformers import AutoTokenizer
536
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
537
+ >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
538
+ ... "completions": ["11 is greater than 8.",
539
+ ... "Hence, 9.11 > 9.8."],
540
+ ... "labels": [True, False]}
541
+ >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
542
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
543
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
544
+ ```
545
+ """
546
+ # Tokenize the prompt and completions
547
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
548
+ completions_ids = [
549
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
550
+ ]
551
+ if train_on_last_step_only and not is_eval:
552
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
553
+ else:
554
+ labels = [int(label) for label in features["labels"]]
555
+
556
+ # Get the ID of the separator token and add it to the completions
557
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
558
+ completions_ids = [completion + separator_ids for completion in completions_ids]
559
+
560
+ # Create the label
561
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
562
+
563
+ # Join the completions and labels steps
564
+ completion_ids = list(chain(*completions_ids))
565
+ labels = list(chain(*labels))
566
+
567
+ if tokenizer.bos_token_id is not None:
568
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
569
+
570
+ # Truncate prompt and completion sequences
571
+ if max_prompt_length is not None:
572
+ prompt_ids = prompt_ids[-max_prompt_length:]
573
+ if max_completion_length is not None:
574
+ completion_ids = completion_ids[:max_completion_length]
575
+ labels = labels[:max_completion_length]
576
+
577
+ input_ids = prompt_ids + completion_ids
578
+ labels = [-100] * len(prompt_ids) + labels
579
+
580
+ if max_length is not None:
581
+ input_ids = input_ids[:max_length]
582
+ labels = labels[:max_length]
583
+
584
+ return {"input_ids": input_ids, "labels": labels}
585
+
586
+ def create_model_card(
587
+ self,
588
+ model_name: Optional[str] = None,
589
+ dataset_name: Optional[str] = None,
590
+ tags: Union[str, list[str], None] = None,
591
+ ):
592
+ """
593
+ Creates a draft of a model card using the information available to the `Trainer`.
594
+
595
+ Args:
596
+ model_name (`str` or `None`, *optional*, defaults to `None`):
597
+ Name of the model.
598
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
599
+ Name of the dataset used for training.
600
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
601
+ Tags to be associated with the model card.
602
+ """
603
+ if not self.is_world_process_zero():
604
+ return
605
+
606
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
607
+ base_model = self.model.config._name_or_path
608
+ else:
609
+ base_model = None
610
+
611
+ tags = tags or []
612
+ if isinstance(tags, str):
613
+ tags = [tags]
614
+
615
+ if hasattr(self.model.config, "unsloth_version"):
616
+ tags.append("unsloth")
617
+
618
+ citation = textwrap.dedent("""\
619
+ @article{uesato2022solving,
620
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
621
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
622
+ year = 2022,
623
+ journal = {arXiv preprint arXiv:2211.14275}
624
+ }""")
625
+
626
+ model_card = generate_model_card(
627
+ base_model=base_model,
628
+ model_name=model_name,
629
+ hub_model_id=self.hub_model_id,
630
+ dataset_name=dataset_name,
631
+ tags=tags,
632
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
633
+ trainer_name="PRM",
634
+ trainer_citation=citation,
635
+ paper_title="Solving math word problems with process-and outcome-based feedback",
636
+ )
637
+
638
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
639
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
640
+ """
641
+
642
+ Initialize PRMTrainer.
643
+
644
+ Args:
645
+ model (`transformers.PreTrainedModel`):
646
+ The model to train, preferably an `AutoModelForTokenClassification`.
647
+ args (`PRMConfig`):
648
+ The arguments to use for training.
649
+ data_collator (`transformers.DataCollator`):
650
+ The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
651
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
652
+ train_dataset (`datasets.Dataset`):
653
+ The dataset to use for training.
654
+ eval_dataset (`datasets.Dataset`):
655
+ The dataset to use for evaluation.
656
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
657
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
658
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
659
+ reuse the fine-tuned model.
660
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
661
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
662
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
663
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
664
+ callbacks (`list[transformers.TrainerCallback]`):
665
+ The callbacks to use for training.
666
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
667
+ The optimizer and scheduler to use for training.
668
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
669
+ The function to use to preprocess the logits before computing the metrics.
670
+ peft_config (`dict`, defaults to `None`):
671
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
672
+
673
+ """
674
+ def __init__(
675
+ self,
676
+ model = None,
677
+ args = None,
678
+ data_collator = None,
679
+ train_dataset = None,
680
+ eval_dataset = None,
681
+ processing_class = None,
682
+ model_init = None,
683
+ compute_metrics = None,
684
+ callbacks = None,
685
+ preprocess_logits_for_metrics = None,
686
+ peft_config = None,
687
+ **kwargs
688
+ ):
689
+ if args is None: args = UnslothPRMConfig()
690
+ use_bf16 = getattr(args, 'bf16', False)
691
+ use_fp16 = getattr(args, 'fp16', False)
692
+ force_float32 = False
693
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
694
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
695
+ force_float32 = True
696
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
697
+ dtype = getattr(model.config, 'torch_dtype', None)
698
+ if dtype is None: dtype = model.get_input_embeddings().dtype
699
+ from unsloth_zoo.utils import _get_dtype
700
+ dtype = _get_dtype(dtype)
701
+ float16 = dtype == torch.float16
702
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
703
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
704
+ if force_float32:
705
+ args.fp16 = False
706
+ args.bf16 = False
707
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
708
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
709
+ args.fp16 = float16
710
+ args.bf16 = not float16
711
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
712
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
713
+ args.eval_strategy = 'steps'
714
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
715
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
716
+ if ga_steps is not None and ga_steps > 1:
717
+ from transformers import __version__ as transformers_version
718
+ if Version(transformers_version) <= Version('4.45.2'):
719
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
720
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
721
+ if getattr(args, 'eval_strategy', 'no') != 'no':
722
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
723
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
724
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
725
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
726
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
727
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
728
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
729
+ if force_float32:
730
+ args.bf16_full_eval = False
731
+ args.fp16_full_eval = False
732
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
733
+ args.bf16_full_eval = True
734
+ args.fp16_full_eval = False
735
+ elif not bf16_full_eval and not fp16_full_eval:
736
+ args.bf16_full_eval = args.bf16
737
+ args.fp16_full_eval = args.fp16
738
+ _output_logits = False
739
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
740
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
741
+ if _output_logits:
742
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
743
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
744
+ pass
745
+ else:
746
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
747
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
748
+ if args_max_seq_length is None and model_max_seq_length is not None:
749
+ max_seq_length = model.max_seq_length
750
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
751
+ if model is not None and hasattr(model, 'for_training'):
752
+ model.for_training()
753
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
754
+ if 'processing_class' in locals():
755
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
756
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
757
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
758
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
759
+ if not isinstance(data_collator, UnslothVisionDataCollator):
760
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
761
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
762
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
763
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
764
+ else:
765
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
766
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
767
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
768
+ if not isinstance(data_collator, UnslothVisionDataCollator):
769
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
770
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
771
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
772
+ else:
773
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
774
+ other_metrics = []
775
+
776
+ from unsloth_zoo.logging_utils import PatchRLStatistics
777
+ PatchRLStatistics('prm_trainer', other_metrics)
778
+
779
+ super().__init__(
780
+ model = model,
781
+ args = args,
782
+ data_collator = data_collator,
783
+ train_dataset = train_dataset,
784
+ eval_dataset = eval_dataset,
785
+ processing_class = processing_class,
786
+ model_init = model_init,
787
+ compute_metrics = compute_metrics,
788
+ callbacks = callbacks,
789
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
790
+ peft_config = peft_config,**kwargs)
791
+ if hasattr(self, 'neftune_hook_handle'):
792
+ self.neftune_hook_handle.remove()
793
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
794
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
795
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
796
+ pass
797
+
798
+ pass