Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .config/.last_opt_in_prompt.yaml +1 -0
- .config/.last_survey_prompt.yaml +1 -0
- .config/.last_update_check.json +1 -0
- .config/active_config +1 -0
- .config/config_sentinel +0 -0
- .config/configurations/config_default +6 -0
- .config/default_configs.db +0 -0
- .config/gce +1 -0
- .config/hidden_gcloud_config_universe_descriptor_data_cache_configs.db +0 -0
- .config/logs/2025.03.14/13.31.36.734686.log +765 -0
- .config/logs/2025.03.14/13.32.03.025824.log +5 -0
- .config/logs/2025.03.14/13.32.11.932574.log +153 -0
- .config/logs/2025.03.14/13.32.16.153180.log +5 -0
- .config/logs/2025.03.14/13.32.25.046318.log +8 -0
- .config/logs/2025.03.14/13.32.25.746375.log +8 -0
- .gitattributes +3 -0
- .gradio/certificate.pem +31 -0
- Gemma-Finetune/.gitignore +192 -0
- Gemma-Finetune/Gemma3_(4B).ipynb +0 -0
- Gemma-Finetune/LICENSE +21 -0
- Gemma-Finetune/README.md +41 -0
- Gemma-Finetune/main.py +295 -0
- Gemma-Finetune/requirements.txt +9 -0
- Gemma-Finetune/utils/__pycache__/check_dataset.cpython-311.pyc +0 -0
- Gemma-Finetune/utils/__pycache__/model.cpython-311.pyc +0 -0
- Gemma-Finetune/utils/__pycache__/sample_dataset.cpython-311.pyc +0 -0
- Gemma-Finetune/utils/check_dataset.py +272 -0
- Gemma-Finetune/utils/model.py +552 -0
- Gemma-Finetune/utils/sample_dataset.py +105 -0
- README.md +2 -8
- requirements.txt +6 -0
- sample_data/README.md +19 -0
- sample_data/anscombe.json +49 -0
- sample_data/california_housing_test.csv +0 -0
- sample_data/california_housing_train.csv +0 -0
- sample_data/mnist_test.csv +3 -0
- sample_data/mnist_train_small.csv +3 -0
- unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
- unsloth_compiled_cache/UnslothBCOTrainer.py +1822 -0
- unsloth_compiled_cache/UnslothCPOTrainer.py +1555 -0
- unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
- unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothGKDTrainer.py +861 -0
- unsloth_compiled_cache/UnslothGRPOTrainer.py +1436 -0
- unsloth_compiled_cache/UnslothKTOTrainer.py +1838 -0
- unsloth_compiled_cache/UnslothNashMDTrainer.py +953 -0
- unsloth_compiled_cache/UnslothORPOTrainer.py +1541 -0
- unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1267 -0
- unsloth_compiled_cache/UnslothPPOTrainer.py +1257 -0
- unsloth_compiled_cache/UnslothPRMTrainer.py +798 -0
.config/.last_opt_in_prompt.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
.config/.last_survey_prompt.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
last_prompt_time: 1741959131.3035824
|
.config/.last_update_check.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"last_update_check_time": 1741959135.630771, "last_update_check_revision": 20250307152352, "notifications": [], "last_nag_times": {}}
|
.config/active_config
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
default
|
.config/config_sentinel
ADDED
File without changes
|
.config/configurations/config_default
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[component_manager]
|
2 |
+
disable_update_check = true
|
3 |
+
|
4 |
+
[compute]
|
5 |
+
gce_metadata_read_timeout_sec = 0
|
6 |
+
|
.config/default_configs.db
ADDED
Binary file (12.3 kB). View file
|
|
.config/gce
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
False
|
.config/hidden_gcloud_config_universe_descriptor_data_cache_configs.db
ADDED
Binary file (12.3 kB). View file
|
|
.config/logs/2025.03.14/13.31.36.734686.log
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:31:48,760 DEBUG root Loaded Command Group: ['gcloud', 'components']
|
2 |
+
2025-03-14 13:31:48,763 DEBUG root Loaded Command Group: ['gcloud', 'components', 'update']
|
3 |
+
2025-03-14 13:31:48,765 DEBUG root Running [gcloud.components.update] with arguments: [--compile-python: "True", --quiet: "True", COMPONENT-IDS:6: "['core', 'gcloud-deps', 'bq', 'gcloud', 'gcloud-crc32c', 'gsutil']"]
|
4 |
+
2025-03-14 13:31:48,766 INFO ___FILE_ONLY___ Beginning update. This process may take several minutes.
|
5 |
+
|
6 |
+
2025-03-14 13:31:48,801 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
7 |
+
2025-03-14 13:31:49,745 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components-2.json HTTP/11" 200 226132
|
8 |
+
2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
|
9 |
+
|
10 |
+
2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
|
11 |
+
Your current Google Cloud CLI version is: 514.0.0
|
12 |
+
|
13 |
+
2025-03-14 13:31:49,757 INFO ___FILE_ONLY___ Installing components from version: 514.0.0
|
14 |
+
|
15 |
+
2025-03-14 13:31:49,757 INFO ___FILE_ONLY___
|
16 |
+
|
17 |
+
2025-03-14 13:31:49,757 DEBUG root Chosen display Format:table[box,title="These components will be removed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
18 |
+
2025-03-14 13:31:49,758 DEBUG root Chosen display Format:table[box,title="These components will be updated."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
19 |
+
2025-03-14 13:31:49,758 DEBUG root Chosen display Format:table[box,title="These components will be installed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
20 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ┌─────────────────────────────────────────────────────────────────────────────┐
|
21 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
|
22 |
+
|
23 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │ These components will be installed. │
|
24 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
|
25 |
+
|
26 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ├─────────────────────────────────────────────────────┬────────────┬──────────┤
|
27 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
|
28 |
+
|
29 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │ Name │ Version │ Size │
|
30 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
|
31 |
+
|
32 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ ├─────────────────────────────────────────────────────┼────────────┼──────────┤
|
33 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___
|
34 |
+
|
35 |
+
2025-03-14 13:31:49,795 INFO ___FILE_ONLY___ │
|
36 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ BigQuery Command Line Tool
|
37 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
38 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
39 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 2.1.14
|
40 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
41 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
42 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 1.8 MiB
|
43 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
44 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
45 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
46 |
+
|
47 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
48 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ BigQuery Command Line Tool (Platform Specific)
|
49 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
50 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
51 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ 2.1.8
|
52 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
53 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
54 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ < 1 MiB
|
55 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
56 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___ │
|
57 |
+
2025-03-14 13:31:49,796 INFO ___FILE_ONLY___
|
58 |
+
|
59 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
60 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ Bundled Python 3.12 (Platform Specific)
|
61 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
62 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
63 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 3.12.8
|
64 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
65 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
66 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 89.2 MiB
|
67 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
68 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
69 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
70 |
+
|
71 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
72 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ Cloud Storage Command Line Tool
|
73 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
74 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
75 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 5.33
|
76 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
77 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
78 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ 11.8 MiB
|
79 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
80 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___ │
|
81 |
+
2025-03-14 13:31:49,797 INFO ___FILE_ONLY___
|
82 |
+
|
83 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
84 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ Cloud Storage Command Line Tool (Platform Specific)
|
85 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
86 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
87 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ 5.30
|
88 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
89 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
90 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ < 1 MiB
|
91 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
92 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
93 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
94 |
+
|
95 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
96 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ Google Cloud CLI Core Libraries (Platform Specific)
|
97 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
98 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
99 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ 2024.08.30
|
100 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
101 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ │
|
102 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___ < 1 MiB
|
103 |
+
2025-03-14 13:31:49,798 INFO ___FILE_ONLY___
|
104 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
105 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
106 |
+
|
107 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
108 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ Google Cloud CRC32C Hash Tool (Platform Specific)
|
109 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
110 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
111 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 1.0.0
|
112 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
113 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
114 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 1.4 MiB
|
115 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
116 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
117 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
118 |
+
|
119 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
120 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ gcloud cli dependencies (Platform Specific)
|
121 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
122 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
123 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ 2021.04.16
|
124 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
125 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ │
|
126 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___ < 1 MiB
|
127 |
+
2025-03-14 13:31:49,799 INFO ___FILE_ONLY___
|
128 |
+
2025-03-14 13:31:49,800 INFO ___FILE_ONLY___ │
|
129 |
+
2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
|
130 |
+
|
131 |
+
2025-03-14 13:31:49,800 INFO ___FILE_ONLY___ └─────────────────────────────────────────────────────┴────────────┴──────────┘
|
132 |
+
2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
|
133 |
+
|
134 |
+
2025-03-14 13:31:49,800 INFO ___FILE_ONLY___
|
135 |
+
|
136 |
+
2025-03-14 13:31:49,803 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
137 |
+
2025-03-14 13:31:49,983 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/RELEASE_NOTES HTTP/11" 200 1377050
|
138 |
+
2025-03-14 13:31:50,443 INFO ___FILE_ONLY___ For the latest full release notes, please visit:
|
139 |
+
https://cloud.google.com/sdk/release_notes
|
140 |
+
|
141 |
+
|
142 |
+
2025-03-14 13:31:50,443 INFO ___FILE_ONLY___ Performing in place update...
|
143 |
+
|
144 |
+
|
145 |
+
2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
146 |
+
|
147 |
+
2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╠═ Downloading: BigQuery Command Line Tool ═╣
|
148 |
+
|
149 |
+
2025-03-14 13:31:50,445 INFO ___FILE_ONLY___ ╚
|
150 |
+
2025-03-14 13:31:50,449 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
151 |
+
2025-03-14 13:31:51,415 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bq-20250228155416.tar.gz HTTP/11" 200 1845321
|
152 |
+
2025-03-14 13:31:51,427 INFO ___FILE_ONLY___ ═
|
153 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
154 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
155 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
156 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
157 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
158 |
+
2025-03-14 13:31:51,428 INFO ___FILE_ONLY___ ═
|
159 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
160 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
161 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
162 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
163 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
164 |
+
2025-03-14 13:31:51,429 INFO ___FILE_ONLY___ ═
|
165 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
166 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
167 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
168 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
169 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
170 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
171 |
+
2025-03-14 13:31:51,430 INFO ___FILE_ONLY___ ═
|
172 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
173 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
174 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
175 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
176 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
177 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
178 |
+
2025-03-14 13:31:51,431 INFO ___FILE_ONLY___ ═
|
179 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
180 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
181 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
182 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
183 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
184 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
185 |
+
2025-03-14 13:31:51,432 INFO ___FILE_ONLY___ ═
|
186 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
187 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
188 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
189 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
190 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
191 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
192 |
+
2025-03-14 13:31:51,433 INFO ___FILE_ONLY___ ═
|
193 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
194 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
195 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
196 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
197 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
198 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
199 |
+
2025-03-14 13:31:51,434 INFO ___FILE_ONLY___ ═
|
200 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
201 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
202 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
203 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
204 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
205 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
206 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
207 |
+
2025-03-14 13:31:51,435 INFO ___FILE_ONLY___ ═
|
208 |
+
2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
|
209 |
+
2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
|
210 |
+
2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
|
211 |
+
2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ═
|
212 |
+
2025-03-14 13:31:51,436 INFO ___FILE_ONLY___ ╝
|
213 |
+
|
214 |
+
2025-03-14 13:31:51,438 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
215 |
+
|
216 |
+
2025-03-14 13:31:51,438 INFO ___FILE_ONLY___ ╠═ Downloading: BigQuery Command Line Tool (Platform Spe... ═╣
|
217 |
+
|
218 |
+
2025-03-14 13:31:51,439 INFO ___FILE_ONLY___ ╚
|
219 |
+
2025-03-14 13:31:51,443 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
220 |
+
2025-03-14 13:31:52,383 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bq-nix-20240830134514.tar.gz HTTP/11" 200 1914
|
221 |
+
2025-03-14 13:31:52,384 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
222 |
+
2025-03-14 13:31:52,384 INFO ___FILE_ONLY___ ╝
|
223 |
+
|
224 |
+
2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
225 |
+
|
226 |
+
2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╠═ Downloading: Bundled Python 3.12 ═╣
|
227 |
+
|
228 |
+
2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╚
|
229 |
+
2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
230 |
+
2025-03-14 13:31:52,386 INFO ___FILE_ONLY___ ╝
|
231 |
+
|
232 |
+
2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════���═══════════════╗
|
233 |
+
|
234 |
+
2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╠═ Downloading: Bundled Python 3.12 (Platform Specific) ═╣
|
235 |
+
|
236 |
+
2025-03-14 13:31:52,388 INFO ___FILE_ONLY___ ╚
|
237 |
+
2025-03-14 13:31:52,392 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
238 |
+
2025-03-14 13:31:53,292 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-bundled-python3-unix-linux-x86_64-20250131143518.tar.gz HTTP/11" 200 93520256
|
239 |
+
2025-03-14 13:31:53,624 INFO ___FILE_ONLY___ ═
|
240 |
+
2025-03-14 13:31:53,626 INFO ___FILE_ONLY___ ═
|
241 |
+
2025-03-14 13:31:53,628 INFO ___FILE_ONLY___ ═
|
242 |
+
2025-03-14 13:31:53,630 INFO ___FILE_ONLY___ ═
|
243 |
+
2025-03-14 13:31:53,632 INFO ___FILE_ONLY___ ═
|
244 |
+
2025-03-14 13:31:53,634 INFO ___FILE_ONLY___ ═
|
245 |
+
2025-03-14 13:31:53,636 INFO ___FILE_ONLY___ ═
|
246 |
+
2025-03-14 13:31:53,638 INFO ___FILE_ONLY___ ═
|
247 |
+
2025-03-14 13:31:53,640 INFO ___FILE_ONLY___ ═
|
248 |
+
2025-03-14 13:31:53,642 INFO ___FILE_ONLY___ ═
|
249 |
+
2025-03-14 13:31:53,644 INFO ___FILE_ONLY___ ═
|
250 |
+
2025-03-14 13:31:53,646 INFO ___FILE_ONLY___ ═
|
251 |
+
2025-03-14 13:31:53,648 INFO ___FILE_ONLY___ ═
|
252 |
+
2025-03-14 13:31:53,650 INFO ___FILE_ONLY___ ═
|
253 |
+
2025-03-14 13:31:53,652 INFO ___FILE_ONLY___ ═
|
254 |
+
2025-03-14 13:31:53,654 INFO ___FILE_ONLY___ ═
|
255 |
+
2025-03-14 13:31:53,656 INFO ___FILE_ONLY___ ═
|
256 |
+
2025-03-14 13:31:53,658 INFO ___FILE_ONLY___ ═
|
257 |
+
2025-03-14 13:31:53,660 INFO ___FILE_ONLY___ ═
|
258 |
+
2025-03-14 13:31:53,662 INFO ___FILE_ONLY___ ═
|
259 |
+
2025-03-14 13:31:53,664 INFO ___FILE_ONLY___ ═
|
260 |
+
2025-03-14 13:31:53,666 INFO ___FILE_ONLY___ ═
|
261 |
+
2025-03-14 13:31:53,668 INFO ___FILE_ONLY___ ═
|
262 |
+
2025-03-14 13:31:53,670 INFO ___FILE_ONLY___ ═
|
263 |
+
2025-03-14 13:31:53,672 INFO ___FILE_ONLY___ ═
|
264 |
+
2025-03-14 13:31:53,674 INFO ___FILE_ONLY___ ═
|
265 |
+
2025-03-14 13:31:53,676 INFO ___FILE_ONLY___ ═
|
266 |
+
2025-03-14 13:31:53,678 INFO ___FILE_ONLY___ ═
|
267 |
+
2025-03-14 13:31:53,680 INFO ___FILE_ONLY___ ═
|
268 |
+
2025-03-14 13:31:53,682 INFO ___FILE_ONLY___ ═
|
269 |
+
2025-03-14 13:31:53,684 INFO ___FILE_ONLY___ ═
|
270 |
+
2025-03-14 13:31:53,686 INFO ___FILE_ONLY___ ═
|
271 |
+
2025-03-14 13:31:53,688 INFO ___FILE_ONLY___ ═
|
272 |
+
2025-03-14 13:31:53,689 INFO ___FILE_ONLY___ ═
|
273 |
+
2025-03-14 13:31:53,691 INFO ___FILE_ONLY___ ═
|
274 |
+
2025-03-14 13:31:53,693 INFO ___FILE_ONLY___ ═
|
275 |
+
2025-03-14 13:31:53,695 INFO ___FILE_ONLY___ ═
|
276 |
+
2025-03-14 13:31:53,697 INFO ___FILE_ONLY___ ═
|
277 |
+
2025-03-14 13:31:53,699 INFO ___FILE_ONLY___ ═
|
278 |
+
2025-03-14 13:31:53,701 INFO ___FILE_ONLY___ ═
|
279 |
+
2025-03-14 13:31:53,703 INFO ___FILE_ONLY___ ═
|
280 |
+
2025-03-14 13:31:53,705 INFO ___FILE_ONLY___ ═
|
281 |
+
2025-03-14 13:31:53,707 INFO ___FILE_ONLY___ ═
|
282 |
+
2025-03-14 13:31:53,709 INFO ___FILE_ONLY___ ═
|
283 |
+
2025-03-14 13:31:53,711 INFO ___FILE_ONLY___ ═
|
284 |
+
2025-03-14 13:31:53,713 INFO ___FILE_ONLY___ ═
|
285 |
+
2025-03-14 13:31:53,715 INFO ___FILE_ONLY___ ═
|
286 |
+
2025-03-14 13:31:53,717 INFO ___FILE_ONLY___ ═
|
287 |
+
2025-03-14 13:31:53,719 INFO ___FILE_ONLY___ ═
|
288 |
+
2025-03-14 13:31:53,721 INFO ___FILE_ONLY___ ═
|
289 |
+
2025-03-14 13:31:53,723 INFO ___FILE_ONLY___ ═
|
290 |
+
2025-03-14 13:31:53,725 INFO ___FILE_ONLY___ ═
|
291 |
+
2025-03-14 13:31:53,727 INFO ___FILE_ONLY___ ═
|
292 |
+
2025-03-14 13:31:53,729 INFO ___FILE_ONLY___ ═
|
293 |
+
2025-03-14 13:31:53,731 INFO ___FILE_ONLY___ ═
|
294 |
+
2025-03-14 13:31:53,732 INFO ___FILE_ONLY___ ═
|
295 |
+
2025-03-14 13:31:53,734 INFO ___FILE_ONLY___ ═
|
296 |
+
2025-03-14 13:31:53,736 INFO ___FILE_ONLY___ ═
|
297 |
+
2025-03-14 13:31:53,738 INFO ___FILE_ONLY___ ═
|
298 |
+
2025-03-14 13:31:53,740 INFO ___FILE_ONLY___ ═
|
299 |
+
2025-03-14 13:31:53,741 INFO ___FILE_ONLY___ ╝
|
300 |
+
|
301 |
+
2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
302 |
+
|
303 |
+
2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╠═ Downloading: Cloud Storage Command Line Tool ═╣
|
304 |
+
|
305 |
+
2025-03-14 13:31:53,744 INFO ___FILE_ONLY___ ╚
|
306 |
+
2025-03-14 13:31:53,747 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
307 |
+
2025-03-14 13:31:53,929 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gsutil-20241213184646.tar.gz HTTP/11" 200 12368759
|
308 |
+
2025-03-14 13:31:53,995 INFO ___FILE_ONLY___ ═
|
309 |
+
2025-03-14 13:31:53,995 INFO ___FILE_ONLY___ ═
|
310 |
+
2025-03-14 13:31:53,996 INFO ___FILE_ONLY___ ═
|
311 |
+
2025-03-14 13:31:53,996 INFO ___FILE_ONLY___ ═
|
312 |
+
2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
|
313 |
+
2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
|
314 |
+
2025-03-14 13:31:53,997 INFO ___FILE_ONLY___ ═
|
315 |
+
2025-03-14 13:31:53,998 INFO ___FILE_ONLY___ ═
|
316 |
+
2025-03-14 13:31:53,998 INFO ___FILE_ONLY___ ═
|
317 |
+
2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
|
318 |
+
2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
|
319 |
+
2025-03-14 13:31:53,999 INFO ___FILE_ONLY___ ═
|
320 |
+
2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
|
321 |
+
2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
|
322 |
+
2025-03-14 13:31:54,000 INFO ___FILE_ONLY___ ═
|
323 |
+
2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
|
324 |
+
2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
|
325 |
+
2025-03-14 13:31:54,001 INFO ___FILE_ONLY___ ═
|
326 |
+
2025-03-14 13:31:54,002 INFO ___FILE_ONLY___ ═
|
327 |
+
2025-03-14 13:31:54,002 INFO ___FILE_ONLY___ ═
|
328 |
+
2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
|
329 |
+
2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
|
330 |
+
2025-03-14 13:31:54,003 INFO ___FILE_ONLY___ ═
|
331 |
+
2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
|
332 |
+
2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
|
333 |
+
2025-03-14 13:31:54,004 INFO ___FILE_ONLY___ ═
|
334 |
+
2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
|
335 |
+
2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
|
336 |
+
2025-03-14 13:31:54,005 INFO ___FILE_ONLY___ ═
|
337 |
+
2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
|
338 |
+
2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
|
339 |
+
2025-03-14 13:31:54,006 INFO ___FILE_ONLY___ ═
|
340 |
+
2025-03-14 13:31:54,007 INFO ___FILE_ONLY___ ═
|
341 |
+
2025-03-14 13:31:54,007 INFO ___FILE_ONLY___ ═
|
342 |
+
2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
|
343 |
+
2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
|
344 |
+
2025-03-14 13:31:54,008 INFO ___FILE_ONLY___ ═
|
345 |
+
2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
|
346 |
+
2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
|
347 |
+
2025-03-14 13:31:54,009 INFO ___FILE_ONLY___ ═
|
348 |
+
2025-03-14 13:31:54,010 INFO ___FILE_ONLY___ ═
|
349 |
+
2025-03-14 13:31:54,010 INFO ___FILE_ONLY___ ═
|
350 |
+
2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
|
351 |
+
2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
|
352 |
+
2025-03-14 13:31:54,011 INFO ___FILE_ONLY___ ═
|
353 |
+
2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
|
354 |
+
2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
|
355 |
+
2025-03-14 13:31:54,012 INFO ___FILE_ONLY___ ═
|
356 |
+
2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
|
357 |
+
2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
|
358 |
+
2025-03-14 13:31:54,013 INFO ___FILE_ONLY___ ═
|
359 |
+
2025-03-14 13:31:54,014 INFO ___FILE_ONLY___ ═
|
360 |
+
2025-03-14 13:31:54,014 INFO ___FILE_ONLY___ ═
|
361 |
+
2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
|
362 |
+
2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
|
363 |
+
2025-03-14 13:31:54,015 INFO ___FILE_ONLY___ ═
|
364 |
+
2025-03-14 13:31:54,016 INFO ___FILE_ONLY___ ═
|
365 |
+
2025-03-14 13:31:54,016 INFO ___FILE_ONLY___ ═
|
366 |
+
2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ═
|
367 |
+
2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ═
|
368 |
+
2025-03-14 13:31:54,017 INFO ___FILE_ONLY___ ╝
|
369 |
+
|
370 |
+
2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
371 |
+
|
372 |
+
2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╠═ Downloading: Cloud Storage Command Line Tool (Platfor... ═╣
|
373 |
+
|
374 |
+
2025-03-14 13:31:54,019 INFO ___FILE_ONLY___ ╚
|
375 |
+
2025-03-14 13:31:54,023 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
376 |
+
2025-03-14 13:31:54,217 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gsutil-nix-20240830134514.tar.gz HTTP/11" 200 1928
|
377 |
+
2025-03-14 13:31:54,217 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
378 |
+
2025-03-14 13:31:54,217 INFO ___FILE_ONLY___ ╝
|
379 |
+
|
380 |
+
2025-03-14 13:31:54,219 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
381 |
+
|
382 |
+
2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╠═ Downloading: Default set of gcloud commands ═╣
|
383 |
+
|
384 |
+
2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╚
|
385 |
+
2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
386 |
+
2025-03-14 13:31:54,220 INFO ___FILE_ONLY___ ╝
|
387 |
+
|
388 |
+
2025-03-14 13:31:54,221 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
389 |
+
|
390 |
+
2025-03-14 13:31:54,222 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CLI Core Libraries (Platfor... ═╣
|
391 |
+
|
392 |
+
2025-03-14 13:31:54,222 INFO ___FILE_ONLY___ ╚
|
393 |
+
2025-03-14 13:31:54,225 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
394 |
+
2025-03-14 13:31:55,133 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-core-nix-20240830134514.tar.gz HTTP/11" 200 2306
|
395 |
+
2025-03-14 13:31:55,134 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
396 |
+
2025-03-14 13:31:55,134 INFO ___FILE_ONLY___ ╝
|
397 |
+
|
398 |
+
2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
399 |
+
|
400 |
+
2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CRC32C Hash Tool ═╣
|
401 |
+
|
402 |
+
2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╚
|
403 |
+
2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
404 |
+
2025-03-14 13:31:55,136 INFO ___FILE_ONLY___ ╝
|
405 |
+
|
406 |
+
2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
407 |
+
|
408 |
+
2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╠═ Downloading: Google Cloud CRC32C Hash Tool (Platform ... ═╣
|
409 |
+
|
410 |
+
2025-03-14 13:31:55,138 INFO ___FILE_ONLY___ ╚
|
411 |
+
2025-03-14 13:31:55,141 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
412 |
+
2025-03-14 13:31:55,315 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gcloud-crc32c-linux-x86_64-20250110133808.tar.gz HTTP/11" 200 1478989
|
413 |
+
2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
|
414 |
+
2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
|
415 |
+
2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
|
416 |
+
2025-03-14 13:31:55,326 INFO ___FILE_ONLY___ ═
|
417 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
418 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
419 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
420 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
421 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
422 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
423 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
424 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
425 |
+
2025-03-14 13:31:55,327 INFO ___FILE_ONLY___ ═
|
426 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
427 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
428 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
429 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
430 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
431 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
432 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
433 |
+
2025-03-14 13:31:55,328 INFO ___FILE_ONLY___ ═
|
434 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
435 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
436 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
437 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
438 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
439 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
440 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
441 |
+
2025-03-14 13:31:55,329 INFO ___FILE_ONLY___ ═
|
442 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
443 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
444 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
445 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
446 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
447 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
448 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
449 |
+
2025-03-14 13:31:55,330 INFO ___FILE_ONLY___ ═
|
450 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
451 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
452 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
453 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
454 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
455 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
456 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
457 |
+
2025-03-14 13:31:55,331 INFO ___FILE_ONLY___ ═
|
458 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
459 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
460 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
461 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
462 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
463 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
464 |
+
2025-03-14 13:31:55,332 INFO ___FILE_ONLY___ ═
|
465 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
466 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
467 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
468 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
469 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
470 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
471 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
472 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ═
|
473 |
+
2025-03-14 13:31:55,333 INFO ___FILE_ONLY___ ╝
|
474 |
+
|
475 |
+
2025-03-14 13:31:55,335 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
476 |
+
|
477 |
+
2025-03-14 13:31:55,336 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud cli dependencies (Platform Specific) ═╣
|
478 |
+
|
479 |
+
2025-03-14 13:31:55,336 INFO ___FILE_ONLY___ ╚
|
480 |
+
2025-03-14 13:31:55,339 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
481 |
+
2025-03-14 13:31:55,513 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-gcloud-deps-linux-x86_64-20210416153011.tar.gz HTTP/11" 200 104
|
482 |
+
2025-03-14 13:31:55,515 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
483 |
+
2025-03-14 13:31:55,515 INFO ___FILE_ONLY___ ╝
|
484 |
+
|
485 |
+
2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
486 |
+
|
487 |
+
2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╠═ Installing: BigQuery Command Line Tool ═╣
|
488 |
+
|
489 |
+
2025-03-14 13:31:55,517 INFO ___FILE_ONLY___ ╚
|
490 |
+
2025-03-14 13:31:55,609 INFO ___FILE_ONLY___ ═
|
491 |
+
2025-03-14 13:31:55,611 INFO ___FILE_ONLY___ ═
|
492 |
+
2025-03-14 13:31:55,614 INFO ___FILE_ONLY___ ═
|
493 |
+
2025-03-14 13:31:55,616 INFO ___FILE_ONLY___ ═
|
494 |
+
2025-03-14 13:31:55,619 INFO ___FILE_ONLY___ ═
|
495 |
+
2025-03-14 13:31:55,621 INFO ___FILE_ONLY___ ═
|
496 |
+
2025-03-14 13:31:55,624 INFO ___FILE_ONLY___ ═
|
497 |
+
2025-03-14 13:31:55,626 INFO ___FILE_ONLY___ ═
|
498 |
+
2025-03-14 13:31:55,629 INFO ___FILE_ONLY___ ═
|
499 |
+
2025-03-14 13:31:55,631 INFO ___FILE_ONLY___ ═
|
500 |
+
2025-03-14 13:31:55,633 INFO ___FILE_ONLY___ ═
|
501 |
+
2025-03-14 13:31:55,636 INFO ___FILE_ONLY___ ═
|
502 |
+
2025-03-14 13:31:55,638 INFO ___FILE_ONLY___ ═
|
503 |
+
2025-03-14 13:31:55,640 INFO ___FILE_ONLY___ ═
|
504 |
+
2025-03-14 13:31:55,643 INFO ___FILE_ONLY___ ═
|
505 |
+
2025-03-14 13:31:55,645 INFO ___FILE_ONLY___ ═
|
506 |
+
2025-03-14 13:31:55,647 INFO ___FILE_ONLY___ ═
|
507 |
+
2025-03-14 13:31:55,649 INFO ___FILE_ONLY___ ═
|
508 |
+
2025-03-14 13:31:55,653 INFO ___FILE_ONLY___ ═
|
509 |
+
2025-03-14 13:31:55,656 INFO ___FILE_ONLY___ ═
|
510 |
+
2025-03-14 13:31:55,658 INFO ___FILE_ONLY___ ═
|
511 |
+
2025-03-14 13:31:55,661 INFO ___FILE_ONLY___ ═
|
512 |
+
2025-03-14 13:31:55,663 INFO ___FILE_ONLY___ ═
|
513 |
+
2025-03-14 13:31:55,665 INFO ___FILE_ONLY___ ═
|
514 |
+
2025-03-14 13:31:55,667 INFO ___FILE_ONLY___ ═
|
515 |
+
2025-03-14 13:31:55,669 INFO ___FILE_ONLY___ ═
|
516 |
+
2025-03-14 13:31:55,671 INFO ___FILE_ONLY___ ═
|
517 |
+
2025-03-14 13:31:55,674 INFO ___FILE_ONLY___ ═
|
518 |
+
2025-03-14 13:31:55,678 INFO ___FILE_ONLY___ ═
|
519 |
+
2025-03-14 13:31:55,680 INFO ___FILE_ONLY___ ═
|
520 |
+
2025-03-14 13:31:55,682 INFO ___FILE_ONLY___ ═
|
521 |
+
2025-03-14 13:31:55,684 INFO ___FILE_ONLY___ ═
|
522 |
+
2025-03-14 13:31:55,688 INFO ___FILE_ONLY___ ═
|
523 |
+
2025-03-14 13:31:55,690 INFO ___FILE_ONLY___ ═
|
524 |
+
2025-03-14 13:31:55,698 INFO ___FILE_ONLY___ ═
|
525 |
+
2025-03-14 13:31:55,703 INFO ___FILE_ONLY___ ═
|
526 |
+
2025-03-14 13:31:55,709 INFO ___FILE_ONLY___ ═
|
527 |
+
2025-03-14 13:31:55,711 INFO ___FILE_ONLY___ ═
|
528 |
+
2025-03-14 13:31:55,713 INFO ___FILE_ONLY___ ═
|
529 |
+
2025-03-14 13:31:55,716 INFO ___FILE_ONLY___ ═
|
530 |
+
2025-03-14 13:31:55,718 INFO ___FILE_ONLY___ ═
|
531 |
+
2025-03-14 13:31:55,721 INFO ___FILE_ONLY___ ═
|
532 |
+
2025-03-14 13:31:55,726 INFO ___FILE_ONLY___ ═
|
533 |
+
2025-03-14 13:31:55,728 INFO ___FILE_ONLY___ ═
|
534 |
+
2025-03-14 13:31:55,731 INFO ___FILE_ONLY___ ═
|
535 |
+
2025-03-14 13:31:55,732 INFO ___FILE_ONLY___ ═
|
536 |
+
2025-03-14 13:31:55,735 INFO ___FILE_ONLY___ ═
|
537 |
+
2025-03-14 13:31:55,737 INFO ___FILE_ONLY___ ═
|
538 |
+
2025-03-14 13:31:55,739 INFO ___FILE_ONLY___ ═
|
539 |
+
2025-03-14 13:31:55,741 INFO ___FILE_ONLY___ ═
|
540 |
+
2025-03-14 13:31:55,744 INFO ___FILE_ONLY___ ═
|
541 |
+
2025-03-14 13:31:55,746 INFO ___FILE_ONLY___ ═
|
542 |
+
2025-03-14 13:31:55,748 INFO ___FILE_ONLY___ ═
|
543 |
+
2025-03-14 13:31:55,750 INFO ___FILE_ONLY___ ═
|
544 |
+
2025-03-14 13:31:55,752 INFO ___FILE_ONLY___ ═
|
545 |
+
2025-03-14 13:31:55,754 INFO ___FILE_ONLY___ ═
|
546 |
+
2025-03-14 13:31:55,756 INFO ___FILE_ONLY___ ═
|
547 |
+
2025-03-14 13:31:55,759 INFO ___FILE_ONLY___ ═
|
548 |
+
2025-03-14 13:31:55,760 INFO ___FILE_ONLY___ ═
|
549 |
+
2025-03-14 13:31:55,763 INFO ___FILE_ONLY___ ═
|
550 |
+
2025-03-14 13:31:55,763 INFO ___FILE_ONLY___ ╝
|
551 |
+
|
552 |
+
2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
553 |
+
|
554 |
+
2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╠═ Installing: BigQuery Command Line Tool (Platform Spec... ═╣
|
555 |
+
|
556 |
+
2025-03-14 13:31:55,771 INFO ___FILE_ONLY___ ╚
|
557 |
+
2025-03-14 13:31:55,772 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
558 |
+
2025-03-14 13:31:55,772 INFO ___FILE_ONLY___ ╝
|
559 |
+
|
560 |
+
2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
561 |
+
|
562 |
+
2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╠═ Installing: Bundled Python 3.12 ═╣
|
563 |
+
|
564 |
+
2025-03-14 13:31:55,776 INFO ___FILE_ONLY___ ╚
|
565 |
+
2025-03-14 13:31:55,780 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
566 |
+
2025-03-14 13:31:55,780 INFO ___FILE_ONLY___ ╝
|
567 |
+
|
568 |
+
2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
569 |
+
|
570 |
+
2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╠═ Installing: Bundled Python 3.12 (Platform Specific) ═╣
|
571 |
+
|
572 |
+
2025-03-14 13:31:55,782 INFO ___FILE_ONLY___ ╚
|
573 |
+
2025-03-14 13:31:58,899 INFO ___FILE_ONLY___ ═
|
574 |
+
2025-03-14 13:31:59,398 INFO ___FILE_ONLY___ ═
|
575 |
+
2025-03-14 13:31:59,419 INFO ___FILE_ONLY___ ═
|
576 |
+
2025-03-14 13:31:59,456 INFO ___FILE_ONLY___ ═
|
577 |
+
2025-03-14 13:31:59,479 INFO ___FILE_ONLY___ ═
|
578 |
+
2025-03-14 13:31:59,496 INFO ___FILE_ONLY___ ═
|
579 |
+
2025-03-14 13:31:59,528 INFO ___FILE_ONLY___ ═
|
580 |
+
2025-03-14 13:31:59,550 INFO ___FILE_ONLY___ ═
|
581 |
+
2025-03-14 13:31:59,569 INFO ___FILE_ONLY___ ═
|
582 |
+
2025-03-14 13:31:59,588 INFO ___FILE_ONLY___ ═
|
583 |
+
2025-03-14 13:31:59,604 INFO ___FILE_ONLY___ ═
|
584 |
+
2025-03-14 13:31:59,632 INFO ___FILE_ONLY___ ═
|
585 |
+
2025-03-14 13:31:59,775 INFO ___FILE_ONLY___ ═
|
586 |
+
2025-03-14 13:31:59,793 INFO ___FILE_ONLY___ ═
|
587 |
+
2025-03-14 13:31:59,810 INFO ___FILE_ONLY___ ═
|
588 |
+
2025-03-14 13:31:59,827 INFO ___FILE_ONLY___ ═
|
589 |
+
2025-03-14 13:31:59,842 INFO ___FILE_ONLY___ ═
|
590 |
+
2025-03-14 13:31:59,861 INFO ___FILE_ONLY___ ═
|
591 |
+
2025-03-14 13:31:59,879 INFO ___FILE_ONLY___ ═
|
592 |
+
2025-03-14 13:31:59,900 INFO ___FILE_ONLY___ ═
|
593 |
+
2025-03-14 13:31:59,917 INFO ___FILE_ONLY___ ═
|
594 |
+
2025-03-14 13:32:00,014 INFO ___FILE_ONLY___ ═
|
595 |
+
2025-03-14 13:32:00,038 INFO ___FILE_ONLY___ ═
|
596 |
+
2025-03-14 13:32:00,539 INFO ___FILE_ONLY___ ═
|
597 |
+
2025-03-14 13:32:00,556 INFO ___FILE_ONLY___ ═
|
598 |
+
2025-03-14 13:32:00,571 INFO ___FILE_ONLY___ ═
|
599 |
+
2025-03-14 13:32:00,584 INFO ___FILE_ONLY___ ═
|
600 |
+
2025-03-14 13:32:00,597 INFO ___FILE_ONLY___ ═
|
601 |
+
2025-03-14 13:32:00,610 INFO ___FILE_ONLY___ ═
|
602 |
+
2025-03-14 13:32:00,623 INFO ___FILE_ONLY___ ═
|
603 |
+
2025-03-14 13:32:00,635 INFO ___FILE_ONLY___ ═
|
604 |
+
2025-03-14 13:32:00,647 INFO ___FILE_ONLY___ ═
|
605 |
+
2025-03-14 13:32:00,659 INFO ___FILE_ONLY___ ═
|
606 |
+
2025-03-14 13:32:00,672 INFO ___FILE_ONLY___ ═
|
607 |
+
2025-03-14 13:32:00,684 INFO ___FILE_ONLY___ ═
|
608 |
+
2025-03-14 13:32:00,697 INFO ___FILE_ONLY___ ═
|
609 |
+
2025-03-14 13:32:00,711 INFO ___FILE_ONLY___ ═
|
610 |
+
2025-03-14 13:32:00,724 INFO ___FILE_ONLY___ ═
|
611 |
+
2025-03-14 13:32:00,737 INFO ___FILE_ONLY___ ═
|
612 |
+
2025-03-14 13:32:00,750 INFO ___FILE_ONLY___ ═
|
613 |
+
2025-03-14 13:32:00,763 INFO ___FILE_ONLY___ ═
|
614 |
+
2025-03-14 13:32:00,775 INFO ___FILE_ONLY___ ═
|
615 |
+
2025-03-14 13:32:00,788 INFO ___FILE_ONLY___ ═
|
616 |
+
2025-03-14 13:32:00,802 INFO ___FILE_ONLY___ ═
|
617 |
+
2025-03-14 13:32:00,815 INFO ___FILE_ONLY___ ═
|
618 |
+
2025-03-14 13:32:00,828 INFO ___FILE_ONLY___ ═
|
619 |
+
2025-03-14 13:32:00,841 INFO ___FILE_ONLY___ ═
|
620 |
+
2025-03-14 13:32:00,854 INFO ___FILE_ONLY___ ═
|
621 |
+
2025-03-14 13:32:00,867 INFO ___FILE_ONLY___ ═
|
622 |
+
2025-03-14 13:32:00,881 INFO ___FILE_ONLY___ ═
|
623 |
+
2025-03-14 13:32:00,893 INFO ___FILE_ONLY___ ═
|
624 |
+
2025-03-14 13:32:00,905 INFO ___FILE_ONLY___ ═
|
625 |
+
2025-03-14 13:32:00,919 INFO ___FILE_ONLY___ ═
|
626 |
+
2025-03-14 13:32:00,932 INFO ___FILE_ONLY___ ═
|
627 |
+
2025-03-14 13:32:00,946 INFO ___FILE_ONLY___ ═
|
628 |
+
2025-03-14 13:32:00,959 INFO ___FILE_ONLY___ ═
|
629 |
+
2025-03-14 13:32:00,972 INFO ___FILE_ONLY___ ═
|
630 |
+
2025-03-14 13:32:00,985 INFO ___FILE_ONLY___ ═
|
631 |
+
2025-03-14 13:32:00,999 INFO ___FILE_ONLY___ ═
|
632 |
+
2025-03-14 13:32:01,012 INFO ___FILE_ONLY___ ═
|
633 |
+
2025-03-14 13:32:01,012 INFO ___FILE_ONLY___ ╝
|
634 |
+
|
635 |
+
2025-03-14 13:32:01,068 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
636 |
+
|
637 |
+
2025-03-14 13:32:01,069 INFO ___FILE_ONLY___ ╠═ Installing: Cloud Storage Command Line Tool ═╣
|
638 |
+
|
639 |
+
2025-03-14 13:32:01,069 INFO ___FILE_ONLY___ ╚
|
640 |
+
2025-03-14 13:32:01,600 INFO ___FILE_ONLY___ ═
|
641 |
+
2025-03-14 13:32:01,617 INFO ___FILE_ONLY___ ═
|
642 |
+
2025-03-14 13:32:01,638 INFO ___FILE_ONLY___ ═
|
643 |
+
2025-03-14 13:32:01,655 INFO ___FILE_ONLY___ ═
|
644 |
+
2025-03-14 13:32:01,671 INFO ___FILE_ONLY___ ═
|
645 |
+
2025-03-14 13:32:01,684 INFO ___FILE_ONLY___ ═
|
646 |
+
2025-03-14 13:32:01,699 INFO ___FILE_ONLY___ ═
|
647 |
+
2025-03-14 13:32:01,711 INFO ___FILE_ONLY___ ═
|
648 |
+
2025-03-14 13:32:01,726 INFO ___FILE_ONLY___ ═
|
649 |
+
2025-03-14 13:32:01,740 INFO ___FILE_ONLY___ ═
|
650 |
+
2025-03-14 13:32:01,754 INFO ___FILE_ONLY___ ═
|
651 |
+
2025-03-14 13:32:01,766 INFO ___FILE_ONLY___ ═
|
652 |
+
2025-03-14 13:32:01,776 INFO ___FILE_ONLY___ ═
|
653 |
+
2025-03-14 13:32:01,786 INFO ___FILE_ONLY___ ═
|
654 |
+
2025-03-14 13:32:01,797 INFO ___FILE_ONLY___ ═
|
655 |
+
2025-03-14 13:32:01,807 INFO ___FILE_ONLY___ ═
|
656 |
+
2025-03-14 13:32:01,818 INFO ___FILE_ONLY___ ═
|
657 |
+
2025-03-14 13:32:01,829 INFO ___FILE_ONLY___ ═
|
658 |
+
2025-03-14 13:32:01,840 INFO ___FILE_ONLY___ ═
|
659 |
+
2025-03-14 13:32:01,851 INFO ___FILE_ONLY___ ═
|
660 |
+
2025-03-14 13:32:01,864 INFO ___FILE_ONLY___ ═
|
661 |
+
2025-03-14 13:32:01,878 INFO ___FILE_ONLY___ ═
|
662 |
+
2025-03-14 13:32:01,893 INFO ___FILE_ONLY___ ═
|
663 |
+
2025-03-14 13:32:01,903 INFO ___FILE_ONLY___ ═
|
664 |
+
2025-03-14 13:32:01,916 INFO ___FILE_ONLY___ ═
|
665 |
+
2025-03-14 13:32:01,931 INFO ___FILE_ONLY___ ═
|
666 |
+
2025-03-14 13:32:01,948 INFO ___FILE_ONLY___ ═
|
667 |
+
2025-03-14 13:32:01,967 INFO ___FILE_ONLY___ ═
|
668 |
+
2025-03-14 13:32:01,983 INFO ___FILE_ONLY___ ═
|
669 |
+
2025-03-14 13:32:02,004 INFO ___FILE_ONLY___ ═
|
670 |
+
2025-03-14 13:32:02,015 INFO ___FILE_ONLY___ ═
|
671 |
+
2025-03-14 13:32:02,027 INFO ___FILE_ONLY___ ═
|
672 |
+
2025-03-14 13:32:02,045 INFO ___FILE_ONLY___ ═
|
673 |
+
2025-03-14 13:32:02,058 INFO ___FILE_ONLY___ ═
|
674 |
+
2025-03-14 13:32:02,069 INFO ___FILE_ONLY___ ═
|
675 |
+
2025-03-14 13:32:02,079 INFO ___FILE_ONLY___ ═
|
676 |
+
2025-03-14 13:32:02,089 INFO ___FILE_ONLY___ ═
|
677 |
+
2025-03-14 13:32:02,101 INFO ___FILE_ONLY___ ═
|
678 |
+
2025-03-14 13:32:02,112 INFO ___FILE_ONLY___ ═
|
679 |
+
2025-03-14 13:32:02,124 INFO ___FILE_ONLY___ ═
|
680 |
+
2025-03-14 13:32:02,134 INFO ___FILE_ONLY___ ═
|
681 |
+
2025-03-14 13:32:02,148 INFO ___FILE_ONLY___ ═
|
682 |
+
2025-03-14 13:32:02,157 INFO ___FILE_ONLY___ ═
|
683 |
+
2025-03-14 13:32:02,172 INFO ___FILE_ONLY___ ═
|
684 |
+
2025-03-14 13:32:02,187 INFO ___FILE_ONLY___ ═
|
685 |
+
2025-03-14 13:32:02,199 INFO ___FILE_ONLY___ ═
|
686 |
+
2025-03-14 13:32:02,209 INFO ___FILE_ONLY___ ═
|
687 |
+
2025-03-14 13:32:02,220 INFO ___FILE_ONLY___ ═
|
688 |
+
2025-03-14 13:32:02,231 INFO ___FILE_ONLY___ ═
|
689 |
+
2025-03-14 13:32:02,246 INFO ___FILE_ONLY___ ═
|
690 |
+
2025-03-14 13:32:02,266 INFO ___FILE_ONLY___ ═
|
691 |
+
2025-03-14 13:32:02,281 INFO ___FILE_ONLY___ ═
|
692 |
+
2025-03-14 13:32:02,300 INFO ___FILE_ONLY___ ═
|
693 |
+
2025-03-14 13:32:02,317 INFO ___FILE_ONLY___ ═
|
694 |
+
2025-03-14 13:32:02,357 INFO ___FILE_ONLY___ ═
|
695 |
+
2025-03-14 13:32:02,368 INFO ___FILE_ONLY___ ═
|
696 |
+
2025-03-14 13:32:02,379 INFO ___FILE_ONLY___ ═
|
697 |
+
2025-03-14 13:32:02,390 INFO ___FILE_ONLY___ ═
|
698 |
+
2025-03-14 13:32:02,402 INFO ___FILE_ONLY___ ═
|
699 |
+
2025-03-14 13:32:02,417 INFO ___FILE_ONLY___ ═
|
700 |
+
2025-03-14 13:32:02,417 INFO ___FILE_ONLY___ ╝
|
701 |
+
|
702 |
+
2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
703 |
+
|
704 |
+
2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╠═ Installing: Cloud Storage Command Line Tool (Platform... ═╣
|
705 |
+
|
706 |
+
2025-03-14 13:32:02,449 INFO ___FILE_ONLY___ ╚
|
707 |
+
2025-03-14 13:32:02,450 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
708 |
+
2025-03-14 13:32:02,450 INFO ___FILE_ONLY___ ╝
|
709 |
+
|
710 |
+
2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
711 |
+
|
712 |
+
2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╠═ Installing: Default set of gcloud commands ═╣
|
713 |
+
|
714 |
+
2025-03-14 13:32:02,454 INFO ___FILE_ONLY___ ╚
|
715 |
+
2025-03-14 13:32:02,457 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
716 |
+
2025-03-14 13:32:02,457 INFO ___FILE_ONLY___ ╝
|
717 |
+
|
718 |
+
2025-03-14 13:32:02,458 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
719 |
+
|
720 |
+
2025-03-14 13:32:02,458 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CLI Core Libraries (Platform... ═╣
|
721 |
+
|
722 |
+
2025-03-14 13:32:02,459 INFO ___FILE_ONLY___ ╚
|
723 |
+
2025-03-14 13:32:02,459 INFO ___FILE_ONLY___ ══════════════════════════════
|
724 |
+
2025-03-14 13:32:02,460 INFO ___FILE_ONLY___ ══════════════════════════════
|
725 |
+
2025-03-14 13:32:02,460 INFO ___FILE_ONLY___ ╝
|
726 |
+
|
727 |
+
2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
728 |
+
|
729 |
+
2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CRC32C Hash Tool ═╣
|
730 |
+
|
731 |
+
2025-03-14 13:32:02,464 INFO ___FILE_ONLY___ ╚
|
732 |
+
2025-03-14 13:32:02,466 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
733 |
+
2025-03-14 13:32:02,466 INFO ___FILE_ONLY___ ╝
|
734 |
+
|
735 |
+
2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
736 |
+
|
737 |
+
2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╠═ Installing: Google Cloud CRC32C Hash Tool (Platform S... ═╣
|
738 |
+
|
739 |
+
2025-03-14 13:32:02,468 INFO ___FILE_ONLY___ ╚
|
740 |
+
2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ══════════════════════════════
|
741 |
+
2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ══════════════════════════════
|
742 |
+
2025-03-14 13:32:02,507 INFO ___FILE_ONLY___ ╝
|
743 |
+
|
744 |
+
2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
745 |
+
|
746 |
+
2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╠═ Installing: gcloud cli dependencies (Platform Specific) ═╣
|
747 |
+
|
748 |
+
2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╚
|
749 |
+
2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
750 |
+
2025-03-14 13:32:02,513 INFO ___FILE_ONLY___ ╝
|
751 |
+
|
752 |
+
2025-03-14 13:32:02,518 DEBUG root Updating notification cache...
|
753 |
+
2025-03-14 13:32:02,518 INFO ___FILE_ONLY___
|
754 |
+
|
755 |
+
2025-03-14 13:32:02,520 INFO ___FILE_ONLY___ Performing post processing steps...
|
756 |
+
2025-03-14 13:32:02,521 DEBUG root Executing command: ['/tools/google-cloud-sdk/bin/gcloud', 'components', 'post-process']
|
757 |
+
2025-03-14 13:32:11,259 DEBUG ___FILE_ONLY___
|
758 |
+
2025-03-14 13:32:11,259 DEBUG ___FILE_ONLY___
|
759 |
+
2025-03-14 13:32:11,299 INFO root descriptor_list: [{'universeDomain': 'googleapis.com', 'universeShortName': '', 'authenticationDomain': 'auth.cloud.google.com', 'projectPrefix': '', 'cloudWebDomain': 'cloud.google.com', 'documentationDomain': 'cloud.google.com', 'version': '1.0.0', 'state': 'primary', 'artifactRegistryDomain': 'pkg.dev'}]
|
760 |
+
2025-03-14 13:32:11,299 INFO ___FILE_ONLY___
|
761 |
+
Update done!
|
762 |
+
|
763 |
+
|
764 |
+
2025-03-14 13:32:11,302 DEBUG root Chosen display Format:none
|
765 |
+
2025-03-14 13:32:11,303 INFO root Display format: "none"
|
.config/logs/2025.03.14/13.32.03.025824.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:32:03,026 DEBUG root Loaded Command Group: ['gcloud', 'components']
|
2 |
+
2025-03-14 13:32:03,028 DEBUG root Loaded Command Group: ['gcloud', 'components', 'post_process']
|
3 |
+
2025-03-14 13:32:03,030 DEBUG root Running [gcloud.components.post-process] with arguments: []
|
4 |
+
2025-03-14 13:32:11,135 DEBUG root Chosen display Format:none
|
5 |
+
2025-03-14 13:32:11,136 INFO root Display format: "none"
|
.config/logs/2025.03.14/13.32.11.932574.log
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:32:11,933 DEBUG root Loaded Command Group: ['gcloud', 'components']
|
2 |
+
2025-03-14 13:32:11,935 DEBUG root Loaded Command Group: ['gcloud', 'components', 'update']
|
3 |
+
2025-03-14 13:32:11,937 DEBUG root Running [gcloud.components.update] with arguments: [--quiet: "True", COMPONENT-IDS:8: "['gcloud', 'core', 'bq', 'gsutil', 'compute', 'preview', 'alpha', 'beta']"]
|
4 |
+
2025-03-14 13:32:11,938 INFO ___FILE_ONLY___ Beginning update. This process may take several minutes.
|
5 |
+
|
6 |
+
2025-03-14 13:32:11,947 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
7 |
+
2025-03-14 13:32:12,903 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components-2.json HTTP/11" 200 226132
|
8 |
+
2025-03-14 13:32:12,918 WARNING root Component [compute] no longer exists.
|
9 |
+
2025-03-14 13:32:12,919 INFO ___FILE_ONLY___
|
10 |
+
|
11 |
+
2025-03-14 13:32:12,919 INFO ___FILE_ONLY___
|
12 |
+
Your current Google Cloud CLI version is: 514.0.0
|
13 |
+
|
14 |
+
2025-03-14 13:32:12,920 INFO ___FILE_ONLY___ Installing components from version: 514.0.0
|
15 |
+
|
16 |
+
2025-03-14 13:32:12,920 INFO ___FILE_ONLY___
|
17 |
+
|
18 |
+
2025-03-14 13:32:12,920 DEBUG root Chosen display Format:table[box,title="These components will be removed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
19 |
+
2025-03-14 13:32:12,920 DEBUG root Chosen display Format:table[box,title="These components will be updated."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
20 |
+
2025-03-14 13:32:12,921 DEBUG root Chosen display Format:table[box,title="These components will be installed."](details.display_name:label=Name:align=left,version.version_string:label=Version:align=right,data.size.size(zero="",min=1048576):label=Size:align=right)
|
21 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ┌────────────────────────────────────────────────┐
|
22 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
23 |
+
|
24 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │ These components will be installed. │
|
25 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
26 |
+
|
27 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ├─────────────────────────┬────────────┬─────────┤
|
28 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
29 |
+
|
30 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │ Name │ Version │ Size │
|
31 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
32 |
+
|
33 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ ├─────────────────────────┼────────────┼─────────┤
|
34 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
35 |
+
|
36 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │
|
37 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ gcloud Alpha Commands
|
38 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___
|
39 |
+
2025-03-14 13:32:12,935 INFO ___FILE_ONLY___ │
|
40 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ 2025.03.07
|
41 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
42 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
43 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ < 1 MiB
|
44 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
45 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
46 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
47 |
+
|
48 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
49 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ gcloud Beta Commands
|
50 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
51 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
52 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ 2025.03.07
|
53 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
54 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
55 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ < 1 MiB
|
56 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
57 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
58 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___
|
59 |
+
|
60 |
+
2025-03-14 13:32:12,936 INFO ___FILE_ONLY___ │
|
61 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ gcloud Preview Commands
|
62 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
63 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
|
64 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
65 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
66 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
|
67 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ < 1 MiB
|
68 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
69 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ │
|
70 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
71 |
+
|
72 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___ └─────────────────────────┴────────────┴─────────┘
|
73 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
74 |
+
|
75 |
+
2025-03-14 13:32:12,937 INFO ___FILE_ONLY___
|
76 |
+
|
77 |
+
2025-03-14 13:32:12,940 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
78 |
+
2025-03-14 13:32:13,829 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/RELEASE_NOTES HTTP/11" 200 1377050
|
79 |
+
2025-03-14 13:32:14,277 INFO ___FILE_ONLY___ For the latest full release notes, please visit:
|
80 |
+
https://cloud.google.com/sdk/release_notes
|
81 |
+
|
82 |
+
|
83 |
+
2025-03-14 13:32:14,277 INFO ___FILE_ONLY___ Performing in place update...
|
84 |
+
|
85 |
+
|
86 |
+
2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
87 |
+
|
88 |
+
2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Alpha Commands ═╣
|
89 |
+
|
90 |
+
2025-03-14 13:32:14,280 INFO ___FILE_ONLY___ ╚
|
91 |
+
2025-03-14 13:32:14,283 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
92 |
+
2025-03-14 13:32:15,248 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-alpha-20250307152352.tar.gz HTTP/11" 200 800
|
93 |
+
2025-03-14 13:32:15,249 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
94 |
+
2025-03-14 13:32:15,249 INFO ___FILE_ONLY___ ╝
|
95 |
+
|
96 |
+
2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
97 |
+
|
98 |
+
2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Beta Commands ═╣
|
99 |
+
|
100 |
+
2025-03-14 13:32:15,251 INFO ___FILE_ONLY___ ╚
|
101 |
+
2025-03-14 13:32:15,254 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
102 |
+
2025-03-14 13:32:15,413 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-beta-20250307152352.tar.gz HTTP/11" 200 797
|
103 |
+
2025-03-14 13:32:15,413 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
104 |
+
2025-03-14 13:32:15,413 INFO ___FILE_ONLY___ ╝
|
105 |
+
|
106 |
+
2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
107 |
+
|
108 |
+
2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╠═ Downloading: gcloud Preview Commands ═╣
|
109 |
+
|
110 |
+
2025-03-14 13:32:15,416 INFO ___FILE_ONLY___ ╚
|
111 |
+
2025-03-14 13:32:15,419 DEBUG urllib3.connectionpool Starting new HTTPS connection (1): dl.google.com:443
|
112 |
+
2025-03-14 13:32:15,610 DEBUG urllib3.connectionpool https://dl.google.com:443 "GET /dl/cloudsdk/channels/rapid/components/google-cloud-sdk-preview-20241115154308.tar.gz HTTP/11" 200 823
|
113 |
+
2025-03-14 13:32:15,611 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
114 |
+
2025-03-14 13:32:15,611 INFO ___FILE_ONLY___ ╝
|
115 |
+
|
116 |
+
2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
117 |
+
|
118 |
+
2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Alpha Commands ═╣
|
119 |
+
|
120 |
+
2025-03-14 13:32:15,613 INFO ___FILE_ONLY___ ╚
|
121 |
+
2025-03-14 13:32:15,614 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
122 |
+
2025-03-14 13:32:15,614 INFO ___FILE_ONLY___ ╝
|
123 |
+
|
124 |
+
2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
125 |
+
|
126 |
+
2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Beta Commands ═╣
|
127 |
+
|
128 |
+
2025-03-14 13:32:15,619 INFO ___FILE_ONLY___ ╚
|
129 |
+
2025-03-14 13:32:15,620 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
130 |
+
2025-03-14 13:32:15,620 INFO ___FILE_ONLY___ ╝
|
131 |
+
|
132 |
+
2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╔════════════════════════════════════════════════════════════╗
|
133 |
+
|
134 |
+
2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╠═ Installing: gcloud Preview Commands ═╣
|
135 |
+
|
136 |
+
2025-03-14 13:32:15,625 INFO ___FILE_ONLY___ ╚
|
137 |
+
2025-03-14 13:32:15,626 INFO ___FILE_ONLY___ ════════════════════════════════════════════════════════════
|
138 |
+
2025-03-14 13:32:15,626 INFO ___FILE_ONLY___ ╝
|
139 |
+
|
140 |
+
2025-03-14 13:32:15,630 DEBUG root Updating notification cache...
|
141 |
+
2025-03-14 13:32:15,631 INFO ___FILE_ONLY___
|
142 |
+
|
143 |
+
2025-03-14 13:32:15,632 INFO ___FILE_ONLY___ Performing post processing steps...
|
144 |
+
2025-03-14 13:32:15,633 DEBUG root Executing command: ['/tools/google-cloud-sdk/bin/gcloud', 'components', 'post-process']
|
145 |
+
2025-03-14 13:32:24,180 DEBUG ___FILE_ONLY___
|
146 |
+
2025-03-14 13:32:24,180 DEBUG ___FILE_ONLY___
|
147 |
+
2025-03-14 13:32:24,406 INFO root descriptor_list: [{'universeDomain': 'googleapis.com', 'universeShortName': '', 'authenticationDomain': 'auth.cloud.google.com', 'projectPrefix': '', 'cloudWebDomain': 'cloud.google.com', 'documentationDomain': 'cloud.google.com', 'version': '1.0.0', 'state': 'primary', 'artifactRegistryDomain': 'pkg.dev'}]
|
148 |
+
2025-03-14 13:32:24,407 INFO ___FILE_ONLY___
|
149 |
+
Update done!
|
150 |
+
|
151 |
+
|
152 |
+
2025-03-14 13:32:24,409 DEBUG root Chosen display Format:none
|
153 |
+
2025-03-14 13:32:24,410 INFO root Display format: "none"
|
.config/logs/2025.03.14/13.32.16.153180.log
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:32:16,154 DEBUG root Loaded Command Group: ['gcloud', 'components']
|
2 |
+
2025-03-14 13:32:16,155 DEBUG root Loaded Command Group: ['gcloud', 'components', 'post_process']
|
3 |
+
2025-03-14 13:32:16,157 DEBUG root Running [gcloud.components.post-process] with arguments: []
|
4 |
+
2025-03-14 13:32:24,057 DEBUG root Chosen display Format:none
|
5 |
+
2025-03-14 13:32:24,058 INFO root Display format: "none"
|
.config/logs/2025.03.14/13.32.25.046318.log
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:32:25,048 DEBUG root Loaded Command Group: ['gcloud', 'config']
|
2 |
+
2025-03-14 13:32:25,095 DEBUG root Loaded Command Group: ['gcloud', 'config', 'set']
|
3 |
+
2025-03-14 13:32:25,098 DEBUG root Running [gcloud.config.set] with arguments: [SECTION/PROPERTY: "component_manager/disable_update_check", VALUE: "true"]
|
4 |
+
2025-03-14 13:32:25,098 INFO ___FILE_ONLY___ Updated property [component_manager/disable_update_check].
|
5 |
+
|
6 |
+
2025-03-14 13:32:25,099 DEBUG root Chosen display Format:default
|
7 |
+
2025-03-14 13:32:25,100 INFO root Display format: "default"
|
8 |
+
2025-03-14 13:32:25,100 DEBUG root SDK update checks are disabled.
|
.config/logs/2025.03.14/13.32.25.746375.log
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-03-14 13:32:25,748 DEBUG root Loaded Command Group: ['gcloud', 'config']
|
2 |
+
2025-03-14 13:32:25,794 DEBUG root Loaded Command Group: ['gcloud', 'config', 'set']
|
3 |
+
2025-03-14 13:32:25,796 DEBUG root Running [gcloud.config.set] with arguments: [SECTION/PROPERTY: "compute/gce_metadata_read_timeout_sec", VALUE: "0"]
|
4 |
+
2025-03-14 13:32:25,797 INFO ___FILE_ONLY___ Updated property [compute/gce_metadata_read_timeout_sec].
|
5 |
+
|
6 |
+
2025-03-14 13:32:25,798 DEBUG root Chosen display Format:default
|
7 |
+
2025-03-14 13:32:25,799 INFO root Display format: "default"
|
8 |
+
2025-03-14 13:32:25,799 DEBUG root SDK update checks are disabled.
|
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
Gemma-Finetune/.gitignore
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Virtual Environment
|
24 |
+
venv/
|
25 |
+
env/
|
26 |
+
ENV/
|
27 |
+
|
28 |
+
# Model files and datasets
|
29 |
+
models/
|
30 |
+
sample_datasets/
|
31 |
+
*.pt
|
32 |
+
*.pth
|
33 |
+
*.bin
|
34 |
+
*.gguf
|
35 |
+
*.onnx
|
36 |
+
|
37 |
+
# IDE
|
38 |
+
.idea/
|
39 |
+
.vscode/
|
40 |
+
*.swp
|
41 |
+
*.swo
|
42 |
+
|
43 |
+
# Logs and databases
|
44 |
+
*.log
|
45 |
+
*.sqlite
|
46 |
+
wandb/
|
47 |
+
|
48 |
+
# OS generated files
|
49 |
+
.DS_Store
|
50 |
+
.DS_Store?
|
51 |
+
._*
|
52 |
+
.Spotlight-V100
|
53 |
+
.Trashes
|
54 |
+
ehthumbs.db
|
55 |
+
Thumbs.db
|
56 |
+
|
57 |
+
# Installer logs
|
58 |
+
pip-log.txt
|
59 |
+
pip-delete-this-directory.txt
|
60 |
+
|
61 |
+
# Unit test / coverage reports
|
62 |
+
htmlcov/
|
63 |
+
.tox/
|
64 |
+
.nox/
|
65 |
+
.coverage
|
66 |
+
.coverage.*
|
67 |
+
.cache
|
68 |
+
nosetests.xml
|
69 |
+
coverage.xml
|
70 |
+
*.cover
|
71 |
+
*.py,cover
|
72 |
+
.hypothesis/
|
73 |
+
.pytest_cache/
|
74 |
+
cover/
|
75 |
+
|
76 |
+
# Translations
|
77 |
+
*.mo
|
78 |
+
*.pot
|
79 |
+
|
80 |
+
# Django stuff:
|
81 |
+
local_settings.py
|
82 |
+
db.sqlite3
|
83 |
+
db.sqlite3-journal
|
84 |
+
|
85 |
+
# Flask stuff:
|
86 |
+
instance/
|
87 |
+
.webassets-cache
|
88 |
+
|
89 |
+
# Scrapy stuff:
|
90 |
+
.scrapy
|
91 |
+
|
92 |
+
# Sphinx documentation
|
93 |
+
docs/_build/
|
94 |
+
|
95 |
+
# PyBuilder
|
96 |
+
.pybuilder/
|
97 |
+
target/
|
98 |
+
|
99 |
+
# Jupyter Notebook
|
100 |
+
.ipynb_checkpoints
|
101 |
+
|
102 |
+
# IPython
|
103 |
+
profile_default/
|
104 |
+
ipython_config.py
|
105 |
+
|
106 |
+
# pyenv
|
107 |
+
# For a library or package, you might want to ignore these files since the code is
|
108 |
+
# intended to run in multiple environments; otherwise, check them in:
|
109 |
+
# .python-version
|
110 |
+
|
111 |
+
# pipenv
|
112 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
113 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
114 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
115 |
+
# install all needed dependencies.
|
116 |
+
#Pipfile.lock
|
117 |
+
|
118 |
+
# UV
|
119 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
120 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
121 |
+
# commonly ignored for libraries.
|
122 |
+
#uv.lock
|
123 |
+
|
124 |
+
# poetry
|
125 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
126 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
127 |
+
# commonly ignored for libraries.
|
128 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
129 |
+
#poetry.lock
|
130 |
+
|
131 |
+
# pdm
|
132 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
133 |
+
#pdm.lock
|
134 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
135 |
+
# in version control.
|
136 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
137 |
+
.pdm.toml
|
138 |
+
.pdm-python
|
139 |
+
.pdm-build/
|
140 |
+
|
141 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
142 |
+
__pypackages__/
|
143 |
+
|
144 |
+
# Celery stuff
|
145 |
+
celerybeat-schedule
|
146 |
+
celerybeat.pid
|
147 |
+
|
148 |
+
# SageMath parsed files
|
149 |
+
*.sage.py
|
150 |
+
|
151 |
+
# Environments
|
152 |
+
.env
|
153 |
+
.venv
|
154 |
+
env.bak/
|
155 |
+
venv.bak/
|
156 |
+
|
157 |
+
# Spyder project settings
|
158 |
+
.spyderproject
|
159 |
+
.spyproject
|
160 |
+
|
161 |
+
# Rope project settings
|
162 |
+
.ropeproject
|
163 |
+
|
164 |
+
# mkdocs documentation
|
165 |
+
/site
|
166 |
+
|
167 |
+
# mypy
|
168 |
+
.mypy_cache/
|
169 |
+
.dmypy.json
|
170 |
+
dmypy.json
|
171 |
+
|
172 |
+
# Pyre type checker
|
173 |
+
.pyre/
|
174 |
+
|
175 |
+
# pytype static type analyzer
|
176 |
+
.pytype/
|
177 |
+
|
178 |
+
# Cython debug symbols
|
179 |
+
cython_debug/
|
180 |
+
|
181 |
+
# PyCharm
|
182 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
183 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
184 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
185 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
186 |
+
#.idea/
|
187 |
+
|
188 |
+
# Ruff stuff:
|
189 |
+
.ruff_cache/
|
190 |
+
|
191 |
+
# PyPI configuration file
|
192 |
+
.pypirc
|
Gemma-Finetune/Gemma3_(4B).ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Gemma-Finetune/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Dark Coder
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Gemma-Finetune/README.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gemma Fine-tuning UI
|
2 |
+
|
3 |
+
A user-friendly interface for fine-tuning Google's Gemma models using Unsloth optimizations.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- Easy-to-use web interface for model fine-tuning
|
8 |
+
- Support for multiple data formats (CSV, JSONL, TEXT)
|
9 |
+
- Parameter-efficient fine-tuning with LoRA
|
10 |
+
- Real-time training progress visualization
|
11 |
+
- Model export in multiple formats
|
12 |
+
- Integrated text generation testing
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
|
16 |
+
```bash
|
17 |
+
git clone https://github.com/codewithdark-git/Gemma-Finetune.git
|
18 |
+
cd Gemma-Finetune
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
1. Run the application:
|
25 |
+
```bash
|
26 |
+
python main.py
|
27 |
+
```
|
28 |
+
|
29 |
+
2. Follow the UI steps:
|
30 |
+
- Upload your dataset
|
31 |
+
- Configure model parameters
|
32 |
+
- Start training
|
33 |
+
- Test and export your model
|
34 |
+
|
35 |
+
## Requirements
|
36 |
+
|
37 |
+
See requirements.txt for detailed dependencies.
|
38 |
+
|
39 |
+
## License
|
40 |
+
|
41 |
+
MIT License
|
Gemma-Finetune/main.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.check_dataset import validate_dataset, generate_dataset_report
|
3 |
+
from utils.sample_dataset import generate_sample_datasets
|
4 |
+
from utils.model import GemmaFineTuning
|
5 |
+
|
6 |
+
class GemmaUI:
|
7 |
+
def __init__(self):
|
8 |
+
self.model_instance = GemmaFineTuning()
|
9 |
+
self.default_params = self.model_instance.default_params
|
10 |
+
|
11 |
+
def create_ui(self):
|
12 |
+
"""Create the Gradio interface"""
|
13 |
+
with gr.Blocks(title="Gemma Fine-tuning UI") as app:
|
14 |
+
gr.Markdown("# Gemma Model Fine-tuning Interface")
|
15 |
+
gr.Markdown("Upload your dataset, configure parameters, and fine-tune a Gemma model")
|
16 |
+
|
17 |
+
with gr.Tabs():
|
18 |
+
with gr.TabItem("1. Data Upload & Preprocessing"):
|
19 |
+
with gr.Row():
|
20 |
+
with gr.Column():
|
21 |
+
file_upload = gr.File(label="Upload Dataset")
|
22 |
+
file_format = gr.Radio(
|
23 |
+
["csv", "jsonl", "text"],
|
24 |
+
label="File Format",
|
25 |
+
value="csv"
|
26 |
+
)
|
27 |
+
preprocess_button = gr.Button("Preprocess Dataset")
|
28 |
+
dataset_info = gr.TextArea(label="Dataset Information", interactive=False)
|
29 |
+
|
30 |
+
with gr.TabItem("2. Model & Hyperparameters"):
|
31 |
+
with gr.Row():
|
32 |
+
with gr.Column():
|
33 |
+
model_name = gr.Dropdown(
|
34 |
+
choices=[
|
35 |
+
"google/gemma-2b",
|
36 |
+
"google/gemma-7b",
|
37 |
+
"google/gemma-2b-it",
|
38 |
+
"google/gemma-7b-it"
|
39 |
+
],
|
40 |
+
value=self.default_params["model_name"],
|
41 |
+
label="Model Name",
|
42 |
+
info="Select a Gemma model to fine-tune"
|
43 |
+
)
|
44 |
+
learning_rate = gr.Slider(
|
45 |
+
minimum=1e-6,
|
46 |
+
maximum=5e-4,
|
47 |
+
value=self.default_params["learning_rate"],
|
48 |
+
label="Learning Rate",
|
49 |
+
info="Learning rate for the optimizer"
|
50 |
+
)
|
51 |
+
batch_size = gr.Slider(
|
52 |
+
minimum=1,
|
53 |
+
maximum=32,
|
54 |
+
step=1,
|
55 |
+
value=self.default_params["batch_size"],
|
56 |
+
label="Batch Size",
|
57 |
+
info="Number of samples in each training batch"
|
58 |
+
)
|
59 |
+
epochs = gr.Slider(
|
60 |
+
minimum=1,
|
61 |
+
maximum=10,
|
62 |
+
step=1,
|
63 |
+
value=self.default_params["epochs"],
|
64 |
+
label="Epochs",
|
65 |
+
info="Number of training epochs"
|
66 |
+
)
|
67 |
+
|
68 |
+
with gr.Column():
|
69 |
+
max_length = gr.Slider(
|
70 |
+
minimum=128,
|
71 |
+
maximum=2048,
|
72 |
+
step=16,
|
73 |
+
value=self.default_params["max_length"],
|
74 |
+
label="Max Sequence Length",
|
75 |
+
info="Maximum token length for inputs"
|
76 |
+
)
|
77 |
+
use_lora = gr.Checkbox(
|
78 |
+
value=self.default_params["use_lora"],
|
79 |
+
label="Use LoRA for Parameter-Efficient Fine-tuning",
|
80 |
+
info="Recommended for faster training and lower memory usage"
|
81 |
+
)
|
82 |
+
lora_r = gr.Slider(
|
83 |
+
minimum=4,
|
84 |
+
maximum=64,
|
85 |
+
step=4,
|
86 |
+
value=self.default_params["lora_r"],
|
87 |
+
label="LoRA Rank (r)",
|
88 |
+
info="Rank of the LoRA update matrices",
|
89 |
+
visible=lambda: use_lora.value
|
90 |
+
)
|
91 |
+
lora_alpha = gr.Slider(
|
92 |
+
minimum=4,
|
93 |
+
maximum=64,
|
94 |
+
step=4,
|
95 |
+
value=self.default_params["lora_alpha"],
|
96 |
+
label="LoRA Alpha",
|
97 |
+
info="Scaling factor for LoRA updates",
|
98 |
+
visible=lambda: use_lora.value
|
99 |
+
)
|
100 |
+
eval_ratio = gr.Slider(
|
101 |
+
minimum=0.05,
|
102 |
+
maximum=0.3,
|
103 |
+
step=0.05,
|
104 |
+
value=self.default_params["eval_ratio"],
|
105 |
+
label="Validation Split Ratio",
|
106 |
+
info="Portion of data to use for validation"
|
107 |
+
)
|
108 |
+
|
109 |
+
with gr.TabItem("3. Training"):
|
110 |
+
with gr.Row():
|
111 |
+
with gr.Column():
|
112 |
+
start_training_button = gr.Button("Start Fine-tuning")
|
113 |
+
stop_training_button = gr.Button("Stop Training", variant="stop")
|
114 |
+
training_status = gr.Textbox(label="Training Status", interactive=False)
|
115 |
+
|
116 |
+
with gr.Column():
|
117 |
+
progress_plot = gr.Plot(label="Training Progress")
|
118 |
+
refresh_plot_button = gr.Button("Refresh Plot")
|
119 |
+
|
120 |
+
with gr.TabItem("4. Evaluation & Export"):
|
121 |
+
with gr.Row():
|
122 |
+
with gr.Column():
|
123 |
+
test_prompt = gr.Textbox(
|
124 |
+
label="Test Prompt",
|
125 |
+
placeholder="Enter a prompt to test the model...",
|
126 |
+
lines=3
|
127 |
+
)
|
128 |
+
max_gen_length = gr.Slider(
|
129 |
+
minimum=10,
|
130 |
+
maximum=500,
|
131 |
+
step=10,
|
132 |
+
value=100,
|
133 |
+
label="Max Generation Length"
|
134 |
+
)
|
135 |
+
generate_button = gr.Button("Generate Text")
|
136 |
+
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
|
137 |
+
|
138 |
+
with gr.Column():
|
139 |
+
export_format = gr.Radio(
|
140 |
+
["pytorch", "tensorflow", "gguf"],
|
141 |
+
label="Export Format",
|
142 |
+
value="pytorch"
|
143 |
+
)
|
144 |
+
export_button = gr.Button("Export Model")
|
145 |
+
export_status = gr.Textbox(label="Export Status", interactive=False)
|
146 |
+
|
147 |
+
# Functionality
|
148 |
+
def preprocess_data(file, format_type):
|
149 |
+
try:
|
150 |
+
if file is None:
|
151 |
+
return "Please upload a file first."
|
152 |
+
|
153 |
+
# Process the uploaded file
|
154 |
+
dataset = self.model_instance.prepare_dataset(file.name, format_type)
|
155 |
+
self.model_instance.dataset = dataset
|
156 |
+
|
157 |
+
# Create a summary of the dataset
|
158 |
+
num_samples = len(dataset["train"])
|
159 |
+
|
160 |
+
|
161 |
+
# Sample a few examples
|
162 |
+
examples = dataset["train"].select(range(min(3, num_samples)))
|
163 |
+
sample_text = []
|
164 |
+
for ex in examples:
|
165 |
+
text_key = list(ex.keys())[0] if "text" not in ex else "text"
|
166 |
+
sample = ex[text_key]
|
167 |
+
if isinstance(sample, str):
|
168 |
+
sample_text.append(sample[:100] + "..." if len(sample) > 100 else sample)
|
169 |
+
|
170 |
+
info = f"Dataset loaded successfully!\n"
|
171 |
+
info += f"Number of training examples: {num_samples}\n"
|
172 |
+
info += f"Sample data:\n" + "\n---\n".join(sample_text)
|
173 |
+
|
174 |
+
return info
|
175 |
+
except Exception as e:
|
176 |
+
return f"Error preprocessing data: {str(e)}"
|
177 |
+
|
178 |
+
def start_training(
|
179 |
+
model_name, learning_rate, batch_size, epochs, max_length,
|
180 |
+
use_lora, lora_r, lora_alpha, eval_ratio
|
181 |
+
):
|
182 |
+
try:
|
183 |
+
if self.model_instance.dataset is None:
|
184 |
+
return "Please preprocess a dataset first."
|
185 |
+
|
186 |
+
# Validate parameters
|
187 |
+
if not model_name:
|
188 |
+
return "Please select a model."
|
189 |
+
|
190 |
+
# Prepare training parameters with proper type conversion
|
191 |
+
training_params = {
|
192 |
+
"model_name": str(model_name),
|
193 |
+
"learning_rate": float(learning_rate),
|
194 |
+
"batch_size": int(batch_size),
|
195 |
+
"epochs": int(epochs),
|
196 |
+
"max_length": int(max_length),
|
197 |
+
"use_lora": bool(use_lora),
|
198 |
+
"lora_r": int(lora_r) if use_lora else None,
|
199 |
+
"lora_alpha": int(lora_alpha) if use_lora else None,
|
200 |
+
"eval_ratio": float(eval_ratio),
|
201 |
+
"weight_decay": float(self.default_params["weight_decay"]),
|
202 |
+
"warmup_ratio": float(self.default_params["warmup_ratio"]),
|
203 |
+
"lora_dropout": float(self.default_params["lora_dropout"])
|
204 |
+
}
|
205 |
+
|
206 |
+
# Start training in a separate thread
|
207 |
+
import threading
|
208 |
+
def train_thread():
|
209 |
+
status = self.model_instance.train(training_params)
|
210 |
+
return status
|
211 |
+
|
212 |
+
thread = threading.Thread(target=train_thread)
|
213 |
+
thread.start()
|
214 |
+
|
215 |
+
return "Training started! Monitor the progress in the Training tab."
|
216 |
+
except Exception as e:
|
217 |
+
return f"Error starting training: {str(e)}"
|
218 |
+
|
219 |
+
def stop_training():
|
220 |
+
if self.model_instance.trainer is not None:
|
221 |
+
# Attempt to stop the trainer
|
222 |
+
self.model_instance.trainer.stop_training = True
|
223 |
+
return "Training stop signal sent. It may take a moment to complete the current step."
|
224 |
+
return "No active training to stop."
|
225 |
+
|
226 |
+
def update_progress_plot():
|
227 |
+
try:
|
228 |
+
return self.model_instance.plot_training_progress()
|
229 |
+
except Exception as e:
|
230 |
+
return None
|
231 |
+
|
232 |
+
def run_text_generation(prompt, max_length):
|
233 |
+
try:
|
234 |
+
if self.model_instance.model is None:
|
235 |
+
return "Please fine-tune a model first."
|
236 |
+
|
237 |
+
return self.model_instance.generate_text(prompt, int(max_length))
|
238 |
+
except Exception as e:
|
239 |
+
return f"Error generating text: {str(e)}"
|
240 |
+
|
241 |
+
def export_model_fn(format_type):
|
242 |
+
try:
|
243 |
+
if self.model_instance.model is None:
|
244 |
+
return "Please fine-tune a model first."
|
245 |
+
|
246 |
+
return self.model_instance.export_model(format_type)
|
247 |
+
except Exception as e:
|
248 |
+
return f"Error exporting model: {str(e)}"
|
249 |
+
|
250 |
+
# Connect UI components to functions
|
251 |
+
preprocess_button.click(
|
252 |
+
preprocess_data,
|
253 |
+
inputs=[file_upload, file_format],
|
254 |
+
outputs=dataset_info
|
255 |
+
)
|
256 |
+
|
257 |
+
start_training_button.click(
|
258 |
+
start_training,
|
259 |
+
inputs=[
|
260 |
+
model_name, learning_rate, batch_size, epochs, max_length,
|
261 |
+
use_lora, lora_r, lora_alpha, eval_ratio
|
262 |
+
],
|
263 |
+
outputs=training_status
|
264 |
+
)
|
265 |
+
|
266 |
+
stop_training_button.click(
|
267 |
+
stop_training,
|
268 |
+
inputs=[],
|
269 |
+
outputs=training_status
|
270 |
+
)
|
271 |
+
|
272 |
+
refresh_plot_button.click(
|
273 |
+
update_progress_plot,
|
274 |
+
inputs=[],
|
275 |
+
outputs=progress_plot
|
276 |
+
)
|
277 |
+
|
278 |
+
generate_button.click(
|
279 |
+
run_text_generation,
|
280 |
+
inputs=[test_prompt, max_gen_length],
|
281 |
+
outputs=generated_output
|
282 |
+
)
|
283 |
+
|
284 |
+
export_button.click(
|
285 |
+
export_model_fn,
|
286 |
+
inputs=[export_format],
|
287 |
+
outputs=export_status
|
288 |
+
)
|
289 |
+
|
290 |
+
return app
|
291 |
+
|
292 |
+
if __name__ == '__main__':
|
293 |
+
ui = GemmaUI()
|
294 |
+
app = ui.create_ui()
|
295 |
+
app.launch(share=True)
|
Gemma-Finetune/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.36.0
|
3 |
+
unsloth>=0.1.0
|
4 |
+
gradio>=4.0.0
|
5 |
+
pandas>=1.5.0
|
6 |
+
numpy>=1.24.0
|
7 |
+
matplotlib>=3.7.0
|
8 |
+
peft>=0.7.0
|
9 |
+
datasets>=2.14.0
|
Gemma-Finetune/utils/__pycache__/check_dataset.cpython-311.pyc
ADDED
Binary file (13.3 kB). View file
|
|
Gemma-Finetune/utils/__pycache__/model.cpython-311.pyc
ADDED
Binary file (28.4 kB). View file
|
|
Gemma-Finetune/utils/__pycache__/sample_dataset.cpython-311.pyc
ADDED
Binary file (7.67 kB). View file
|
|
Gemma-Finetune/utils/check_dataset.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
def validate_dataset(self, file_path, format_type):
|
5 |
+
"""
|
6 |
+
Validate and analyze the dataset format, providing detailed feedback
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
file_path (str): Path to the dataset file
|
10 |
+
format_type (str): File format (csv, jsonl, text)
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
dict: Validation results including format, structure, and statistics
|
14 |
+
"""
|
15 |
+
import pandas as pd
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
|
20 |
+
validation_results = {
|
21 |
+
"is_valid": False,
|
22 |
+
"format": format_type,
|
23 |
+
"detected_structure": None,
|
24 |
+
"statistics": {},
|
25 |
+
"issues": [],
|
26 |
+
"recommendations": []
|
27 |
+
}
|
28 |
+
|
29 |
+
try:
|
30 |
+
# Check if file exists
|
31 |
+
if not os.path.exists(file_path):
|
32 |
+
validation_results["issues"].append(f"File not found: {file_path}")
|
33 |
+
return validation_results
|
34 |
+
|
35 |
+
# Check file size
|
36 |
+
file_size = os.path.getsize(file_path)
|
37 |
+
validation_results["statistics"]["file_size_bytes"] = file_size
|
38 |
+
validation_results["statistics"]["file_size_mb"] = round(file_size / (1024 * 1024), 2)
|
39 |
+
|
40 |
+
if file_size == 0:
|
41 |
+
validation_results["issues"].append("File is empty")
|
42 |
+
return validation_results
|
43 |
+
|
44 |
+
if format_type == "csv":
|
45 |
+
# Load CSV file
|
46 |
+
try:
|
47 |
+
df = pd.read_csv(file_path)
|
48 |
+
validation_results["statistics"]["total_rows"] = len(df)
|
49 |
+
validation_results["statistics"]["total_columns"] = len(df.columns)
|
50 |
+
validation_results["statistics"]["column_names"] = list(df.columns)
|
51 |
+
|
52 |
+
# Check for null values
|
53 |
+
null_counts = df.isnull().sum().to_dict()
|
54 |
+
validation_results["statistics"]["null_counts"] = null_counts
|
55 |
+
|
56 |
+
if validation_results["statistics"]["total_rows"] == 0:
|
57 |
+
validation_results["issues"].append("CSV file has no rows")
|
58 |
+
return validation_results
|
59 |
+
|
60 |
+
# Detect structure
|
61 |
+
if "instruction" in df.columns and "response" in df.columns:
|
62 |
+
validation_results["detected_structure"] = "instruction-response"
|
63 |
+
validation_results["is_valid"] = True
|
64 |
+
elif "input" in df.columns and "output" in df.columns:
|
65 |
+
validation_results["detected_structure"] = "input-output"
|
66 |
+
validation_results["is_valid"] = True
|
67 |
+
elif "prompt" in df.columns and "completion" in df.columns:
|
68 |
+
validation_results["detected_structure"] = "prompt-completion"
|
69 |
+
validation_results["is_valid"] = True
|
70 |
+
elif "text" in df.columns:
|
71 |
+
validation_results["detected_structure"] = "text-only"
|
72 |
+
validation_results["is_valid"] = True
|
73 |
+
else:
|
74 |
+
# Look for text columns
|
75 |
+
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
76 |
+
if text_columns:
|
77 |
+
validation_results["detected_structure"] = "custom"
|
78 |
+
validation_results["statistics"]["potential_text_columns"] = text_columns
|
79 |
+
validation_results["is_valid"] = True
|
80 |
+
validation_results["recommendations"].append(
|
81 |
+
f"Consider renaming columns to match standard formats: instruction/response, input/output, prompt/completion, or text"
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
validation_results["issues"].append("No suitable text columns found in CSV")
|
85 |
+
|
86 |
+
# Check for short text
|
87 |
+
if validation_results["detected_structure"] == "instruction-response":
|
88 |
+
short_instructions = (df["instruction"].str.len() < 10).sum()
|
89 |
+
short_responses = (df["response"].str.len() < 10).sum()
|
90 |
+
validation_results["statistics"]["short_instructions"] = short_instructions
|
91 |
+
validation_results["statistics"]["short_responses"] = short_responses
|
92 |
+
|
93 |
+
if short_instructions > 0:
|
94 |
+
validation_results["issues"].append(f"Found {short_instructions} instructions shorter than 10 characters")
|
95 |
+
if short_responses > 0:
|
96 |
+
validation_results["issues"].append(f"Found {short_responses} responses shorter than 10 characters")
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
validation_results["issues"].append(f"Error parsing CSV: {str(e)}")
|
100 |
+
return validation_results
|
101 |
+
|
102 |
+
elif format_type == "jsonl":
|
103 |
+
try:
|
104 |
+
# Load JSONL file
|
105 |
+
data = []
|
106 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
107 |
+
for line_num, line in enumerate(f, 1):
|
108 |
+
line = line.strip()
|
109 |
+
if not line:
|
110 |
+
continue
|
111 |
+
try:
|
112 |
+
json_obj = json.loads(line)
|
113 |
+
data.append(json_obj)
|
114 |
+
except json.JSONDecodeError:
|
115 |
+
validation_results["issues"].append(f"Invalid JSON at line {line_num}")
|
116 |
+
|
117 |
+
validation_results["statistics"]["total_examples"] = len(data)
|
118 |
+
|
119 |
+
if len(data) == 0:
|
120 |
+
validation_results["issues"].append("No valid JSON objects found in file")
|
121 |
+
return validation_results
|
122 |
+
|
123 |
+
# Get sample of keys from first object
|
124 |
+
if data:
|
125 |
+
validation_results["statistics"]["sample_keys"] = list(data[0].keys())
|
126 |
+
|
127 |
+
# Detect structure
|
128 |
+
structures = []
|
129 |
+
for item in data:
|
130 |
+
if "instruction" in item and "response" in item:
|
131 |
+
structures.append("instruction-response")
|
132 |
+
elif "input" in item and "output" in item:
|
133 |
+
structures.append("input-output")
|
134 |
+
elif "prompt" in item and "completion" in item:
|
135 |
+
structures.append("prompt-completion")
|
136 |
+
elif "text" in item:
|
137 |
+
structures.append("text-only")
|
138 |
+
else:
|
139 |
+
structures.append("custom")
|
140 |
+
|
141 |
+
# Count structure types
|
142 |
+
from collections import Counter
|
143 |
+
structure_counts = Counter(structures)
|
144 |
+
validation_results["statistics"]["structure_counts"] = structure_counts
|
145 |
+
|
146 |
+
# Set detected structure to most common
|
147 |
+
if structures:
|
148 |
+
most_common = structure_counts.most_common(1)[0][0]
|
149 |
+
validation_results["detected_structure"] = most_common
|
150 |
+
validation_results["is_valid"] = True
|
151 |
+
|
152 |
+
# Check if mixed
|
153 |
+
if len(structure_counts) > 1:
|
154 |
+
validation_results["issues"].append(f"Mixed structures detected: {dict(structure_counts)}")
|
155 |
+
validation_results["recommendations"].append("Consider standardizing all records to the same structure")
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
validation_results["issues"].append(f"Error parsing JSONL: {str(e)}")
|
159 |
+
return validation_results
|
160 |
+
|
161 |
+
elif format_type == "text":
|
162 |
+
try:
|
163 |
+
# Read text file
|
164 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
165 |
+
content = f.read()
|
166 |
+
|
167 |
+
# Get basic stats
|
168 |
+
total_chars = len(content)
|
169 |
+
total_words = len(content.split())
|
170 |
+
total_lines = content.count('\n') + 1
|
171 |
+
|
172 |
+
validation_results["statistics"]["total_characters"] = total_chars
|
173 |
+
validation_results["statistics"]["total_words"] = total_words
|
174 |
+
validation_results["statistics"]["total_lines"] = total_lines
|
175 |
+
|
176 |
+
# Check if it's a single large document or multiple examples
|
177 |
+
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
178 |
+
validation_results["statistics"]["total_paragraphs"] = len(paragraphs)
|
179 |
+
|
180 |
+
# Try to detect structure
|
181 |
+
# Look for common patterns like "Q: ... A: ...", "Input: ... Output: ..."
|
182 |
+
has_qa_pattern = re.search(r"Q:.*?A:", content, re.DOTALL) is not None
|
183 |
+
has_input_output = re.search(r"Input:.*?Output:", content, re.DOTALL) is not None
|
184 |
+
has_prompt_completion = re.search(r"Prompt:.*?Completion:", content, re.DOTALL) is not None
|
185 |
+
|
186 |
+
if has_qa_pattern:
|
187 |
+
validation_results["detected_structure"] = "Q&A-format"
|
188 |
+
elif has_input_output:
|
189 |
+
validation_results["detected_structure"] = "input-output-format"
|
190 |
+
elif has_prompt_completion:
|
191 |
+
validation_results["detected_structure"] = "prompt-completion-format"
|
192 |
+
elif len(paragraphs) > 1:
|
193 |
+
validation_results["detected_structure"] = "paragraphs"
|
194 |
+
else:
|
195 |
+
validation_results["detected_structure"] = "continuous-text"
|
196 |
+
|
197 |
+
validation_results["is_valid"] = True
|
198 |
+
|
199 |
+
if validation_results["detected_structure"] == "continuous-text" and total_chars < 1000:
|
200 |
+
validation_results["issues"].append("Text file is very short for fine-tuning")
|
201 |
+
validation_results["recommendations"].append("Consider adding more content or examples")
|
202 |
+
|
203 |
+
except Exception as e:
|
204 |
+
validation_results["issues"].append(f"Error parsing text file: {str(e)}")
|
205 |
+
return validation_results
|
206 |
+
else:
|
207 |
+
validation_results["issues"].append(f"Unsupported file format: {format_type}")
|
208 |
+
return validation_results
|
209 |
+
|
210 |
+
# General recommendations
|
211 |
+
if validation_results["is_valid"]:
|
212 |
+
if not validation_results["issues"]:
|
213 |
+
validation_results["recommendations"].append("Dataset looks good and ready for fine-tuning!")
|
214 |
+
else:
|
215 |
+
validation_results["recommendations"].append("Address the issues above before proceeding with fine-tuning")
|
216 |
+
|
217 |
+
return validation_results
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
validation_results["issues"].append(f"Unexpected error: {str(e)}")
|
221 |
+
return validation_results
|
222 |
+
|
223 |
+
def generate_dataset_report(validation_results):
|
224 |
+
"""
|
225 |
+
Generate a user-friendly report from validation results
|
226 |
+
|
227 |
+
Parameters:
|
228 |
+
validation_results (dict): Results from validate_dataset
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
str: Formatted report
|
232 |
+
"""
|
233 |
+
report = []
|
234 |
+
|
235 |
+
# Add header
|
236 |
+
report.append("# Dataset Validation Report")
|
237 |
+
report.append("")
|
238 |
+
|
239 |
+
# Add validation status
|
240 |
+
if validation_results["is_valid"]:
|
241 |
+
report.append("✅ Dataset is valid and can be used for fine-tuning")
|
242 |
+
else:
|
243 |
+
report.append("❌ Dataset has issues that need to be addressed")
|
244 |
+
report.append("")
|
245 |
+
|
246 |
+
# Add format info
|
247 |
+
report.append(f"**File Format:** {validation_results['format']}")
|
248 |
+
report.append(f"**Detected Structure:** {validation_results['detected_structure']}")
|
249 |
+
report.append("")
|
250 |
+
|
251 |
+
# Add statistics
|
252 |
+
report.append("## Statistics")
|
253 |
+
for key, value in validation_results["statistics"].items():
|
254 |
+
# Format the key for better readability
|
255 |
+
formatted_key = key.replace("_", " ").title()
|
256 |
+
report.append(f"- {formatted_key}: {value}")
|
257 |
+
report.append("")
|
258 |
+
|
259 |
+
# Add issues
|
260 |
+
if validation_results["issues"]:
|
261 |
+
report.append("## Issues")
|
262 |
+
for issue in validation_results["issues"]:
|
263 |
+
report.append(f"- ⚠️ {issue}")
|
264 |
+
report.append("")
|
265 |
+
|
266 |
+
# Add recommendations
|
267 |
+
if validation_results["recommendations"]:
|
268 |
+
report.append("## Recommendations")
|
269 |
+
for recommendation in validation_results["recommendations"]:
|
270 |
+
report.append(f"- 💡 {recommendation}")
|
271 |
+
|
272 |
+
return "\n".join(report)
|
Gemma-Finetune/utils/model.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
10 |
+
from datetime import datetime
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
AutoModelForCausalLM,
|
15 |
+
TrainingArguments,
|
16 |
+
Trainer,
|
17 |
+
DataCollatorForLanguageModeling,
|
18 |
+
TrainerCallback
|
19 |
+
)
|
20 |
+
from peft import (
|
21 |
+
LoraConfig,
|
22 |
+
get_peft_model,
|
23 |
+
prepare_model_for_kbit_training
|
24 |
+
)
|
25 |
+
from datasets import load_dataset
|
26 |
+
from unsloth import FastModel
|
27 |
+
|
28 |
+
|
29 |
+
class GemmaFineTuning:
|
30 |
+
def __init__(self):
|
31 |
+
self.model = None
|
32 |
+
self.tokenizer = None
|
33 |
+
self.dataset = None
|
34 |
+
self.trainer = None
|
35 |
+
self.training_history = {"loss": [], "eval_loss": [], "step": []}
|
36 |
+
self.model_save_path = None
|
37 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
+
|
39 |
+
self.fourbit_models = [
|
40 |
+
"unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
|
41 |
+
"unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
|
42 |
+
"unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
|
43 |
+
"unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
|
44 |
+
]
|
45 |
+
# Default hyperparameters
|
46 |
+
self.default_params = {
|
47 |
+
"model_name": "google/gemma-2b",
|
48 |
+
"learning_rate": 2e-5,
|
49 |
+
"batch_size": 8,
|
50 |
+
"epochs": 3,
|
51 |
+
"max_length": 512,
|
52 |
+
"weight_decay": 0.01,
|
53 |
+
"warmup_ratio": 0.1,
|
54 |
+
"use_lora": True,
|
55 |
+
"lora_r": 16,
|
56 |
+
"lora_alpha": 32,
|
57 |
+
"lora_dropout": 0.05,
|
58 |
+
"eval_ratio": 0.1,
|
59 |
+
}
|
60 |
+
|
61 |
+
def load_model_and_tokenizer(self, model_name: str) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
62 |
+
"""Load the model and tokenizer"""
|
63 |
+
try:
|
64 |
+
# Map UI model names to actual model IDs
|
65 |
+
model_mapping = {
|
66 |
+
"google/gemma-2b": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
|
67 |
+
"google/gemma-7b": "unsloth/gemma-7b-it-unsloth-bnb-4bit",
|
68 |
+
"google/gemma-2b-it": "unsloth/gemma-2b-it-unsloth-bnb-4bit",
|
69 |
+
"google/gemma-7b-it": "unsloth/gemma-7b-it-unsloth-bnb-4bit"
|
70 |
+
}
|
71 |
+
|
72 |
+
actual_model_name = model_mapping.get(model_name, model_name)
|
73 |
+
|
74 |
+
model, tokenizer = FastModel.from_pretrained(
|
75 |
+
model_name=actual_model_name,
|
76 |
+
max_seq_length=2048,
|
77 |
+
load_in_4bit=True,
|
78 |
+
load_in_8bit=False,
|
79 |
+
full_finetuning=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
# Move model to device
|
83 |
+
model = model.to(self.device)
|
84 |
+
return model, tokenizer
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
raise ValueError(f"Error loading model {model_name}: {str(e)}")
|
88 |
+
|
89 |
+
def prepare_dataset(self, file_path, format_type):
|
90 |
+
"""
|
91 |
+
Prepare and normalize dataset from various formats
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
file_path (str): Path to the dataset file
|
95 |
+
format_type (str): File format (csv, jsonl, text)
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
dict: Dataset dictionary with train split
|
99 |
+
"""
|
100 |
+
import pandas as pd
|
101 |
+
import json
|
102 |
+
import os
|
103 |
+
from datasets import Dataset, DatasetDict
|
104 |
+
|
105 |
+
try:
|
106 |
+
if format_type == "csv":
|
107 |
+
# Load CSV file
|
108 |
+
df = pd.read_csv(file_path)
|
109 |
+
|
110 |
+
# Check if the CSV has the expected columns (looking for either instruction-response pairs or text)
|
111 |
+
if "instruction" in df.columns and "response" in df.columns:
|
112 |
+
# Instruction-following dataset format
|
113 |
+
dataset_format = "instruction-response"
|
114 |
+
# Ensure no nulls
|
115 |
+
df = df.dropna(subset=["instruction", "response"])
|
116 |
+
# Create formatted text by combining instruction and response
|
117 |
+
df["text"] = df.apply(lambda row: f"<instruction>{row['instruction']}</instruction>\n<response>{row['response']}</response>", axis=1)
|
118 |
+
elif "input" in df.columns and "output" in df.columns:
|
119 |
+
# Another common format
|
120 |
+
dataset_format = "input-output"
|
121 |
+
df = df.dropna(subset=["input", "output"])
|
122 |
+
df["text"] = df.apply(lambda row: f"<input>{row['input']}</input>\n<output>{row['output']}</output>", axis=1)
|
123 |
+
elif "prompt" in df.columns and "completion" in df.columns:
|
124 |
+
# OpenAI-style format
|
125 |
+
dataset_format = "prompt-completion"
|
126 |
+
df = df.dropna(subset=["prompt", "completion"])
|
127 |
+
df["text"] = df.apply(lambda row: f"<prompt>{row['prompt']}</prompt>\n<completion>{row['completion']}</completion>", axis=1)
|
128 |
+
elif "text" in df.columns:
|
129 |
+
# Simple text format
|
130 |
+
dataset_format = "text-only"
|
131 |
+
df = df.dropna(subset=["text"])
|
132 |
+
else:
|
133 |
+
# Try to infer format from the first text column
|
134 |
+
text_columns = [col for col in df.columns if df[col].dtype == 'object']
|
135 |
+
if len(text_columns) > 0:
|
136 |
+
dataset_format = "inferred"
|
137 |
+
df["text"] = df[text_columns[0]]
|
138 |
+
df = df.dropna(subset=["text"])
|
139 |
+
else:
|
140 |
+
raise ValueError("CSV file must contain either 'instruction'/'response', 'input'/'output', 'prompt'/'completion', or 'text' columns")
|
141 |
+
|
142 |
+
# Create dataset
|
143 |
+
dataset = Dataset.from_pandas(df)
|
144 |
+
|
145 |
+
elif format_type == "jsonl":
|
146 |
+
# Load JSONL file
|
147 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
148 |
+
data = [json.loads(line) for line in f if line.strip()]
|
149 |
+
|
150 |
+
# Check and normalize the format
|
151 |
+
normalized_data = []
|
152 |
+
for item in data:
|
153 |
+
normalized_item = {}
|
154 |
+
|
155 |
+
# Try to find either instruction-response pairs or text
|
156 |
+
if "instruction" in item and "response" in item:
|
157 |
+
normalized_item["text"] = f"<instruction>{item['instruction']}</instruction>\n<response>{item['response']}</response>"
|
158 |
+
normalized_item["instruction"] = item["instruction"]
|
159 |
+
normalized_item["response"] = item["response"]
|
160 |
+
elif "input" in item and "output" in item:
|
161 |
+
normalized_item["text"] = f"<input>{item['input']}</input>\n<output>{item['output']}</output>"
|
162 |
+
normalized_item["input"] = item["input"]
|
163 |
+
normalized_item["output"] = item["output"]
|
164 |
+
elif "prompt" in item and "completion" in item:
|
165 |
+
normalized_item["text"] = f"<prompt>{item['prompt']}</prompt>\n<completion>{item['completion']}</completion>"
|
166 |
+
normalized_item["prompt"] = item["prompt"]
|
167 |
+
normalized_item["completion"] = item["completion"]
|
168 |
+
elif "text" in item:
|
169 |
+
normalized_item["text"] = item["text"]
|
170 |
+
else:
|
171 |
+
# Try to infer from the first string value
|
172 |
+
text_keys = [k for k, v in item.items() if isinstance(v, str) and len(v.strip()) > 0]
|
173 |
+
if text_keys:
|
174 |
+
normalized_item["text"] = item[text_keys[0]]
|
175 |
+
else:
|
176 |
+
continue # Skip this item if no usable text found
|
177 |
+
|
178 |
+
normalized_data.append(normalized_item)
|
179 |
+
|
180 |
+
if not normalized_data:
|
181 |
+
raise ValueError("No valid data items found in the JSONL file")
|
182 |
+
|
183 |
+
# Create dataset
|
184 |
+
dataset = Dataset.from_list(normalized_data)
|
185 |
+
|
186 |
+
elif format_type == "text":
|
187 |
+
# For text files, split by newlines and create entries
|
188 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
189 |
+
content = f.read()
|
190 |
+
|
191 |
+
# Check if it's a single large document or multiple examples
|
192 |
+
# If file size > 10KB, try to split into paragraphs
|
193 |
+
if os.path.getsize(file_path) > 10240:
|
194 |
+
# Split by double newlines (paragraphs)
|
195 |
+
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
196 |
+
# Filter out very short paragraphs (less than 20 chars)
|
197 |
+
paragraphs = [p for p in paragraphs if len(p) >= 20]
|
198 |
+
data = [{"text": p} for p in paragraphs]
|
199 |
+
else:
|
200 |
+
# Treat as a single example
|
201 |
+
data = [{"text": content}]
|
202 |
+
|
203 |
+
# Create dataset
|
204 |
+
dataset = Dataset.from_list(data)
|
205 |
+
else:
|
206 |
+
raise ValueError(f"Unsupported file format: {format_type}")
|
207 |
+
|
208 |
+
# Return as a DatasetDict with a train split
|
209 |
+
return DatasetDict({"train": dataset})
|
210 |
+
|
211 |
+
except Exception as e:
|
212 |
+
import traceback
|
213 |
+
error_msg = f"Error processing dataset: {str(e)}\n{traceback.format_exc()}"
|
214 |
+
print(error_msg)
|
215 |
+
raise ValueError(error_msg)
|
216 |
+
|
217 |
+
def chunk_text(self, text: str, chunk_size: int) -> List[str]:
|
218 |
+
"""Split text into chunks of approximately chunk_size characters"""
|
219 |
+
words = text.split()
|
220 |
+
chunks = []
|
221 |
+
current_chunk = []
|
222 |
+
current_length = 0
|
223 |
+
|
224 |
+
for word in words:
|
225 |
+
if current_length + len(word) + 1 > chunk_size and current_chunk:
|
226 |
+
chunks.append(" ".join(current_chunk))
|
227 |
+
current_chunk = [word]
|
228 |
+
current_length = len(word)
|
229 |
+
else:
|
230 |
+
current_chunk.append(word)
|
231 |
+
current_length += len(word) + 1 # +1 for the space
|
232 |
+
|
233 |
+
if current_chunk:
|
234 |
+
chunks.append(" ".join(current_chunk))
|
235 |
+
|
236 |
+
return chunks
|
237 |
+
|
238 |
+
def preprocess_dataset(self, dataset, tokenizer, max_length):
|
239 |
+
"""
|
240 |
+
Tokenize and format the dataset for training
|
241 |
+
|
242 |
+
Parameters:
|
243 |
+
dataset (DatasetDict): Dataset dictionary with train and validation splits
|
244 |
+
tokenizer: HuggingFace tokenizer
|
245 |
+
max_length (int): Maximum sequence length
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
DatasetDict: Tokenized dataset ready for training
|
249 |
+
"""
|
250 |
+
def tokenize_function(examples):
|
251 |
+
# Check if the dataset has both input and target text columns
|
252 |
+
if "text" in examples:
|
253 |
+
texts = examples["text"]
|
254 |
+
inputs = tokenizer(
|
255 |
+
texts,
|
256 |
+
padding="max_length",
|
257 |
+
truncation=True,
|
258 |
+
max_length=max_length,
|
259 |
+
return_tensors="pt"
|
260 |
+
)
|
261 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
262 |
+
return inputs
|
263 |
+
else:
|
264 |
+
# Try to find text columns based on common naming patterns
|
265 |
+
potential_text_cols = [col for col in examples.keys() if isinstance(examples[col], list) and
|
266 |
+
all(isinstance(item, str) for item in examples[col])]
|
267 |
+
|
268 |
+
if not potential_text_cols:
|
269 |
+
raise ValueError("No suitable text columns found in the dataset")
|
270 |
+
|
271 |
+
# Use the first text column found
|
272 |
+
text_col = potential_text_cols[0]
|
273 |
+
texts = examples[text_col]
|
274 |
+
|
275 |
+
inputs = tokenizer(
|
276 |
+
texts,
|
277 |
+
padding="max_length",
|
278 |
+
truncation=True,
|
279 |
+
max_length=max_length,
|
280 |
+
return_tensors="pt"
|
281 |
+
)
|
282 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
283 |
+
return inputs
|
284 |
+
|
285 |
+
# Apply tokenization to each split
|
286 |
+
tokenized_dataset = {}
|
287 |
+
for split, ds in dataset.items():
|
288 |
+
tokenized_dataset[split] = ds.map(
|
289 |
+
tokenize_function,
|
290 |
+
batched=True,
|
291 |
+
remove_columns=ds.column_names
|
292 |
+
)
|
293 |
+
|
294 |
+
return tokenized_dataset
|
295 |
+
|
296 |
+
def prepare_training_args(self, params: Dict) -> TrainingArguments:
|
297 |
+
"""Set up training arguments based on hyperparameters"""
|
298 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
299 |
+
self.model_save_path = f"gemma-finetuned-{timestamp}"
|
300 |
+
|
301 |
+
args = TrainingArguments(
|
302 |
+
output_dir=self.model_save_path,
|
303 |
+
per_device_train_batch_size=params.get("batch_size", self.default_params["batch_size"]),
|
304 |
+
gradient_accumulation_steps=4,
|
305 |
+
per_device_eval_batch_size=params.get("batch_size", self.default_params["batch_size"]),
|
306 |
+
learning_rate=params.get("learning_rate", self.default_params["learning_rate"]),
|
307 |
+
num_train_epochs=params.get("epochs", self.default_params["epochs"]),
|
308 |
+
warmup_ratio=params.get("warmup_ratio", self.default_params["warmup_ratio"]),
|
309 |
+
weight_decay=params.get("weight_decay", self.default_params["weight_decay"]),
|
310 |
+
logging_steps=1,
|
311 |
+
evaluation_strategy="steps" if params.get("eval_ratio", 0) > 0 else "no",
|
312 |
+
eval_steps=100 if params.get("eval_ratio", 0) > 0 else None,
|
313 |
+
save_strategy="steps",
|
314 |
+
save_steps=100,
|
315 |
+
save_total_limit=2,
|
316 |
+
load_best_model_at_end=True if params.get("eval_ratio", 0) > 0 else False,
|
317 |
+
report_to="none"
|
318 |
+
)
|
319 |
+
return args
|
320 |
+
|
321 |
+
def train(self, training_params: Dict) -> str:
|
322 |
+
"""Main training method that handles the complete training pipeline"""
|
323 |
+
try:
|
324 |
+
if self.dataset is None:
|
325 |
+
return "Error: No dataset loaded. Please preprocess a dataset first."
|
326 |
+
|
327 |
+
# Reset training history
|
328 |
+
self.training_history = {"loss": [], "eval_loss": [], "step": []}
|
329 |
+
|
330 |
+
# Load model and tokenizer if not already loaded or if model name changed
|
331 |
+
current_model_name = training_params.get("model_name", self.default_params["model_name"])
|
332 |
+
if (self.model is None or self.tokenizer is None or
|
333 |
+
getattr(self, '_current_model_name', None) != current_model_name):
|
334 |
+
|
335 |
+
self.model, self.tokenizer = self.load_model_and_tokenizer(current_model_name)
|
336 |
+
self._current_model_name = current_model_name
|
337 |
+
|
338 |
+
# Create validation split if needed
|
339 |
+
eval_ratio = float(training_params.get("eval_ratio", self.default_params["eval_ratio"]))
|
340 |
+
if eval_ratio > 0 and "validation" not in self.dataset:
|
341 |
+
split_dataset = self.dataset["train"].train_test_split(test_size=eval_ratio)
|
342 |
+
self.dataset = {
|
343 |
+
"train": split_dataset["train"],
|
344 |
+
"validation": split_dataset["test"]
|
345 |
+
}
|
346 |
+
|
347 |
+
# Apply LoRA if selected
|
348 |
+
if training_params.get("use_lora", self.default_params["use_lora"]):
|
349 |
+
self.model = self.setup_lora(self.model, {
|
350 |
+
"lora_r": int(training_params.get("lora_r", self.default_params["lora_r"])),
|
351 |
+
"lora_alpha": int(training_params.get("lora_alpha", self.default_params["lora_alpha"])),
|
352 |
+
"lora_dropout": float(training_params.get("lora_dropout", self.default_params["lora_dropout"]))
|
353 |
+
})
|
354 |
+
|
355 |
+
# Preprocess dataset
|
356 |
+
max_length = int(training_params.get("max_length", self.default_params["max_length"]))
|
357 |
+
tokenized_dataset = self.preprocess_dataset(self.dataset, self.tokenizer, max_length)
|
358 |
+
|
359 |
+
# Update training arguments with proper type conversion
|
360 |
+
training_args = self.prepare_training_args({
|
361 |
+
"batch_size": int(training_params.get("batch_size", self.default_params["batch_size"])),
|
362 |
+
"learning_rate": float(training_params.get("learning_rate", self.default_params["learning_rate"])),
|
363 |
+
"epochs": int(training_params.get("epochs", self.default_params["epochs"])),
|
364 |
+
"weight_decay": float(training_params.get("weight_decay", self.default_params["weight_decay"])),
|
365 |
+
"warmup_ratio": float(training_params.get("warmup_ratio", self.default_params["warmup_ratio"])),
|
366 |
+
"eval_ratio": eval_ratio
|
367 |
+
})
|
368 |
+
|
369 |
+
# Create trainer with proper callback
|
370 |
+
self.trainer = self.create_trainer(
|
371 |
+
self.model,
|
372 |
+
self.tokenizer,
|
373 |
+
tokenized_dataset,
|
374 |
+
training_args
|
375 |
+
)
|
376 |
+
|
377 |
+
# Start training
|
378 |
+
self.trainer.train()
|
379 |
+
|
380 |
+
# Save the model
|
381 |
+
save_path = f"models/gemma-finetuned-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
382 |
+
os.makedirs(save_path, exist_ok=True)
|
383 |
+
self.trainer.save_model(save_path)
|
384 |
+
self.tokenizer.save_pretrained(save_path)
|
385 |
+
self.model_save_path = save_path
|
386 |
+
|
387 |
+
return f"Training completed successfully! Model saved to {save_path}"
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
import traceback
|
391 |
+
return f"Error during training: {str(e)}\n{traceback.format_exc()}"
|
392 |
+
|
393 |
+
def setup_lora(self, model, params: Dict) -> torch.nn.Module:
|
394 |
+
"""Configure LoRA for parameter-efficient fine-tuning"""
|
395 |
+
# Prepare the model for training if using 8-bit or 4-bit quantization
|
396 |
+
if hasattr(model, "is_quantized") and model.is_quantized:
|
397 |
+
model = prepare_model_for_kbit_training(model)
|
398 |
+
|
399 |
+
lora_config = LoraConfig(
|
400 |
+
r=params["lora_r"],
|
401 |
+
lora_alpha=params["lora_alpha"],
|
402 |
+
target_modules=["q_proj", "k_proj", "v_proj"],
|
403 |
+
lora_dropout=params["lora_dropout"],
|
404 |
+
bias="none",
|
405 |
+
task_type="CAUSAL_LM",
|
406 |
+
)
|
407 |
+
|
408 |
+
model = FastModel.get_peft_model(
|
409 |
+
model,
|
410 |
+
finetune_vision_layers = False, # Turn off for just text!
|
411 |
+
finetune_language_layers = True, # Should leave on!
|
412 |
+
finetune_attention_modules = True, # Attention good for GRPO
|
413 |
+
finetune_mlp_modules = True, # SHould leave on always!
|
414 |
+
|
415 |
+
r = 8, # Larger = higher accuracy, but might overfit
|
416 |
+
lora_alpha = 8, # Recommended alpha == r at least
|
417 |
+
lora_dropout = 0,
|
418 |
+
bias = "none",
|
419 |
+
random_state = 3407,
|
420 |
+
)
|
421 |
+
model.print_trainable_parameters()
|
422 |
+
model = model.to(self.device)
|
423 |
+
return model
|
424 |
+
|
425 |
+
def create_trainer(self, model, tokenizer, dataset, training_args):
|
426 |
+
"""Set up the Trainer for model fine-tuning"""
|
427 |
+
# Create data collator
|
428 |
+
data_collator = DataCollatorForLanguageModeling(
|
429 |
+
tokenizer=tokenizer,
|
430 |
+
mlm=False
|
431 |
+
)
|
432 |
+
|
433 |
+
# Custom callback to store training history
|
434 |
+
class TrainingCallback(TrainerCallback):
|
435 |
+
def __init__(self, app):
|
436 |
+
self.app = app
|
437 |
+
|
438 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
439 |
+
if logs:
|
440 |
+
for key in ['loss', 'eval_loss']:
|
441 |
+
if key in logs:
|
442 |
+
self.app.training_history[key].append(logs[key])
|
443 |
+
if 'step' in logs:
|
444 |
+
self.app.training_history['step'].append(logs['step'])
|
445 |
+
|
446 |
+
# Create trainer
|
447 |
+
trainer = Trainer(
|
448 |
+
model=model,
|
449 |
+
args=training_args,
|
450 |
+
train_dataset=dataset["train"],
|
451 |
+
eval_dataset=dataset["validation"] if "validation" in dataset else None,
|
452 |
+
data_collator=data_collator,
|
453 |
+
callbacks=[TrainingCallback]
|
454 |
+
)
|
455 |
+
|
456 |
+
return trainer
|
457 |
+
|
458 |
+
def plot_training_progress(self):
|
459 |
+
"""Generate a plot of the training progress"""
|
460 |
+
if not self.training_history["loss"]:
|
461 |
+
return None
|
462 |
+
|
463 |
+
plt.figure(figsize=(10, 6))
|
464 |
+
plt.plot(self.training_history["step"], self.training_history["loss"], label="Training Loss")
|
465 |
+
|
466 |
+
if self.training_history["eval_loss"]:
|
467 |
+
# Get the steps where eval happened
|
468 |
+
eval_steps = self.training_history["step"][:len(self.training_history["eval_loss"])]
|
469 |
+
plt.plot(eval_steps, self.training_history["eval_loss"], label="Validation Loss", linestyle="--")
|
470 |
+
|
471 |
+
plt.xlabel("Training Steps")
|
472 |
+
plt.ylabel("Loss")
|
473 |
+
plt.title("Training Progress")
|
474 |
+
plt.legend()
|
475 |
+
plt.grid(True)
|
476 |
+
|
477 |
+
return plt
|
478 |
+
|
479 |
+
def export_model(self, output_format: str) -> str:
|
480 |
+
"""Export the fine-tuned model in various formats"""
|
481 |
+
if self.model is None or self.model_save_path is None:
|
482 |
+
return "No model has been trained yet."
|
483 |
+
|
484 |
+
export_path = f"{self.model_save_path}/exported_{output_format}"
|
485 |
+
os.makedirs(export_path, exist_ok=True)
|
486 |
+
|
487 |
+
if output_format == "pytorch":
|
488 |
+
# Save as PyTorch format
|
489 |
+
self.model.save_pretrained(export_path)
|
490 |
+
self.tokenizer.save_pretrained(export_path)
|
491 |
+
return f"Model exported in PyTorch format to {export_path}"
|
492 |
+
|
493 |
+
elif output_format == "tensorflow":
|
494 |
+
# Convert to TensorFlow format
|
495 |
+
try:
|
496 |
+
from transformers.modeling_tf_utils import convert_pt_to_tf
|
497 |
+
|
498 |
+
# First save the PyTorch model
|
499 |
+
self.model.save_pretrained(export_path)
|
500 |
+
self.tokenizer.save_pretrained(export_path)
|
501 |
+
|
502 |
+
# Then convert to TF SavedModel format
|
503 |
+
tf_model = convert_pt_to_tf(self.model)
|
504 |
+
tf_model.save_pretrained(f"{export_path}/tf_saved_model")
|
505 |
+
|
506 |
+
return f"Model exported in TensorFlow format to {export_path}/tf_saved_model"
|
507 |
+
except Exception as e:
|
508 |
+
return f"Failed to export as TensorFlow model: {str(e)}"
|
509 |
+
|
510 |
+
elif output_format == "gguf":
|
511 |
+
# Export as GGUF format for local inference
|
512 |
+
try:
|
513 |
+
import subprocess
|
514 |
+
|
515 |
+
# First save the model in PyTorch format
|
516 |
+
self.model.save_pretrained(export_path)
|
517 |
+
self.tokenizer.save_pretrained(export_path)
|
518 |
+
|
519 |
+
# Use llama.cpp's conversion script (must be installed)
|
520 |
+
subprocess.run([
|
521 |
+
"python", "-m", "llama_cpp.convert",
|
522 |
+
"--outtype", "gguf",
|
523 |
+
"--outfile", f"{export_path}/model.gguf",
|
524 |
+
export_path
|
525 |
+
])
|
526 |
+
|
527 |
+
return f"Model exported in GGUF format to {export_path}/model.gguf"
|
528 |
+
except Exception as e:
|
529 |
+
return f"Failed to export as GGUF model: {str(e)}"
|
530 |
+
|
531 |
+
else:
|
532 |
+
return f"Unsupported export format: {output_format}"
|
533 |
+
|
534 |
+
def generate_text(self, prompt: str, max_length: int = 100) -> str:
|
535 |
+
"""Generate text using the fine-tuned model"""
|
536 |
+
if self.model is None or self.tokenizer is None:
|
537 |
+
return "No model has been loaded or fine-tuned yet."
|
538 |
+
|
539 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
540 |
+
|
541 |
+
with torch.no_grad():
|
542 |
+
outputs = self.model.generate(
|
543 |
+
**inputs,
|
544 |
+
max_length=max_length + inputs.input_ids.shape[1],
|
545 |
+
temperature=0.7,
|
546 |
+
top_p=0.9,
|
547 |
+
do_sample=True,
|
548 |
+
pad_token_id=self.tokenizer.pad_token_id
|
549 |
+
)
|
550 |
+
|
551 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
552 |
+
return generated_text
|
Gemma-Finetune/utils/sample_dataset.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
|
6 |
+
def generate_sample_datasets(output_dir="./sample_datasets"):
|
7 |
+
"""
|
8 |
+
Generate sample datasets in CSV, JSONL, and text formats for testing the Gemma fine-tuning UI
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
output_dir (str): Directory to save the sample datasets
|
12 |
+
"""
|
13 |
+
# Create output directory if it doesn't exist
|
14 |
+
os.makedirs(output_dir, exist_ok=True)
|
15 |
+
|
16 |
+
# Sample data for instruction-response pairs
|
17 |
+
instruction_data = [
|
18 |
+
{"instruction": "Explain the concept of machine learning in simple terms.",
|
19 |
+
"response": "Machine learning is when computers learn from examples rather than being explicitly programmed. It's like teaching a child by showing examples instead of giving exact rules."},
|
20 |
+
{"instruction": "What are the health benefits of regular exercise?",
|
21 |
+
"response": "Regular exercise improves cardiovascular health, strengthens muscles and bones, helps control weight, reduces stress and anxiety, improves sleep quality, and increases overall energy levels."},
|
22 |
+
{"instruction": "How do I make a basic pasta sauce?",
|
23 |
+
"response": "For a basic tomato pasta sauce, sauté minced garlic and onions in olive oil until soft. Add canned crushed tomatoes, salt, pepper, and Italian herbs. Simmer for 20-30 minutes, stirring occasionally. Add basil at the end for fresh flavor."},
|
24 |
+
{"instruction": "Summarize the water cycle.",
|
25 |
+
"response": "The water cycle is the continuous movement of water on, above, and below Earth's surface. It involves evaporation (water turning to vapor), condensation (vapor forming clouds), precipitation (rain or snow), and collection (water returning to bodies of water)."},
|
26 |
+
{"instruction": "What is the difference between a simile and a metaphor?",
|
27 |
+
"response": "A simile compares two things using 'like' or 'as' (e.g., 'as brave as a lion'). A metaphor directly states that one thing is another (e.g., 'he is a lion in battle'). Both are figurative language techniques used to create vivid descriptions."}
|
28 |
+
]
|
29 |
+
|
30 |
+
# Generate more instruction-response pairs
|
31 |
+
topics = ["history", "science", "literature", "cooking", "technology", "health", "travel", "sports", "music", "art"]
|
32 |
+
question_starters = ["Explain", "Describe", "How to", "What is", "Why does", "Compare", "Summarize", "List ways to", "Define", "Analyze"]
|
33 |
+
|
34 |
+
for _ in range(20):
|
35 |
+
topic = random.choice(topics)
|
36 |
+
starter = random.choice(question_starters)
|
37 |
+
instruction = f"{starter} {topic.lower()} {random.choice(['concepts', 'principles', 'ideas', 'techniques', 'methods'])}"
|
38 |
+
response = f"This is a sample response about {topic} that would be more detailed in a real dataset. It would contain multiple sentences explaining {topic} concepts in depth."
|
39 |
+
instruction_data.append({"instruction": instruction, "response": response})
|
40 |
+
|
41 |
+
# Create a dictionary to store sample datasets
|
42 |
+
datasets = {}
|
43 |
+
|
44 |
+
# 1. Create CSV in instruction-response format
|
45 |
+
df_instruction = pd.DataFrame(instruction_data)
|
46 |
+
datasets["instruction_response.csv"] = df_instruction
|
47 |
+
|
48 |
+
# 2. Create CSV in input-output format
|
49 |
+
input_output_data = [{"input": item["instruction"], "output": item["response"]} for item in instruction_data]
|
50 |
+
df_input_output = pd.DataFrame(input_output_data)
|
51 |
+
datasets["input_output.csv"] = df_input_output
|
52 |
+
|
53 |
+
# 3. Create CSV in text-only format
|
54 |
+
text_data = [{"text": f"Q: {item['instruction']}\nA: {item['response']}"} for item in instruction_data]
|
55 |
+
df_text = pd.DataFrame(text_data)
|
56 |
+
datasets["text_only.csv"] = df_text
|
57 |
+
|
58 |
+
# 4. Create CSV with non-standard format
|
59 |
+
custom_data = [{"question": item["instruction"], "answer": item["response"]} for item in instruction_data]
|
60 |
+
df_custom = pd.DataFrame(custom_data)
|
61 |
+
datasets["custom_format.csv"] = df_custom
|
62 |
+
|
63 |
+
# 5. Create JSONL in instruction-response format
|
64 |
+
jsonl_instruction = instruction_data
|
65 |
+
datasets["instruction_response.jsonl"] = jsonl_instruction
|
66 |
+
|
67 |
+
# 6. Create JSONL in prompt-completion format
|
68 |
+
prompt_completion_data = [{"prompt": item["instruction"], "completion": item["response"]} for item in instruction_data]
|
69 |
+
datasets["prompt_completion.jsonl"] = prompt_completion_data
|
70 |
+
|
71 |
+
# 7. Create JSONL with non-standard format
|
72 |
+
jsonl_custom = [{"query": item["instruction"], "result": item["response"]} for item in instruction_data]
|
73 |
+
datasets["custom_format.jsonl"] = jsonl_custom
|
74 |
+
|
75 |
+
# 8. Create text format (paragraphs)
|
76 |
+
text_paragraphs = "\n\n".join([f"Q: {item['instruction']}\nA: {item['response']}" for item in instruction_data])
|
77 |
+
datasets["paragraphs.txt"] = text_paragraphs
|
78 |
+
|
79 |
+
# 9. Create text format (single examples per line)
|
80 |
+
text_lines = "\n".join([f"{item['instruction']} => {item['response']}" for item in instruction_data])
|
81 |
+
datasets["single_lines.txt"] = text_lines
|
82 |
+
|
83 |
+
# Save all datasets
|
84 |
+
for filename, data in datasets.items():
|
85 |
+
file_path = os.path.join(output_dir, filename)
|
86 |
+
|
87 |
+
if filename.endswith('.csv'):
|
88 |
+
data.to_csv(file_path, index=False)
|
89 |
+
elif filename.endswith('.jsonl'):
|
90 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
91 |
+
for item in data:
|
92 |
+
f.write(json.dumps(item) + '\n')
|
93 |
+
elif filename.endswith('.txt'):
|
94 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
95 |
+
f.write(data)
|
96 |
+
|
97 |
+
print(f"Sample datasets generated in {output_dir}")
|
98 |
+
return list(datasets.keys())
|
99 |
+
|
100 |
+
# if __name__ == "__main__":
|
101 |
+
# # Generate sample datasets
|
102 |
+
# generated_files = generate_sample_datasets()
|
103 |
+
# print(f"Generated {len(generated_files)} sample dataset files:")
|
104 |
+
# for file in generated_files:
|
105 |
+
# print(f" - {file}")
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.21.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Gemma_Finetuner
|
3 |
+
app_file: /content/Gemma-Finetune/main.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.21.0
|
|
|
|
|
6 |
---
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
transformers
|
3 |
+
unsloth
|
4 |
+
peft
|
5 |
+
datasets
|
6 |
+
torch
|
sample_data/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This directory includes a few sample datasets to get you started.
|
2 |
+
|
3 |
+
* `california_housing_data*.csv` is California housing data from the 1990 US
|
4 |
+
Census; more information is available at:
|
5 |
+
https://docs.google.com/document/d/e/2PACX-1vRhYtsvc5eOR2FWNCwaBiKL6suIOrxJig8LcSBbmCbyYsayia_DvPOOBlXZ4CAlQ5nlDD8kTaIDRwrN/pub
|
6 |
+
|
7 |
+
* `mnist_*.csv` is a small sample of the
|
8 |
+
[MNIST database](https://en.wikipedia.org/wiki/MNIST_database), which is
|
9 |
+
described at: http://yann.lecun.com/exdb/mnist/
|
10 |
+
|
11 |
+
* `anscombe.json` contains a copy of
|
12 |
+
[Anscombe's quartet](https://en.wikipedia.org/wiki/Anscombe%27s_quartet); it
|
13 |
+
was originally described in
|
14 |
+
|
15 |
+
Anscombe, F. J. (1973). 'Graphs in Statistical Analysis'. American
|
16 |
+
Statistician. 27 (1): 17-21. JSTOR 2682899.
|
17 |
+
|
18 |
+
and our copy was prepared by the
|
19 |
+
[vega_datasets library](https://github.com/altair-viz/vega_datasets/blob/4f67bdaad10f45e3549984e17e1b3088c731503d/vega_datasets/_data/anscombe.json).
|
sample_data/anscombe.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{"Series":"I", "X":10.0, "Y":8.04},
|
3 |
+
{"Series":"I", "X":8.0, "Y":6.95},
|
4 |
+
{"Series":"I", "X":13.0, "Y":7.58},
|
5 |
+
{"Series":"I", "X":9.0, "Y":8.81},
|
6 |
+
{"Series":"I", "X":11.0, "Y":8.33},
|
7 |
+
{"Series":"I", "X":14.0, "Y":9.96},
|
8 |
+
{"Series":"I", "X":6.0, "Y":7.24},
|
9 |
+
{"Series":"I", "X":4.0, "Y":4.26},
|
10 |
+
{"Series":"I", "X":12.0, "Y":10.84},
|
11 |
+
{"Series":"I", "X":7.0, "Y":4.81},
|
12 |
+
{"Series":"I", "X":5.0, "Y":5.68},
|
13 |
+
|
14 |
+
{"Series":"II", "X":10.0, "Y":9.14},
|
15 |
+
{"Series":"II", "X":8.0, "Y":8.14},
|
16 |
+
{"Series":"II", "X":13.0, "Y":8.74},
|
17 |
+
{"Series":"II", "X":9.0, "Y":8.77},
|
18 |
+
{"Series":"II", "X":11.0, "Y":9.26},
|
19 |
+
{"Series":"II", "X":14.0, "Y":8.10},
|
20 |
+
{"Series":"II", "X":6.0, "Y":6.13},
|
21 |
+
{"Series":"II", "X":4.0, "Y":3.10},
|
22 |
+
{"Series":"II", "X":12.0, "Y":9.13},
|
23 |
+
{"Series":"II", "X":7.0, "Y":7.26},
|
24 |
+
{"Series":"II", "X":5.0, "Y":4.74},
|
25 |
+
|
26 |
+
{"Series":"III", "X":10.0, "Y":7.46},
|
27 |
+
{"Series":"III", "X":8.0, "Y":6.77},
|
28 |
+
{"Series":"III", "X":13.0, "Y":12.74},
|
29 |
+
{"Series":"III", "X":9.0, "Y":7.11},
|
30 |
+
{"Series":"III", "X":11.0, "Y":7.81},
|
31 |
+
{"Series":"III", "X":14.0, "Y":8.84},
|
32 |
+
{"Series":"III", "X":6.0, "Y":6.08},
|
33 |
+
{"Series":"III", "X":4.0, "Y":5.39},
|
34 |
+
{"Series":"III", "X":12.0, "Y":8.15},
|
35 |
+
{"Series":"III", "X":7.0, "Y":6.42},
|
36 |
+
{"Series":"III", "X":5.0, "Y":5.73},
|
37 |
+
|
38 |
+
{"Series":"IV", "X":8.0, "Y":6.58},
|
39 |
+
{"Series":"IV", "X":8.0, "Y":5.76},
|
40 |
+
{"Series":"IV", "X":8.0, "Y":7.71},
|
41 |
+
{"Series":"IV", "X":8.0, "Y":8.84},
|
42 |
+
{"Series":"IV", "X":8.0, "Y":8.47},
|
43 |
+
{"Series":"IV", "X":8.0, "Y":7.04},
|
44 |
+
{"Series":"IV", "X":8.0, "Y":5.25},
|
45 |
+
{"Series":"IV", "X":19.0, "Y":12.50},
|
46 |
+
{"Series":"IV", "X":8.0, "Y":5.56},
|
47 |
+
{"Series":"IV", "X":8.0, "Y":7.91},
|
48 |
+
{"Series":"IV", "X":8.0, "Y":6.89}
|
49 |
+
]
|
sample_data/california_housing_test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_data/california_housing_train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_data/mnist_test.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51c292478d94ec3a01461bdfa82eb0885d262eb09e615679b2d69dedb6ad09e7
|
3 |
+
size 18289443
|
sample_data/mnist_train_small.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ef64781aa03180f4f5ce504314f058f5d0227277df86060473d973cf43b033e
|
3 |
+
size 36523880
|
unsloth_compiled_cache/UnslothAlignPropTrainer.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothAlignPropConfig(AlignPropConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`AlignPropTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
+
Name of this experiment (defaults to the file name without the extension).
|
55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
56 |
+
Name of this run.
|
57 |
+
seed (`int`, *optional*, defaults to `0`):
|
58 |
+
Random seed for reproducibility.
|
59 |
+
log_with (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Log with either `"wandb"` or `"tensorboard"`. Check
|
61 |
+
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
62 |
+
log_image_freq (`int`, *optional*, defaults to `1`):
|
63 |
+
Frequency for logging images.
|
64 |
+
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
65 |
+
Keyword arguments for the tracker (e.g., `wandb_project`).
|
66 |
+
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
67 |
+
Keyword arguments for the accelerator.
|
68 |
+
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
69 |
+
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
70 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
71 |
+
Name of project to use for tracking.
|
72 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
73 |
+
Top-level logging directory for checkpoint saving.
|
74 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
75 |
+
Number of epochs to train.
|
76 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
77 |
+
Number of epochs between saving model checkpoints.
|
78 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
79 |
+
Number of checkpoints to keep before overwriting old ones.
|
80 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
81 |
+
Mixed precision training.
|
82 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
83 |
+
Allow `tf32` on Ampere GPUs.
|
84 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
85 |
+
Path to resume training from a checkpoint.
|
86 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
87 |
+
Number of sampler inference steps.
|
88 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
89 |
+
Eta parameter for the DDIM sampler.
|
90 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
91 |
+
Classifier-free guidance weight.
|
92 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
93 |
+
Batch size for training.
|
94 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
95 |
+
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
96 |
+
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
97 |
+
Learning rate.
|
98 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
99 |
+
Beta1 for Adam optimizer.
|
100 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
101 |
+
Beta2 for Adam optimizer.
|
102 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
103 |
+
Weight decay for Adam optimizer.
|
104 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
105 |
+
Epsilon value for Adam optimizer.
|
106 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
107 |
+
Number of gradient accumulation steps.
|
108 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
109 |
+
Maximum gradient norm for gradient clipping.
|
110 |
+
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
111 |
+
Comma-separated list of prompts to use as negative examples.
|
112 |
+
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
113 |
+
If `True`, randomized truncation to different diffusion timesteps is used.
|
114 |
+
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
115 |
+
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
116 |
+
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
117 |
+
Range of diffusion timesteps for randomized truncated backpropagation.
|
118 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
119 |
+
Whether to push the final model to the Hub.
|
120 |
+
|
121 |
+
"""
|
122 |
+
vllm_sampling_params: Optional[Any] = field(
|
123 |
+
default = None,
|
124 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
125 |
+
)
|
126 |
+
unsloth_num_chunks : Optional[int] = field(
|
127 |
+
default = -1,
|
128 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
129 |
+
)
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
exp_name = 'main',
|
133 |
+
run_name = '',
|
134 |
+
seed = 3407,
|
135 |
+
log_with = None,
|
136 |
+
log_image_freq = 1,
|
137 |
+
tracker_project_name = 'trl',
|
138 |
+
logdir = 'logs',
|
139 |
+
num_epochs = 100,
|
140 |
+
save_freq = 1,
|
141 |
+
num_checkpoint_limit = 5,
|
142 |
+
mixed_precision = 'fp16',
|
143 |
+
allow_tf32 = True,
|
144 |
+
resume_from = '',
|
145 |
+
sample_num_steps = 50,
|
146 |
+
sample_eta = 1.0,
|
147 |
+
sample_guidance_scale = 5.0,
|
148 |
+
train_batch_size = 1,
|
149 |
+
train_use_8bit_adam = False,
|
150 |
+
train_learning_rate = 5e-05,
|
151 |
+
train_adam_beta1 = 0.9,
|
152 |
+
train_adam_beta2 = 0.999,
|
153 |
+
train_adam_weight_decay = 0.01,
|
154 |
+
train_adam_epsilon = 1e-08,
|
155 |
+
train_gradient_accumulation_steps = 2,
|
156 |
+
train_max_grad_norm = 1.0,
|
157 |
+
negative_prompts = None,
|
158 |
+
truncated_backprop_rand = True,
|
159 |
+
truncated_backprop_timestep = 49,
|
160 |
+
push_to_hub = False,
|
161 |
+
vllm_sampling_params = None,
|
162 |
+
unsloth_num_chunks = -1,
|
163 |
+
**kwargs,
|
164 |
+
):
|
165 |
+
|
166 |
+
super().__init__(
|
167 |
+
exp_name = exp_name,
|
168 |
+
run_name = run_name,
|
169 |
+
seed = seed,
|
170 |
+
log_with = log_with,
|
171 |
+
log_image_freq = log_image_freq,
|
172 |
+
tracker_project_name = tracker_project_name,
|
173 |
+
logdir = logdir,
|
174 |
+
num_epochs = num_epochs,
|
175 |
+
save_freq = save_freq,
|
176 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
177 |
+
mixed_precision = mixed_precision,
|
178 |
+
allow_tf32 = allow_tf32,
|
179 |
+
resume_from = resume_from,
|
180 |
+
sample_num_steps = sample_num_steps,
|
181 |
+
sample_eta = sample_eta,
|
182 |
+
sample_guidance_scale = sample_guidance_scale,
|
183 |
+
train_batch_size = train_batch_size,
|
184 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
185 |
+
train_learning_rate = train_learning_rate,
|
186 |
+
train_adam_beta1 = train_adam_beta1,
|
187 |
+
train_adam_beta2 = train_adam_beta2,
|
188 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
189 |
+
train_adam_epsilon = train_adam_epsilon,
|
190 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
191 |
+
train_max_grad_norm = train_max_grad_norm,
|
192 |
+
negative_prompts = negative_prompts,
|
193 |
+
truncated_backprop_rand = truncated_backprop_rand,
|
194 |
+
truncated_backprop_timestep = truncated_backprop_timestep,
|
195 |
+
push_to_hub = push_to_hub,**kwargs)
|
196 |
+
self.vllm_sampling_params = vllm_sampling_params
|
197 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
198 |
+
pass
|
199 |
+
|
200 |
+
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
201 |
+
""""""
|
202 |
+
|
203 |
+
_tag_names = ["trl", "alignprop"]
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
config: AlignPropConfig,
|
208 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
209 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
210 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
211 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
212 |
+
):
|
213 |
+
if image_samples_hook is None:
|
214 |
+
warn("No image_samples_hook provided; no images will be logged")
|
215 |
+
|
216 |
+
self.prompt_fn = prompt_function
|
217 |
+
self.reward_fn = reward_function
|
218 |
+
self.config = config
|
219 |
+
self.image_samples_callback = image_samples_hook
|
220 |
+
|
221 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
222 |
+
|
223 |
+
if self.config.resume_from:
|
224 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
225 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
226 |
+
# get the most recent checkpoint in this directory
|
227 |
+
checkpoints = list(
|
228 |
+
filter(
|
229 |
+
lambda x: "checkpoint_" in x,
|
230 |
+
os.listdir(self.config.resume_from),
|
231 |
+
)
|
232 |
+
)
|
233 |
+
if len(checkpoints) == 0:
|
234 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
235 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
236 |
+
self.config.resume_from = os.path.join(
|
237 |
+
self.config.resume_from,
|
238 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
239 |
+
)
|
240 |
+
|
241 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
242 |
+
|
243 |
+
self.accelerator = Accelerator(
|
244 |
+
log_with=self.config.log_with,
|
245 |
+
mixed_precision=self.config.mixed_precision,
|
246 |
+
project_config=accelerator_project_config,
|
247 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
248 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
249 |
+
# the total number of optimizer steps to accumulate across.
|
250 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
251 |
+
**self.config.accelerator_kwargs,
|
252 |
+
)
|
253 |
+
|
254 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
255 |
+
|
256 |
+
if self.accelerator.is_main_process:
|
257 |
+
self.accelerator.init_trackers(
|
258 |
+
self.config.tracker_project_name,
|
259 |
+
config=dict(alignprop_trainer_config=config.to_dict())
|
260 |
+
if not is_using_tensorboard
|
261 |
+
else config.to_dict(),
|
262 |
+
init_kwargs=self.config.tracker_kwargs,
|
263 |
+
)
|
264 |
+
|
265 |
+
logger.info(f"\n{config}")
|
266 |
+
|
267 |
+
set_seed(self.config.seed, device_specific=True)
|
268 |
+
|
269 |
+
self.sd_pipeline = sd_pipeline
|
270 |
+
|
271 |
+
self.sd_pipeline.set_progress_bar_config(
|
272 |
+
position=1,
|
273 |
+
disable=not self.accelerator.is_local_main_process,
|
274 |
+
leave=False,
|
275 |
+
desc="Timestep",
|
276 |
+
dynamic_ncols=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
280 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
281 |
+
if self.accelerator.mixed_precision == "fp16":
|
282 |
+
inference_dtype = torch.float16
|
283 |
+
elif self.accelerator.mixed_precision == "bf16":
|
284 |
+
inference_dtype = torch.bfloat16
|
285 |
+
else:
|
286 |
+
inference_dtype = torch.float32
|
287 |
+
|
288 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
289 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
290 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
291 |
+
|
292 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
293 |
+
|
294 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
295 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
296 |
+
|
297 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
298 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
299 |
+
if self.config.allow_tf32:
|
300 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
301 |
+
|
302 |
+
self.optimizer = self._setup_optimizer(
|
303 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
304 |
+
)
|
305 |
+
|
306 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
307 |
+
self.sd_pipeline.tokenizer(
|
308 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
309 |
+
return_tensors="pt",
|
310 |
+
padding="max_length",
|
311 |
+
truncation=True,
|
312 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
313 |
+
).input_ids.to(self.accelerator.device)
|
314 |
+
)[0]
|
315 |
+
|
316 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
317 |
+
# more memory
|
318 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
319 |
+
|
320 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
321 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
322 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
323 |
+
else:
|
324 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
325 |
+
|
326 |
+
if config.resume_from:
|
327 |
+
logger.info(f"Resuming from {config.resume_from}")
|
328 |
+
self.accelerator.load_state(config.resume_from)
|
329 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
330 |
+
else:
|
331 |
+
self.first_epoch = 0
|
332 |
+
|
333 |
+
def compute_rewards(self, prompt_image_pairs):
|
334 |
+
reward, reward_metadata = self.reward_fn(
|
335 |
+
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
336 |
+
)
|
337 |
+
return reward
|
338 |
+
|
339 |
+
def step(self, epoch: int, global_step: int):
|
340 |
+
"""
|
341 |
+
Perform a single step of training.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
epoch (int): The current epoch.
|
345 |
+
global_step (int): The current global step.
|
346 |
+
|
347 |
+
Side Effects:
|
348 |
+
- Model weights are updated
|
349 |
+
- Logs the statistics to the accelerator trackers.
|
350 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
global_step (int): The updated global step.
|
354 |
+
"""
|
355 |
+
info = defaultdict(list)
|
356 |
+
|
357 |
+
self.sd_pipeline.unet.train()
|
358 |
+
|
359 |
+
for _ in range(self.config.train_gradient_accumulation_steps):
|
360 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
361 |
+
prompt_image_pairs = self._generate_samples(
|
362 |
+
batch_size=self.config.train_batch_size,
|
363 |
+
)
|
364 |
+
|
365 |
+
rewards = self.compute_rewards(prompt_image_pairs)
|
366 |
+
|
367 |
+
prompt_image_pairs["rewards"] = rewards
|
368 |
+
|
369 |
+
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
370 |
+
|
371 |
+
loss = self.calculate_loss(rewards)
|
372 |
+
|
373 |
+
self.accelerator.backward(loss)
|
374 |
+
|
375 |
+
if self.accelerator.sync_gradients:
|
376 |
+
self.accelerator.clip_grad_norm_(
|
377 |
+
self.trainable_layers.parameters()
|
378 |
+
if not isinstance(self.trainable_layers, list)
|
379 |
+
else self.trainable_layers,
|
380 |
+
self.config.train_max_grad_norm,
|
381 |
+
)
|
382 |
+
|
383 |
+
self.optimizer.step()
|
384 |
+
self.optimizer.zero_grad()
|
385 |
+
|
386 |
+
info["reward_mean"].append(rewards_vis.mean())
|
387 |
+
info["reward_std"].append(rewards_vis.std())
|
388 |
+
info["loss"].append(loss.item())
|
389 |
+
|
390 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
391 |
+
if self.accelerator.sync_gradients:
|
392 |
+
# log training-related stuff
|
393 |
+
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
394 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
395 |
+
info.update({"epoch": epoch})
|
396 |
+
self.accelerator.log(info, step=global_step)
|
397 |
+
global_step += 1
|
398 |
+
info = defaultdict(list)
|
399 |
+
else:
|
400 |
+
raise ValueError(
|
401 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
402 |
+
)
|
403 |
+
# Logs generated images
|
404 |
+
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
405 |
+
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
406 |
+
|
407 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
408 |
+
self.accelerator.save_state()
|
409 |
+
|
410 |
+
return global_step
|
411 |
+
|
412 |
+
def calculate_loss(self, rewards):
|
413 |
+
"""
|
414 |
+
Calculate the loss for a batch of an unpacked sample
|
415 |
+
|
416 |
+
Args:
|
417 |
+
rewards (torch.Tensor):
|
418 |
+
Differentiable reward scalars for each generated image, shape: [batch_size]
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
loss (torch.Tensor)
|
422 |
+
(all of these are of shape (1,))
|
423 |
+
"""
|
424 |
+
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
425 |
+
loss = 10.0 - (rewards).mean()
|
426 |
+
return loss
|
427 |
+
|
428 |
+
def loss(
|
429 |
+
self,
|
430 |
+
advantages: torch.Tensor,
|
431 |
+
clip_range: float,
|
432 |
+
ratio: torch.Tensor,
|
433 |
+
):
|
434 |
+
unclipped_loss = -advantages * ratio
|
435 |
+
clipped_loss = -advantages * torch.clamp(
|
436 |
+
ratio,
|
437 |
+
1.0 - clip_range,
|
438 |
+
1.0 + clip_range,
|
439 |
+
)
|
440 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
441 |
+
|
442 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
443 |
+
if self.config.train_use_8bit_adam:
|
444 |
+
import bitsandbytes
|
445 |
+
|
446 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
447 |
+
else:
|
448 |
+
optimizer_cls = torch.optim.AdamW
|
449 |
+
|
450 |
+
return optimizer_cls(
|
451 |
+
trainable_layers_parameters,
|
452 |
+
lr=self.config.train_learning_rate,
|
453 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
454 |
+
weight_decay=self.config.train_adam_weight_decay,
|
455 |
+
eps=self.config.train_adam_epsilon,
|
456 |
+
)
|
457 |
+
|
458 |
+
def _save_model_hook(self, models, weights, output_dir):
|
459 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
460 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
461 |
+
|
462 |
+
def _load_model_hook(self, models, input_dir):
|
463 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
464 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
465 |
+
|
466 |
+
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
467 |
+
"""
|
468 |
+
Generate samples from the model
|
469 |
+
|
470 |
+
Args:
|
471 |
+
batch_size (int): Batch size to use for sampling
|
472 |
+
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
prompt_image_pairs (dict[Any])
|
476 |
+
"""
|
477 |
+
prompt_image_pairs = {}
|
478 |
+
|
479 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
480 |
+
|
481 |
+
if prompts is None:
|
482 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
483 |
+
else:
|
484 |
+
prompt_metadata = [{} for _ in range(batch_size)]
|
485 |
+
|
486 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
487 |
+
prompts,
|
488 |
+
return_tensors="pt",
|
489 |
+
padding="max_length",
|
490 |
+
truncation=True,
|
491 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
492 |
+
).input_ids.to(self.accelerator.device)
|
493 |
+
|
494 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
495 |
+
|
496 |
+
if with_grad:
|
497 |
+
sd_output = self.sd_pipeline.rgb_with_grad(
|
498 |
+
prompt_embeds=prompt_embeds,
|
499 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
500 |
+
num_inference_steps=self.config.sample_num_steps,
|
501 |
+
guidance_scale=self.config.sample_guidance_scale,
|
502 |
+
eta=self.config.sample_eta,
|
503 |
+
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
504 |
+
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
505 |
+
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
506 |
+
output_type="pt",
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
sd_output = self.sd_pipeline(
|
510 |
+
prompt_embeds=prompt_embeds,
|
511 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
512 |
+
num_inference_steps=self.config.sample_num_steps,
|
513 |
+
guidance_scale=self.config.sample_guidance_scale,
|
514 |
+
eta=self.config.sample_eta,
|
515 |
+
output_type="pt",
|
516 |
+
)
|
517 |
+
|
518 |
+
images = sd_output.images
|
519 |
+
|
520 |
+
prompt_image_pairs["images"] = images
|
521 |
+
prompt_image_pairs["prompts"] = prompts
|
522 |
+
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
523 |
+
|
524 |
+
return prompt_image_pairs
|
525 |
+
|
526 |
+
def train(self, epochs: Optional[int] = None):
|
527 |
+
"""
|
528 |
+
Train the model for a given number of epochs
|
529 |
+
"""
|
530 |
+
global_step = 0
|
531 |
+
if epochs is None:
|
532 |
+
epochs = self.config.num_epochs
|
533 |
+
for epoch in range(self.first_epoch, epochs):
|
534 |
+
global_step = self.step(epoch, global_step)
|
535 |
+
|
536 |
+
def _save_pretrained(self, save_directory):
|
537 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
538 |
+
self.create_model_card()
|
539 |
+
|
540 |
+
def create_model_card(
|
541 |
+
self,
|
542 |
+
model_name: Optional[str] = None,
|
543 |
+
dataset_name: Optional[str] = None,
|
544 |
+
tags: Union[str, list[str], None] = None,
|
545 |
+
):
|
546 |
+
"""
|
547 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
551 |
+
Name of the model.
|
552 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
553 |
+
Name of the dataset used for training.
|
554 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
555 |
+
Tags to be associated with the model card.
|
556 |
+
"""
|
557 |
+
if not self.is_world_process_zero():
|
558 |
+
return
|
559 |
+
|
560 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
561 |
+
base_model = self.model.config._name_or_path
|
562 |
+
else:
|
563 |
+
base_model = None
|
564 |
+
|
565 |
+
tags = tags or []
|
566 |
+
if isinstance(tags, str):
|
567 |
+
tags = [tags]
|
568 |
+
|
569 |
+
if hasattr(self.model.config, "unsloth_version"):
|
570 |
+
tags.append("unsloth")
|
571 |
+
|
572 |
+
citation = textwrap.dedent("""\
|
573 |
+
@article{prabhudesai2024aligning,
|
574 |
+
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
575 |
+
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
576 |
+
year = 2024,
|
577 |
+
eprint = {arXiv:2310.03739}
|
578 |
+
}""")
|
579 |
+
|
580 |
+
model_card = generate_model_card(
|
581 |
+
base_model=base_model,
|
582 |
+
model_name=model_name,
|
583 |
+
hub_model_id=self.hub_model_id,
|
584 |
+
dataset_name=dataset_name,
|
585 |
+
tags=tags,
|
586 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
587 |
+
comet_url=get_comet_experiment_url(),
|
588 |
+
trainer_name="AlignProp",
|
589 |
+
trainer_citation=citation,
|
590 |
+
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
591 |
+
paper_id="2310.03739",
|
592 |
+
)
|
593 |
+
|
594 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
595 |
+
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
596 |
+
"""
|
597 |
+
|
598 |
+
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
599 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
600 |
+
As of now only Stable Diffusion based pipelines are supported
|
601 |
+
|
602 |
+
Attributes:
|
603 |
+
config (`AlignPropConfig`):
|
604 |
+
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
605 |
+
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
606 |
+
Reward function to be used
|
607 |
+
prompt_function (`Callable[[], tuple[str, Any]]`):
|
608 |
+
Function to generate prompts to guide model
|
609 |
+
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
610 |
+
Stable Diffusion pipeline to be used for training.
|
611 |
+
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
612 |
+
Hook to be called to log images
|
613 |
+
|
614 |
+
"""
|
615 |
+
def __init__(
|
616 |
+
self,
|
617 |
+
config,
|
618 |
+
reward_function,
|
619 |
+
prompt_function,
|
620 |
+
sd_pipeline,
|
621 |
+
image_samples_hook = None,
|
622 |
+
**kwargs
|
623 |
+
):
|
624 |
+
if args is None: args = UnslothAlignPropConfig()
|
625 |
+
other_metrics = []
|
626 |
+
|
627 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
628 |
+
PatchRLStatistics('alignprop_trainer', other_metrics)
|
629 |
+
|
630 |
+
super().__init__(
|
631 |
+
config = config,
|
632 |
+
reward_function = reward_function,
|
633 |
+
prompt_function = prompt_function,
|
634 |
+
sd_pipeline = sd_pipeline,
|
635 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
636 |
+
|
637 |
+
pass
|
unsloth_compiled_cache/UnslothBCOTrainer.py
ADDED
@@ -0,0 +1,1822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothBCOConfig(BCOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`BCOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
54 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
55 |
+
to use the default data collator.
|
56 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
57 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
58 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
59 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
60 |
+
and your model is an encoder-decoder.
|
61 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
62 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
63 |
+
reference model.
|
64 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
65 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
66 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
67 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
68 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
69 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
70 |
+
This argument is required if you want to use the default data collator.
|
71 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
72 |
+
Whether to disable dropout in the model and reference model.
|
73 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
74 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
75 |
+
evaluation.
|
76 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
77 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
78 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
79 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
81 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
82 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
83 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
84 |
+
string.
|
85 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
86 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
87 |
+
from a string.
|
88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
+
Number of processes to use for processing the dataset.
|
90 |
+
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
91 |
+
Number of prompts that are fed to density ratio classifier.
|
92 |
+
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
93 |
+
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
94 |
+
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
95 |
+
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
96 |
+
|
97 |
+
"""
|
98 |
+
vllm_sampling_params: Optional[Any] = field(
|
99 |
+
default = None,
|
100 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
101 |
+
)
|
102 |
+
unsloth_num_chunks : Optional[int] = field(
|
103 |
+
default = -1,
|
104 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
105 |
+
)
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
output_dir = None,
|
109 |
+
overwrite_output_dir = None,
|
110 |
+
do_train = False,
|
111 |
+
do_eval = False,
|
112 |
+
do_predict = False,
|
113 |
+
eval_strategy = 'no',
|
114 |
+
prediction_loss_only = False,
|
115 |
+
per_device_train_batch_size = 4,
|
116 |
+
per_device_eval_batch_size = 4,
|
117 |
+
per_gpu_train_batch_size = None,
|
118 |
+
per_gpu_eval_batch_size = None,
|
119 |
+
gradient_accumulation_steps = 2,
|
120 |
+
eval_accumulation_steps = 2,
|
121 |
+
eval_delay = 0,
|
122 |
+
torch_empty_cache_steps = 250,
|
123 |
+
learning_rate = 5e-05,
|
124 |
+
weight_decay = 0.01,
|
125 |
+
adam_beta1 = 0.9,
|
126 |
+
adam_beta2 = 0.999,
|
127 |
+
adam_epsilon = 1e-08,
|
128 |
+
max_grad_norm = 1.0,
|
129 |
+
num_train_epochs = 3.0,
|
130 |
+
max_steps = -1,
|
131 |
+
lr_scheduler_type = 'linear',
|
132 |
+
warmup_ratio = 0.1,
|
133 |
+
warmup_steps = 0,
|
134 |
+
log_level = 'passive',
|
135 |
+
log_level_replica = 'warning',
|
136 |
+
log_on_each_node = True,
|
137 |
+
logging_dir = None,
|
138 |
+
logging_strategy = 'steps',
|
139 |
+
logging_first_step = False,
|
140 |
+
logging_steps = 1,
|
141 |
+
logging_nan_inf_filter = False,
|
142 |
+
save_strategy = 'steps',
|
143 |
+
save_steps = 500,
|
144 |
+
save_total_limit = None,
|
145 |
+
save_safetensors = True,
|
146 |
+
save_on_each_node = False,
|
147 |
+
save_only_model = False,
|
148 |
+
restore_callback_states_from_checkpoint = False,
|
149 |
+
no_cuda = False,
|
150 |
+
use_cpu = False,
|
151 |
+
use_mps_device = False,
|
152 |
+
seed = 3407,
|
153 |
+
data_seed = 3407,
|
154 |
+
jit_mode_eval = False,
|
155 |
+
use_ipex = False,
|
156 |
+
bf16 = False,
|
157 |
+
fp16 = False,
|
158 |
+
fp16_opt_level = 'O1',
|
159 |
+
half_precision_backend = 'auto',
|
160 |
+
bf16_full_eval = False,
|
161 |
+
fp16_full_eval = False,
|
162 |
+
tf32 = None,
|
163 |
+
local_rank = -1,
|
164 |
+
ddp_backend = None,
|
165 |
+
tpu_num_cores = None,
|
166 |
+
tpu_metrics_debug = False,
|
167 |
+
debug = '',
|
168 |
+
dataloader_drop_last = False,
|
169 |
+
eval_steps = None,
|
170 |
+
dataloader_num_workers = 0,
|
171 |
+
dataloader_prefetch_factor = None,
|
172 |
+
past_index = -1,
|
173 |
+
run_name = None,
|
174 |
+
disable_tqdm = None,
|
175 |
+
remove_unused_columns = True,
|
176 |
+
label_names = None,
|
177 |
+
load_best_model_at_end = False,
|
178 |
+
metric_for_best_model = None,
|
179 |
+
greater_is_better = None,
|
180 |
+
ignore_data_skip = False,
|
181 |
+
fsdp = '',
|
182 |
+
fsdp_min_num_params = 0,
|
183 |
+
fsdp_config = None,
|
184 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
185 |
+
accelerator_config = None,
|
186 |
+
deepspeed = None,
|
187 |
+
label_smoothing_factor = 0.0,
|
188 |
+
optim = 'adamw_8bit',
|
189 |
+
optim_args = None,
|
190 |
+
adafactor = False,
|
191 |
+
group_by_length = False,
|
192 |
+
length_column_name = 'length',
|
193 |
+
report_to = None,
|
194 |
+
ddp_find_unused_parameters = None,
|
195 |
+
ddp_bucket_cap_mb = None,
|
196 |
+
ddp_broadcast_buffers = None,
|
197 |
+
dataloader_pin_memory = True,
|
198 |
+
dataloader_persistent_workers = False,
|
199 |
+
skip_memory_metrics = True,
|
200 |
+
use_legacy_prediction_loop = False,
|
201 |
+
push_to_hub = False,
|
202 |
+
resume_from_checkpoint = None,
|
203 |
+
hub_model_id = None,
|
204 |
+
hub_strategy = 'every_save',
|
205 |
+
hub_token = None,
|
206 |
+
hub_private_repo = None,
|
207 |
+
hub_always_push = False,
|
208 |
+
gradient_checkpointing = False,
|
209 |
+
gradient_checkpointing_kwargs = None,
|
210 |
+
include_inputs_for_metrics = False,
|
211 |
+
eval_do_concat_batches = True,
|
212 |
+
fp16_backend = 'auto',
|
213 |
+
evaluation_strategy = None,
|
214 |
+
push_to_hub_model_id = None,
|
215 |
+
push_to_hub_organization = None,
|
216 |
+
push_to_hub_token = None,
|
217 |
+
mp_parameters = '',
|
218 |
+
auto_find_batch_size = False,
|
219 |
+
full_determinism = False,
|
220 |
+
torchdynamo = None,
|
221 |
+
ray_scope = 'last',
|
222 |
+
ddp_timeout = 1800,
|
223 |
+
torch_compile = False,
|
224 |
+
torch_compile_backend = None,
|
225 |
+
torch_compile_mode = None,
|
226 |
+
dispatch_batches = None,
|
227 |
+
split_batches = None,
|
228 |
+
include_tokens_per_second = False,
|
229 |
+
include_num_input_tokens_seen = False,
|
230 |
+
neftune_noise_alpha = None,
|
231 |
+
optim_target_modules = None,
|
232 |
+
batch_eval_metrics = False,
|
233 |
+
eval_on_start = False,
|
234 |
+
use_liger_kernel = False,
|
235 |
+
eval_use_gather_object = False,
|
236 |
+
average_tokens_across_devices = False,
|
237 |
+
max_length = 1024,
|
238 |
+
max_prompt_length = 512,
|
239 |
+
max_completion_length = None,
|
240 |
+
beta = 0.1,
|
241 |
+
label_pad_token_id = -100,
|
242 |
+
padding_value = None,
|
243 |
+
truncation_mode = 'keep_end',
|
244 |
+
disable_dropout = True,
|
245 |
+
generate_during_eval = False,
|
246 |
+
is_encoder_decoder = None,
|
247 |
+
precompute_ref_log_probs = False,
|
248 |
+
model_init_kwargs = None,
|
249 |
+
ref_model_init_kwargs = None,
|
250 |
+
dataset_num_proc = None,
|
251 |
+
prompt_sample_size = 1024,
|
252 |
+
min_density_ratio = 0.5,
|
253 |
+
max_density_ratio = 10.0,
|
254 |
+
vllm_sampling_params = None,
|
255 |
+
unsloth_num_chunks = -1,
|
256 |
+
**kwargs,
|
257 |
+
):
|
258 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
259 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
260 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
261 |
+
output_dir = 'unsloth_training_checkpoints'
|
262 |
+
save_strategy = 'no'
|
263 |
+
if dataset_num_proc is None:
|
264 |
+
from multiprocessing import cpu_count
|
265 |
+
dataset_num_proc = cpu_count()
|
266 |
+
|
267 |
+
super().__init__(
|
268 |
+
output_dir = output_dir,
|
269 |
+
overwrite_output_dir = overwrite_output_dir,
|
270 |
+
do_train = do_train,
|
271 |
+
do_eval = do_eval,
|
272 |
+
do_predict = do_predict,
|
273 |
+
eval_strategy = eval_strategy,
|
274 |
+
prediction_loss_only = prediction_loss_only,
|
275 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
276 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
277 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
278 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
279 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
280 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
281 |
+
eval_delay = eval_delay,
|
282 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
283 |
+
learning_rate = learning_rate,
|
284 |
+
weight_decay = weight_decay,
|
285 |
+
adam_beta1 = adam_beta1,
|
286 |
+
adam_beta2 = adam_beta2,
|
287 |
+
adam_epsilon = adam_epsilon,
|
288 |
+
max_grad_norm = max_grad_norm,
|
289 |
+
num_train_epochs = num_train_epochs,
|
290 |
+
max_steps = max_steps,
|
291 |
+
lr_scheduler_type = lr_scheduler_type,
|
292 |
+
warmup_ratio = warmup_ratio,
|
293 |
+
warmup_steps = warmup_steps,
|
294 |
+
log_level = log_level,
|
295 |
+
log_level_replica = log_level_replica,
|
296 |
+
log_on_each_node = log_on_each_node,
|
297 |
+
logging_dir = logging_dir,
|
298 |
+
logging_strategy = logging_strategy,
|
299 |
+
logging_first_step = logging_first_step,
|
300 |
+
logging_steps = logging_steps,
|
301 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
302 |
+
save_strategy = save_strategy,
|
303 |
+
save_steps = save_steps,
|
304 |
+
save_total_limit = save_total_limit,
|
305 |
+
save_safetensors = save_safetensors,
|
306 |
+
save_on_each_node = save_on_each_node,
|
307 |
+
save_only_model = save_only_model,
|
308 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
309 |
+
no_cuda = no_cuda,
|
310 |
+
use_cpu = use_cpu,
|
311 |
+
use_mps_device = use_mps_device,
|
312 |
+
seed = seed,
|
313 |
+
data_seed = data_seed,
|
314 |
+
jit_mode_eval = jit_mode_eval,
|
315 |
+
use_ipex = use_ipex,
|
316 |
+
bf16 = bf16,
|
317 |
+
fp16 = fp16,
|
318 |
+
fp16_opt_level = fp16_opt_level,
|
319 |
+
half_precision_backend = half_precision_backend,
|
320 |
+
bf16_full_eval = bf16_full_eval,
|
321 |
+
fp16_full_eval = fp16_full_eval,
|
322 |
+
tf32 = tf32,
|
323 |
+
local_rank = local_rank,
|
324 |
+
ddp_backend = ddp_backend,
|
325 |
+
tpu_num_cores = tpu_num_cores,
|
326 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
327 |
+
debug = debug,
|
328 |
+
dataloader_drop_last = dataloader_drop_last,
|
329 |
+
eval_steps = eval_steps,
|
330 |
+
dataloader_num_workers = dataloader_num_workers,
|
331 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
332 |
+
past_index = past_index,
|
333 |
+
run_name = run_name,
|
334 |
+
disable_tqdm = disable_tqdm,
|
335 |
+
remove_unused_columns = remove_unused_columns,
|
336 |
+
label_names = label_names,
|
337 |
+
load_best_model_at_end = load_best_model_at_end,
|
338 |
+
metric_for_best_model = metric_for_best_model,
|
339 |
+
greater_is_better = greater_is_better,
|
340 |
+
ignore_data_skip = ignore_data_skip,
|
341 |
+
fsdp = fsdp,
|
342 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
343 |
+
fsdp_config = fsdp_config,
|
344 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
345 |
+
accelerator_config = accelerator_config,
|
346 |
+
deepspeed = deepspeed,
|
347 |
+
label_smoothing_factor = label_smoothing_factor,
|
348 |
+
optim = optim,
|
349 |
+
optim_args = optim_args,
|
350 |
+
adafactor = adafactor,
|
351 |
+
group_by_length = group_by_length,
|
352 |
+
length_column_name = length_column_name,
|
353 |
+
report_to = report_to,
|
354 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
355 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
356 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
357 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
358 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
359 |
+
skip_memory_metrics = skip_memory_metrics,
|
360 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
361 |
+
push_to_hub = push_to_hub,
|
362 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
363 |
+
hub_model_id = hub_model_id,
|
364 |
+
hub_strategy = hub_strategy,
|
365 |
+
hub_token = hub_token,
|
366 |
+
hub_private_repo = hub_private_repo,
|
367 |
+
hub_always_push = hub_always_push,
|
368 |
+
gradient_checkpointing = gradient_checkpointing,
|
369 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
370 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
371 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
372 |
+
fp16_backend = fp16_backend,
|
373 |
+
evaluation_strategy = evaluation_strategy,
|
374 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
375 |
+
push_to_hub_organization = push_to_hub_organization,
|
376 |
+
push_to_hub_token = push_to_hub_token,
|
377 |
+
mp_parameters = mp_parameters,
|
378 |
+
auto_find_batch_size = auto_find_batch_size,
|
379 |
+
full_determinism = full_determinism,
|
380 |
+
torchdynamo = torchdynamo,
|
381 |
+
ray_scope = ray_scope,
|
382 |
+
ddp_timeout = ddp_timeout,
|
383 |
+
torch_compile = torch_compile,
|
384 |
+
torch_compile_backend = torch_compile_backend,
|
385 |
+
torch_compile_mode = torch_compile_mode,
|
386 |
+
dispatch_batches = dispatch_batches,
|
387 |
+
split_batches = split_batches,
|
388 |
+
include_tokens_per_second = include_tokens_per_second,
|
389 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
390 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
391 |
+
optim_target_modules = optim_target_modules,
|
392 |
+
batch_eval_metrics = batch_eval_metrics,
|
393 |
+
eval_on_start = eval_on_start,
|
394 |
+
use_liger_kernel = use_liger_kernel,
|
395 |
+
eval_use_gather_object = eval_use_gather_object,
|
396 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
397 |
+
max_length = max_length,
|
398 |
+
max_prompt_length = max_prompt_length,
|
399 |
+
max_completion_length = max_completion_length,
|
400 |
+
beta = beta,
|
401 |
+
label_pad_token_id = label_pad_token_id,
|
402 |
+
padding_value = padding_value,
|
403 |
+
truncation_mode = truncation_mode,
|
404 |
+
disable_dropout = disable_dropout,
|
405 |
+
generate_during_eval = generate_during_eval,
|
406 |
+
is_encoder_decoder = is_encoder_decoder,
|
407 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
408 |
+
model_init_kwargs = model_init_kwargs,
|
409 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
410 |
+
dataset_num_proc = dataset_num_proc,
|
411 |
+
prompt_sample_size = prompt_sample_size,
|
412 |
+
min_density_ratio = min_density_ratio,
|
413 |
+
max_density_ratio = max_density_ratio,**kwargs)
|
414 |
+
self.vllm_sampling_params = vllm_sampling_params
|
415 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
416 |
+
pass
|
417 |
+
|
418 |
+
class _UnslothBCOTrainer(Trainer):
|
419 |
+
r""""""
|
420 |
+
|
421 |
+
_tag_names = ["trl", "bco"]
|
422 |
+
|
423 |
+
def __init__(
|
424 |
+
self,
|
425 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
426 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
427 |
+
args: BCOConfig = None,
|
428 |
+
train_dataset: Optional[Dataset] = None,
|
429 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
430 |
+
processing_class: Optional[
|
431 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
432 |
+
] = None,
|
433 |
+
data_collator: Optional[DataCollator] = None,
|
434 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
435 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
436 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
437 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
438 |
+
peft_config: Optional[dict] = None,
|
439 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
440 |
+
model_adapter_name: Optional[str] = None,
|
441 |
+
ref_adapter_name: Optional[str] = None,
|
442 |
+
embedding_func: Optional[Callable] = None,
|
443 |
+
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
444 |
+
):
|
445 |
+
if not is_sklearn_available():
|
446 |
+
raise ImportError(
|
447 |
+
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
448 |
+
)
|
449 |
+
|
450 |
+
if type(args) is TrainingArguments:
|
451 |
+
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
452 |
+
|
453 |
+
if not isinstance(model, str) and ref_model is model:
|
454 |
+
raise ValueError(
|
455 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
456 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
457 |
+
)
|
458 |
+
|
459 |
+
if args.model_init_kwargs is None:
|
460 |
+
model_init_kwargs = {}
|
461 |
+
elif not isinstance(model, str):
|
462 |
+
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
463 |
+
else:
|
464 |
+
model_init_kwargs = args.model_init_kwargs
|
465 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
466 |
+
if torch_dtype is not None:
|
467 |
+
# Convert to `torch.dtype` if an str is passed
|
468 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
469 |
+
torch_dtype = getattr(torch, torch_dtype)
|
470 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
471 |
+
raise ValueError(
|
472 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
473 |
+
)
|
474 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
475 |
+
|
476 |
+
if args.ref_model_init_kwargs is None:
|
477 |
+
ref_model_init_kwargs = {}
|
478 |
+
elif not isinstance(ref_model, str):
|
479 |
+
raise ValueError(
|
480 |
+
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
484 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
485 |
+
if torch_dtype is not None:
|
486 |
+
# Convert to `torch.dtype` if an str is passed
|
487 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
488 |
+
torch_dtype = getattr(torch, torch_dtype)
|
489 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
490 |
+
raise ValueError(
|
491 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
492 |
+
)
|
493 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
494 |
+
|
495 |
+
if isinstance(model, str):
|
496 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
497 |
+
|
498 |
+
if isinstance(ref_model, str):
|
499 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
500 |
+
|
501 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
502 |
+
# has been called in order to properly call autocast if needed.
|
503 |
+
self._peft_has_been_casted_to_bf16 = False
|
504 |
+
|
505 |
+
if not is_peft_available() and peft_config is not None:
|
506 |
+
raise ValueError(
|
507 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
508 |
+
)
|
509 |
+
elif is_peft_available() and peft_config is not None:
|
510 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
511 |
+
if isinstance(model, PeftModel):
|
512 |
+
model = model.merge_and_unload()
|
513 |
+
|
514 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
515 |
+
_support_gc_kwargs = hasattr(
|
516 |
+
args, "gradient_checkpointing_kwargs"
|
517 |
+
) and "gradient_checkpointing_kwargs" in list(
|
518 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
519 |
+
)
|
520 |
+
|
521 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
522 |
+
|
523 |
+
if _support_gc_kwargs:
|
524 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
525 |
+
|
526 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
527 |
+
elif getattr(args, "gradient_checkpointing", False):
|
528 |
+
# For backward compatibility with older versions of transformers
|
529 |
+
if hasattr(model, "enable_input_require_grads"):
|
530 |
+
model.enable_input_require_grads()
|
531 |
+
else:
|
532 |
+
|
533 |
+
def make_inputs_require_grad(module, input, output):
|
534 |
+
output.requires_grad_(True)
|
535 |
+
|
536 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
537 |
+
|
538 |
+
# get peft model with the given config
|
539 |
+
model = model
|
540 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
541 |
+
peft_module_casting_to_bf16(model)
|
542 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
543 |
+
self._peft_has_been_casted_to_bf16 = True
|
544 |
+
|
545 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
546 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
547 |
+
# fail or completely fail.
|
548 |
+
elif getattr(args, "gradient_checkpointing", False):
|
549 |
+
# For backward compatibility with older versions of transformers
|
550 |
+
if hasattr(model, "enable_input_require_grads"):
|
551 |
+
model.enable_input_require_grads()
|
552 |
+
else:
|
553 |
+
|
554 |
+
def make_inputs_require_grad(module, input, output):
|
555 |
+
output.requires_grad_(True)
|
556 |
+
|
557 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
558 |
+
|
559 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
560 |
+
raise ValueError(
|
561 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
562 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
563 |
+
)
|
564 |
+
|
565 |
+
if model is not None:
|
566 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
567 |
+
elif args.is_encoder_decoder is None:
|
568 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
569 |
+
else:
|
570 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
571 |
+
|
572 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
573 |
+
self.model_adapter_name = model_adapter_name
|
574 |
+
self.ref_adapter_name = ref_adapter_name
|
575 |
+
|
576 |
+
if ref_model:
|
577 |
+
self.ref_model = ref_model
|
578 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
579 |
+
# The `model` with adapters turned off will be used as the reference model
|
580 |
+
self.ref_model = None
|
581 |
+
else:
|
582 |
+
self.ref_model = create_reference_model(model)
|
583 |
+
|
584 |
+
if processing_class is None:
|
585 |
+
raise ValueError(
|
586 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
587 |
+
)
|
588 |
+
if args.max_length is None:
|
589 |
+
warnings.warn(
|
590 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
591 |
+
"It will be set to `512` by default, but you should do it yourself in the future.",
|
592 |
+
UserWarning,
|
593 |
+
)
|
594 |
+
max_length = 512
|
595 |
+
if args.max_length is not None:
|
596 |
+
max_length = args.max_length
|
597 |
+
|
598 |
+
if args.max_prompt_length is None:
|
599 |
+
warnings.warn(
|
600 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
601 |
+
"It will be set to `128` by default, but you should do it yourself in the future.",
|
602 |
+
UserWarning,
|
603 |
+
)
|
604 |
+
max_prompt_length = 128
|
605 |
+
if args.max_prompt_length is not None:
|
606 |
+
max_prompt_length = args.max_prompt_length
|
607 |
+
|
608 |
+
max_completion_length = None
|
609 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
610 |
+
warnings.warn(
|
611 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
612 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
613 |
+
UserWarning,
|
614 |
+
)
|
615 |
+
max_completion_length = 128
|
616 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
617 |
+
max_completion_length = args.max_completion_length
|
618 |
+
|
619 |
+
if data_collator is None:
|
620 |
+
data_collator = DPODataCollatorWithPadding(
|
621 |
+
pad_token_id=processing_class.pad_token_id,
|
622 |
+
label_pad_token_id=args.label_pad_token_id,
|
623 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
624 |
+
)
|
625 |
+
|
626 |
+
if args.remove_unused_columns:
|
627 |
+
args.remove_unused_columns = False
|
628 |
+
# warn users
|
629 |
+
warnings.warn(
|
630 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
631 |
+
" we have set it for you, but you should do it yourself in the future.",
|
632 |
+
UserWarning,
|
633 |
+
)
|
634 |
+
|
635 |
+
self.use_dpo_data_collator = True
|
636 |
+
else:
|
637 |
+
self.use_dpo_data_collator = False
|
638 |
+
|
639 |
+
# Disable dropout in the model and reference model
|
640 |
+
if args.disable_dropout:
|
641 |
+
disable_dropout_in_model(model)
|
642 |
+
if self.ref_model is not None:
|
643 |
+
disable_dropout_in_model(self.ref_model)
|
644 |
+
|
645 |
+
self.max_length = max_length
|
646 |
+
self.generate_during_eval = args.generate_during_eval
|
647 |
+
self.label_pad_token_id = args.label_pad_token_id
|
648 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
649 |
+
self.max_prompt_length = max_prompt_length
|
650 |
+
self.truncation_mode = args.truncation_mode
|
651 |
+
self.max_completion_length = max_completion_length
|
652 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
653 |
+
|
654 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
655 |
+
# keep track of first called to avoid computation of future calls
|
656 |
+
self._precomputed_train_ref_log_probs = False
|
657 |
+
self._precomputed_eval_ref_log_probs = False
|
658 |
+
|
659 |
+
# metric
|
660 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
661 |
+
|
662 |
+
# BCO parameter
|
663 |
+
self.beta = args.beta
|
664 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
665 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
666 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
667 |
+
warnings.warn(
|
668 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
669 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
670 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
671 |
+
"loss.",
|
672 |
+
UserWarning,
|
673 |
+
)
|
674 |
+
|
675 |
+
# Underlying Distribution Matching argument
|
676 |
+
self.embedding_func = embedding_func
|
677 |
+
self.embedding_tokenizer = embedding_tokenizer
|
678 |
+
|
679 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
680 |
+
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
681 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
682 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
683 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
684 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
685 |
+
# issued.
|
686 |
+
model.warnings_issued["estimate_tokens"] = True
|
687 |
+
|
688 |
+
with PartialState().local_main_process_first():
|
689 |
+
# Apply the chat template if needed
|
690 |
+
train_dataset = train_dataset.map(
|
691 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
692 |
+
)
|
693 |
+
if eval_dataset is not None:
|
694 |
+
eval_dataset = eval_dataset.map(
|
695 |
+
maybe_apply_chat_template,
|
696 |
+
fn_kwargs={"tokenizer": processing_class},
|
697 |
+
num_proc=args.dataset_num_proc,
|
698 |
+
)
|
699 |
+
# Shuffle the datasets
|
700 |
+
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
701 |
+
if eval_dataset is not None:
|
702 |
+
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
703 |
+
# Tokenize and prepare the training datasets
|
704 |
+
train_dataset = train_dataset.map(
|
705 |
+
_tokenize,
|
706 |
+
batched=True,
|
707 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
708 |
+
num_proc=args.dataset_num_proc,
|
709 |
+
desc="Tokenizing train dataset",
|
710 |
+
)
|
711 |
+
|
712 |
+
# Prepare the datasets
|
713 |
+
fn_kwargs = {
|
714 |
+
"prefix": "",
|
715 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
716 |
+
"tokenizer": processing_class,
|
717 |
+
"max_length": self.max_length,
|
718 |
+
"truncation_mode": self.truncation_mode,
|
719 |
+
"label_pad_token_id": self.label_pad_token_id,
|
720 |
+
"max_prompt_length": self.max_prompt_length,
|
721 |
+
"max_completion_length": self.max_completion_length,
|
722 |
+
}
|
723 |
+
train_dataset = train_dataset.map(
|
724 |
+
_process_tokens,
|
725 |
+
fn_kwargs=fn_kwargs,
|
726 |
+
num_proc=args.dataset_num_proc,
|
727 |
+
desc="Processing tokenized train dataset",
|
728 |
+
)
|
729 |
+
|
730 |
+
if eval_dataset is not None:
|
731 |
+
# Tokenize
|
732 |
+
eval_dataset = eval_dataset.map(
|
733 |
+
_tokenize,
|
734 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
735 |
+
batched=True,
|
736 |
+
num_proc=args.dataset_num_proc,
|
737 |
+
desc="Tokenizing eval dataset",
|
738 |
+
)
|
739 |
+
|
740 |
+
# Process
|
741 |
+
fn_kwargs = {
|
742 |
+
"prefix": "",
|
743 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
744 |
+
"tokenizer": processing_class,
|
745 |
+
"max_length": self.max_length,
|
746 |
+
"truncation_mode": self.truncation_mode,
|
747 |
+
"label_pad_token_id": self.label_pad_token_id,
|
748 |
+
"max_prompt_length": self.max_prompt_length,
|
749 |
+
"max_completion_length": self.max_completion_length,
|
750 |
+
}
|
751 |
+
eval_dataset = eval_dataset.map(
|
752 |
+
_process_tokens,
|
753 |
+
fn_kwargs=fn_kwargs,
|
754 |
+
num_proc=args.dataset_num_proc,
|
755 |
+
desc="Processing tokenized eval dataset",
|
756 |
+
)
|
757 |
+
|
758 |
+
desirable = train_dataset.filter(
|
759 |
+
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
760 |
+
)
|
761 |
+
undesirable = train_dataset.filter(
|
762 |
+
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
763 |
+
)
|
764 |
+
|
765 |
+
desirable = desirable.shuffle(seed=args.data_seed)
|
766 |
+
undesirable = undesirable.shuffle(seed=args.data_seed)
|
767 |
+
|
768 |
+
super().__init__(
|
769 |
+
model=model,
|
770 |
+
args=args,
|
771 |
+
data_collator=data_collator,
|
772 |
+
train_dataset=train_dataset,
|
773 |
+
eval_dataset=eval_dataset,
|
774 |
+
processing_class=processing_class,
|
775 |
+
model_init=model_init,
|
776 |
+
compute_metrics=compute_metrics,
|
777 |
+
callbacks=callbacks,
|
778 |
+
optimizers=optimizers,
|
779 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
780 |
+
)
|
781 |
+
|
782 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
783 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
784 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
785 |
+
self.model_accepts_loss_kwargs = False
|
786 |
+
|
787 |
+
# Add tags for models that have been loaded with the correct transformers version
|
788 |
+
if hasattr(self.model, "add_model_tags"):
|
789 |
+
self.model.add_model_tags(self._tag_names)
|
790 |
+
|
791 |
+
if not hasattr(self, "accelerator"):
|
792 |
+
raise AttributeError(
|
793 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
794 |
+
)
|
795 |
+
|
796 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
797 |
+
if self.is_deepspeed_enabled:
|
798 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
799 |
+
raise ValueError(
|
800 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
801 |
+
)
|
802 |
+
|
803 |
+
if self.ref_model is None:
|
804 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
805 |
+
raise ValueError(
|
806 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
807 |
+
)
|
808 |
+
else:
|
809 |
+
if self.is_deepspeed_enabled:
|
810 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
811 |
+
else:
|
812 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
813 |
+
|
814 |
+
self.running = RunningMoments(accelerator=self.accelerator)
|
815 |
+
|
816 |
+
if self.embedding_func is None:
|
817 |
+
return
|
818 |
+
|
819 |
+
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
820 |
+
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
821 |
+
|
822 |
+
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
823 |
+
labels = torch.cat(
|
824 |
+
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
825 |
+
)
|
826 |
+
|
827 |
+
self.clf = LogisticRegression(class_weight="balanced").fit(
|
828 |
+
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
829 |
+
)
|
830 |
+
|
831 |
+
@property
|
832 |
+
def match_underlying_distribution(self):
|
833 |
+
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
834 |
+
|
835 |
+
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
836 |
+
"""
|
837 |
+
Calculates the probability if the given prompt embedding is from desirable dataset.
|
838 |
+
This function calculates the probability in the process and ensemble across processes.
|
839 |
+
"""
|
840 |
+
dtype = prompt_embeddings.dtype
|
841 |
+
device = prompt_embeddings.device
|
842 |
+
rank = self.accelerator.process_index
|
843 |
+
|
844 |
+
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
845 |
+
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
846 |
+
)
|
847 |
+
sample_size = padded_prompt_embeddings.shape[0]
|
848 |
+
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
849 |
+
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
850 |
+
|
851 |
+
# cannot predict for all empty values
|
852 |
+
if prompt_embeddings.shape[0] == 0:
|
853 |
+
return torch.tensor([], device=device, dtype=dtype)
|
854 |
+
|
855 |
+
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
856 |
+
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
857 |
+
prob = self.accelerator.reduce(prob, reduction="mean")
|
858 |
+
|
859 |
+
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
860 |
+
prob = prob[nonzero]
|
861 |
+
|
862 |
+
return prob
|
863 |
+
|
864 |
+
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
865 |
+
"""
|
866 |
+
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
867 |
+
and applies self.embedding_func
|
868 |
+
"""
|
869 |
+
input_ids = torch.where(
|
870 |
+
input_ids == self.processing_class.pad_token_id,
|
871 |
+
self.embedding_tokenizer.pad_token_id,
|
872 |
+
input_ids,
|
873 |
+
)
|
874 |
+
|
875 |
+
with torch.no_grad():
|
876 |
+
embeddings = self.embedding_func(
|
877 |
+
input_ids=input_ids,
|
878 |
+
attention_mask=attention_mask,
|
879 |
+
)
|
880 |
+
|
881 |
+
return embeddings
|
882 |
+
|
883 |
+
def _get_prompt_embeddings(
|
884 |
+
self, batch: dict[str, Union[list, torch.LongTensor]]
|
885 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
886 |
+
"""Extract embeddings from frozen embedding model"""
|
887 |
+
|
888 |
+
if not self.match_underlying_distribution:
|
889 |
+
return None, None
|
890 |
+
|
891 |
+
embeddings = self._vectorize_prompt(
|
892 |
+
input_ids=batch["embedding_input_ids"],
|
893 |
+
attention_mask=batch["embedding_attention_mask"],
|
894 |
+
)
|
895 |
+
|
896 |
+
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
897 |
+
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
898 |
+
|
899 |
+
chosen_embeddings = embeddings[chosen_idx, ...]
|
900 |
+
rejected_embeddings = embeddings[rejected_idx, ...]
|
901 |
+
|
902 |
+
return (chosen_embeddings, rejected_embeddings)
|
903 |
+
|
904 |
+
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
905 |
+
"""
|
906 |
+
Sample instances from dataset and get prompt embeddings.
|
907 |
+
Used for density ratio classifier training.
|
908 |
+
"""
|
909 |
+
n_samples = min(len(dataset), sample_size)
|
910 |
+
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
911 |
+
|
912 |
+
embedding_dataset = dataset.select(rand_indices)
|
913 |
+
|
914 |
+
dataloader_params = {
|
915 |
+
"batch_size": self.args.per_device_train_batch_size,
|
916 |
+
"collate_fn": self.data_collator,
|
917 |
+
"num_workers": self.args.dataloader_num_workers,
|
918 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
919 |
+
"shuffle": False,
|
920 |
+
}
|
921 |
+
|
922 |
+
# prepare dataloader
|
923 |
+
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
924 |
+
|
925 |
+
with torch.no_grad():
|
926 |
+
all_embeddings = torch.empty(0)
|
927 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
928 |
+
embeddings = self._vectorize_prompt(
|
929 |
+
input_ids=padded_batch["embedding_input_ids"],
|
930 |
+
attention_mask=padded_batch["embedding_attention_mask"],
|
931 |
+
)
|
932 |
+
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
933 |
+
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
934 |
+
|
935 |
+
return all_embeddings
|
936 |
+
|
937 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
938 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
939 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
940 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
941 |
+
|
942 |
+
if model is not None:
|
943 |
+
if hasattr(model, "config"):
|
944 |
+
hidden_size = (
|
945 |
+
max(model.config.hidden_sizes)
|
946 |
+
if getattr(model.config, "hidden_sizes", None)
|
947 |
+
else getattr(model.config, "hidden_size", None)
|
948 |
+
)
|
949 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
950 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
951 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
952 |
+
config_kwargs.update(
|
953 |
+
{
|
954 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
955 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
956 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
957 |
+
}
|
958 |
+
)
|
959 |
+
|
960 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
961 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
962 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
963 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
964 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
965 |
+
model.eval()
|
966 |
+
return model
|
967 |
+
|
968 |
+
def _save_optimizer_and_scheduler(self, output_dir):
|
969 |
+
super()._save_optimizer_and_scheduler(output_dir)
|
970 |
+
|
971 |
+
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
972 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
973 |
+
|
974 |
+
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
975 |
+
|
976 |
+
if self.match_underlying_distribution:
|
977 |
+
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
978 |
+
|
979 |
+
def _load_optimizer_and_scheduler(self, checkpoint):
|
980 |
+
super()._load_optimizer_and_scheduler(checkpoint)
|
981 |
+
|
982 |
+
if checkpoint is None:
|
983 |
+
return
|
984 |
+
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
985 |
+
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
986 |
+
if os.path.isfile(running_file):
|
987 |
+
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
988 |
+
|
989 |
+
if self.match_underlying_distribution:
|
990 |
+
clf_file = os.path.join(checkpoint, CLF_NAME)
|
991 |
+
if os.path.isfile(running_file):
|
992 |
+
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
993 |
+
|
994 |
+
@contextmanager
|
995 |
+
def null_ref_context(self):
|
996 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
997 |
+
with (
|
998 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
999 |
+
if self.is_peft_model and not self.ref_adapter_name
|
1000 |
+
else nullcontext()
|
1001 |
+
):
|
1002 |
+
if self.ref_adapter_name:
|
1003 |
+
self.model.set_adapter(self.ref_adapter_name)
|
1004 |
+
yield
|
1005 |
+
if self.ref_adapter_name:
|
1006 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
1007 |
+
|
1008 |
+
def get_train_dataloader(self) -> DataLoader:
|
1009 |
+
"""
|
1010 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
1011 |
+
|
1012 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
1013 |
+
"""
|
1014 |
+
|
1015 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
1016 |
+
dataloader_params = {
|
1017 |
+
"batch_size": self.args.per_device_train_batch_size,
|
1018 |
+
"collate_fn": self.data_collator,
|
1019 |
+
"num_workers": self.args.dataloader_num_workers,
|
1020 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1021 |
+
"shuffle": False,
|
1022 |
+
}
|
1023 |
+
|
1024 |
+
# prepare dataloader
|
1025 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
1026 |
+
reference_completion_logps = []
|
1027 |
+
|
1028 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
1029 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1030 |
+
|
1031 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1032 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1033 |
+
|
1034 |
+
self.train_dataset = self.train_dataset.add_column(
|
1035 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
self._precomputed_train_ref_log_probs = True
|
1039 |
+
|
1040 |
+
return super().get_train_dataloader()
|
1041 |
+
|
1042 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
1043 |
+
"""
|
1044 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
1045 |
+
|
1046 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
1047 |
+
|
1048 |
+
Args:
|
1049 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
1050 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
1051 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
1052 |
+
"""
|
1053 |
+
if eval_dataset is None and self.eval_dataset is None:
|
1054 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
1055 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
1056 |
+
|
1057 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
1058 |
+
dataloader_params = {
|
1059 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
1060 |
+
"collate_fn": self.data_collator,
|
1061 |
+
"num_workers": self.args.dataloader_num_workers,
|
1062 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1063 |
+
"shuffle": False,
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
# prepare dataloader
|
1067 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1068 |
+
|
1069 |
+
reference_completion_logps = []
|
1070 |
+
|
1071 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1072 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1073 |
+
|
1074 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1075 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1076 |
+
|
1077 |
+
eval_dataset = eval_dataset.add_column(
|
1078 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1082 |
+
if self.eval_dataset is not None:
|
1083 |
+
self.eval_dataset = eval_dataset
|
1084 |
+
self._precomputed_eval_ref_log_probs = True
|
1085 |
+
|
1086 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1087 |
+
|
1088 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1089 |
+
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
1090 |
+
with torch.no_grad():
|
1091 |
+
if self.ref_model is None:
|
1092 |
+
with self.null_ref_context():
|
1093 |
+
if self.is_encoder_decoder:
|
1094 |
+
completion_logits = self.model(
|
1095 |
+
padded_batch["prompt_input_ids"],
|
1096 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1097 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1098 |
+
labels=padded_batch["completion_labels"],
|
1099 |
+
).logits
|
1100 |
+
|
1101 |
+
else:
|
1102 |
+
completion_logits = self.model(
|
1103 |
+
padded_batch["completion_input_ids"],
|
1104 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
1105 |
+
).logits
|
1106 |
+
|
1107 |
+
else:
|
1108 |
+
if self.is_encoder_decoder:
|
1109 |
+
completion_logits = self.ref_model(
|
1110 |
+
padded_batch["prompt_input_ids"],
|
1111 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1112 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1113 |
+
labels=padded_batch["completion_labels"],
|
1114 |
+
).logits
|
1115 |
+
|
1116 |
+
else:
|
1117 |
+
completion_logits = self.ref_model(
|
1118 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1119 |
+
).logits
|
1120 |
+
|
1121 |
+
completion_logps = self.get_batch_logps(
|
1122 |
+
completion_logits,
|
1123 |
+
padded_batch["completion_labels"],
|
1124 |
+
average_log_prob=False,
|
1125 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1126 |
+
label_pad_token_id=self.label_pad_token_id,
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
return completion_logps
|
1130 |
+
|
1131 |
+
@staticmethod
|
1132 |
+
def get_batch_logps(
|
1133 |
+
logits: torch.FloatTensor,
|
1134 |
+
labels: torch.LongTensor,
|
1135 |
+
average_log_prob: bool = False,
|
1136 |
+
label_pad_token_id: int = -100,
|
1137 |
+
is_encoder_decoder: bool = False,
|
1138 |
+
) -> torch.FloatTensor:
|
1139 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1140 |
+
|
1141 |
+
Args:
|
1142 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1143 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1144 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1145 |
+
|
1146 |
+
Returns:
|
1147 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1148 |
+
"""
|
1149 |
+
if logits.shape[:-1] != labels.shape:
|
1150 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1151 |
+
|
1152 |
+
if not is_encoder_decoder:
|
1153 |
+
labels = labels[:, 1:].clone()
|
1154 |
+
logits = logits[:, :-1, :]
|
1155 |
+
else:
|
1156 |
+
# Fixes end-dec RuntimeError
|
1157 |
+
labels = labels.clone()
|
1158 |
+
|
1159 |
+
loss_mask = labels != label_pad_token_id
|
1160 |
+
|
1161 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1162 |
+
labels[labels == label_pad_token_id] = 0
|
1163 |
+
|
1164 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1165 |
+
|
1166 |
+
if average_log_prob:
|
1167 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1168 |
+
else:
|
1169 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1170 |
+
|
1171 |
+
def forward(
|
1172 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1173 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1174 |
+
model_kwargs = (
|
1175 |
+
{
|
1176 |
+
"labels": batch["completion_labels"],
|
1177 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1178 |
+
}
|
1179 |
+
if self.is_encoder_decoder
|
1180 |
+
else {}
|
1181 |
+
)
|
1182 |
+
if self.aux_loss_enabled:
|
1183 |
+
model_kwargs["output_router_logits"] = True
|
1184 |
+
|
1185 |
+
outputs = model(
|
1186 |
+
batch["completion_input_ids"],
|
1187 |
+
attention_mask=batch["completion_attention_mask"],
|
1188 |
+
**model_kwargs,
|
1189 |
+
)
|
1190 |
+
completion_logits = outputs.logits
|
1191 |
+
|
1192 |
+
completion_logps = self.get_batch_logps(
|
1193 |
+
completion_logits,
|
1194 |
+
batch["completion_labels"],
|
1195 |
+
average_log_prob=False,
|
1196 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1197 |
+
label_pad_token_id=self.label_pad_token_id,
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
1201 |
+
raise ValueError(
|
1202 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
1203 |
+
"examples for which an output sequence was predicted."
|
1204 |
+
)
|
1205 |
+
|
1206 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1207 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1208 |
+
|
1209 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
1210 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
1211 |
+
|
1212 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
1213 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
1214 |
+
|
1215 |
+
if self.aux_loss_enabled:
|
1216 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
1217 |
+
else:
|
1218 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
1219 |
+
|
1220 |
+
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
1221 |
+
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
1222 |
+
min_ratio = self.args.min_density_ratio
|
1223 |
+
max_ratio = self.args.max_density_ratio
|
1224 |
+
|
1225 |
+
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
1226 |
+
|
1227 |
+
return weight
|
1228 |
+
|
1229 |
+
def bco_loss(
|
1230 |
+
self,
|
1231 |
+
policy_chosen_logps: torch.FloatTensor,
|
1232 |
+
policy_rejected_logps: torch.FloatTensor,
|
1233 |
+
reference_chosen_logps: torch.FloatTensor,
|
1234 |
+
reference_rejected_logps: torch.FloatTensor,
|
1235 |
+
chosen_embeddings: Optional[torch.FloatTensor],
|
1236 |
+
rejected_embeddings: Optional[torch.FloatTensor],
|
1237 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1238 |
+
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
1239 |
+
|
1240 |
+
Args:
|
1241 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1242 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1243 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1244 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1245 |
+
chosen_embeddings: embeddings of desirable prompts
|
1246 |
+
rejected_embeddings: embeddings of undesirable prompts
|
1247 |
+
|
1248 |
+
Returns:
|
1249 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
1250 |
+
The losses tensor contains the BCO loss for each example in the batch.
|
1251 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1252 |
+
The delta value contains the moving average of all implicit rewards.
|
1253 |
+
"""
|
1254 |
+
|
1255 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1256 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1257 |
+
chosen_rewards = self.beta * chosen_logratios
|
1258 |
+
else:
|
1259 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1260 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1261 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1262 |
+
|
1263 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1264 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1265 |
+
rejected_rewards = self.beta * rejected_logratios
|
1266 |
+
else:
|
1267 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1268 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1269 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1270 |
+
|
1271 |
+
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
1272 |
+
self.running.update(rewards)
|
1273 |
+
delta = self.running.mean
|
1274 |
+
|
1275 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1276 |
+
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
1277 |
+
|
1278 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1279 |
+
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
1280 |
+
|
1281 |
+
if self.match_underlying_distribution:
|
1282 |
+
chosen_weight = torch.ones_like(chosen_losses)
|
1283 |
+
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
1284 |
+
|
1285 |
+
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
1286 |
+
else:
|
1287 |
+
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
1288 |
+
|
1289 |
+
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
1290 |
+
|
1291 |
+
def get_batch_loss_metrics(
|
1292 |
+
self,
|
1293 |
+
model,
|
1294 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1295 |
+
):
|
1296 |
+
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
1297 |
+
metrics = {}
|
1298 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1299 |
+
|
1300 |
+
forward_output = self.forward(model, batch)
|
1301 |
+
(
|
1302 |
+
policy_chosen_logps,
|
1303 |
+
policy_rejected_logps,
|
1304 |
+
policy_chosen_logits,
|
1305 |
+
policy_rejected_logits,
|
1306 |
+
) = forward_output[:4]
|
1307 |
+
if self.aux_loss_enabled:
|
1308 |
+
aux_loss = forward_output[4]
|
1309 |
+
|
1310 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
1311 |
+
if "reference_logps" in batch:
|
1312 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1313 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1314 |
+
|
1315 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1316 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1317 |
+
else:
|
1318 |
+
with torch.no_grad():
|
1319 |
+
if self.ref_model is None:
|
1320 |
+
with self.null_ref_context():
|
1321 |
+
(
|
1322 |
+
reference_chosen_logps,
|
1323 |
+
reference_rejected_logps,
|
1324 |
+
_,
|
1325 |
+
_,
|
1326 |
+
) = self.forward(self.model, batch)[:4]
|
1327 |
+
else:
|
1328 |
+
(
|
1329 |
+
reference_chosen_logps,
|
1330 |
+
reference_rejected_logps,
|
1331 |
+
_,
|
1332 |
+
_,
|
1333 |
+
) = self.forward(self.ref_model, batch)[:4]
|
1334 |
+
|
1335 |
+
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
1336 |
+
|
1337 |
+
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
1338 |
+
policy_chosen_logps,
|
1339 |
+
policy_rejected_logps,
|
1340 |
+
reference_chosen_logps,
|
1341 |
+
reference_rejected_logps,
|
1342 |
+
chosen_embeddings,
|
1343 |
+
rejected_embeddings,
|
1344 |
+
)
|
1345 |
+
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
1346 |
+
|
1347 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1348 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1349 |
+
|
1350 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1351 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1352 |
+
|
1353 |
+
if all_num_chosen > 0:
|
1354 |
+
metrics["rewards/chosen_sum"] = (
|
1355 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1356 |
+
)
|
1357 |
+
metrics["logps/chosen_sum"] = (
|
1358 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1359 |
+
)
|
1360 |
+
metrics["logits/chosen_sum"] = (
|
1361 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1362 |
+
)
|
1363 |
+
metrics["count/chosen"] = all_num_chosen
|
1364 |
+
|
1365 |
+
if all_num_rejected > 0:
|
1366 |
+
metrics["rewards/rejected_sum"] = (
|
1367 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1368 |
+
)
|
1369 |
+
metrics["logps/rejected_sum"] = (
|
1370 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1371 |
+
)
|
1372 |
+
metrics["logits/rejected_sum"] = (
|
1373 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1374 |
+
)
|
1375 |
+
metrics["count/rejected"] = all_num_rejected
|
1376 |
+
|
1377 |
+
loss = losses.nanmean()
|
1378 |
+
if self.aux_loss_enabled:
|
1379 |
+
loss += self.aux_loss_coef * aux_loss
|
1380 |
+
|
1381 |
+
return loss, metrics
|
1382 |
+
|
1383 |
+
def compute_loss(
|
1384 |
+
self,
|
1385 |
+
model: Union[PreTrainedModel, nn.Module],
|
1386 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1387 |
+
return_outputs=False,
|
1388 |
+
num_items_in_batch=None,
|
1389 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1390 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1391 |
+
|
1392 |
+
with compute_loss_context_manager:
|
1393 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1394 |
+
|
1395 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1396 |
+
loss = loss.to(self.args.device)
|
1397 |
+
# force log the metrics
|
1398 |
+
if self.accelerator.is_main_process:
|
1399 |
+
self.store_metrics(metrics, train_eval="train")
|
1400 |
+
|
1401 |
+
if return_outputs:
|
1402 |
+
return (loss, metrics)
|
1403 |
+
return loss
|
1404 |
+
|
1405 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1406 |
+
for key, value in metrics.items():
|
1407 |
+
self._stored_metrics[train_eval][key].append(value)
|
1408 |
+
|
1409 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1410 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
1411 |
+
return None
|
1412 |
+
return SequentialSampler(self.train_dataset)
|
1413 |
+
|
1414 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1415 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1416 |
+
|
1417 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1418 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1419 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1420 |
+
with generate_context_manager:
|
1421 |
+
policy_output = model.generate(
|
1422 |
+
input_ids=batch["prompt_input_ids"],
|
1423 |
+
attention_mask=batch["prompt_attention_mask"],
|
1424 |
+
max_length=self.max_length,
|
1425 |
+
do_sample=True,
|
1426 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1427 |
+
)
|
1428 |
+
|
1429 |
+
# if reference_output in batch use that otherwise use the reference model
|
1430 |
+
if "reference_output" in batch:
|
1431 |
+
reference_output = batch["reference_output"]
|
1432 |
+
else:
|
1433 |
+
if self.ref_model is None:
|
1434 |
+
with self.null_ref_context():
|
1435 |
+
reference_output = self.model.generate(
|
1436 |
+
input_ids=batch["prompt_input_ids"],
|
1437 |
+
attention_mask=batch["prompt_attention_mask"],
|
1438 |
+
max_length=self.max_length,
|
1439 |
+
do_sample=True,
|
1440 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1441 |
+
)
|
1442 |
+
else:
|
1443 |
+
reference_output = self.ref_model.generate(
|
1444 |
+
input_ids=batch["prompt_input_ids"],
|
1445 |
+
attention_mask=batch["prompt_attention_mask"],
|
1446 |
+
max_length=self.max_length,
|
1447 |
+
do_sample=True,
|
1448 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1449 |
+
)
|
1450 |
+
|
1451 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1452 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1453 |
+
|
1454 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1455 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1456 |
+
|
1457 |
+
return policy_output_decoded, reference_output_decoded
|
1458 |
+
|
1459 |
+
def prediction_step(
|
1460 |
+
self,
|
1461 |
+
model: Union[PreTrainedModel, nn.Module],
|
1462 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1463 |
+
prediction_loss_only: bool,
|
1464 |
+
ignore_keys: Optional[list[str]] = None,
|
1465 |
+
):
|
1466 |
+
if ignore_keys is None:
|
1467 |
+
if hasattr(model, "config"):
|
1468 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1469 |
+
else:
|
1470 |
+
ignore_keys = []
|
1471 |
+
|
1472 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1473 |
+
with torch.no_grad(), prediction_context_manager:
|
1474 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1475 |
+
|
1476 |
+
# force log the metrics
|
1477 |
+
if self.accelerator.is_main_process:
|
1478 |
+
self.store_metrics(metrics, train_eval="eval")
|
1479 |
+
|
1480 |
+
if prediction_loss_only:
|
1481 |
+
return (loss.detach(), None, None)
|
1482 |
+
|
1483 |
+
# logits for the chosen and rejected samples from model
|
1484 |
+
logits_dict = {
|
1485 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
1486 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
1487 |
+
}
|
1488 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1489 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1490 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1491 |
+
|
1492 |
+
return (loss.detach(), logits, labels)
|
1493 |
+
|
1494 |
+
def evaluation_loop(
|
1495 |
+
self,
|
1496 |
+
dataloader: DataLoader,
|
1497 |
+
description: str,
|
1498 |
+
prediction_loss_only: Optional[bool] = None,
|
1499 |
+
ignore_keys: Optional[list[str]] = None,
|
1500 |
+
metric_key_prefix: str = "eval",
|
1501 |
+
) -> EvalLoopOutput:
|
1502 |
+
"""
|
1503 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1504 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1505 |
+
|
1506 |
+
Works both with or without labels.
|
1507 |
+
"""
|
1508 |
+
|
1509 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1510 |
+
if self.generate_during_eval:
|
1511 |
+
# Generate random indices within the range of the total number of samples
|
1512 |
+
num_samples = len(dataloader.dataset)
|
1513 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1514 |
+
|
1515 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1516 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1517 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1518 |
+
random_batch = self._prepare_inputs(random_batch)
|
1519 |
+
|
1520 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1521 |
+
target_batch = {
|
1522 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1523 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1524 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1525 |
+
}
|
1526 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1527 |
+
|
1528 |
+
table = pd.DataFrame(
|
1529 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
1530 |
+
data=[
|
1531 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1532 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1533 |
+
],
|
1534 |
+
)
|
1535 |
+
if "wandb" in self.args.report_to:
|
1536 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1537 |
+
|
1538 |
+
if "comet_ml" in self.args.report_to:
|
1539 |
+
log_table_to_comet_experiment(
|
1540 |
+
name="game_log.csv",
|
1541 |
+
table=table,
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
# Base evaluation
|
1545 |
+
initial_output = super().evaluation_loop(
|
1546 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1547 |
+
)
|
1548 |
+
|
1549 |
+
return initial_output
|
1550 |
+
|
1551 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1552 |
+
"""
|
1553 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1554 |
+
|
1555 |
+
Args:
|
1556 |
+
logs (`dict[str, float]`):
|
1557 |
+
The values to log.
|
1558 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1559 |
+
Start time of the training.
|
1560 |
+
"""
|
1561 |
+
# logs either has 'loss' or 'eval_loss'
|
1562 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1563 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
1564 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1565 |
+
# accumulate average metrics from sums and lengths
|
1566 |
+
for split in ["chosen", "rejected"]:
|
1567 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1568 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1569 |
+
for metric in ["rewards", "logps", "logits"]:
|
1570 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
1571 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1572 |
+
/ count_sum
|
1573 |
+
)
|
1574 |
+
# delete obsolete metric
|
1575 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1576 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
1577 |
+
# calculate reward margin
|
1578 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1579 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1580 |
+
# Add averaged stored metrics to logs
|
1581 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1582 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1583 |
+
del self._stored_metrics[train_eval]
|
1584 |
+
|
1585 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1586 |
+
return super().log(logs, start_time)
|
1587 |
+
else: # transformers<=4.46
|
1588 |
+
return super().log(logs)
|
1589 |
+
|
1590 |
+
def create_model_card(
|
1591 |
+
self,
|
1592 |
+
model_name: Optional[str] = None,
|
1593 |
+
dataset_name: Optional[str] = None,
|
1594 |
+
tags: Union[str, list[str], None] = None,
|
1595 |
+
):
|
1596 |
+
"""
|
1597 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1598 |
+
|
1599 |
+
Args:
|
1600 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1601 |
+
Name of the model.
|
1602 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1603 |
+
Name of the dataset used for training.
|
1604 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1605 |
+
Tags to be associated with the model card.
|
1606 |
+
"""
|
1607 |
+
if not self.is_world_process_zero():
|
1608 |
+
return
|
1609 |
+
|
1610 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1611 |
+
base_model = self.model.config._name_or_path
|
1612 |
+
else:
|
1613 |
+
base_model = None
|
1614 |
+
|
1615 |
+
tags = tags or []
|
1616 |
+
if isinstance(tags, str):
|
1617 |
+
tags = [tags]
|
1618 |
+
|
1619 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1620 |
+
tags.append("unsloth")
|
1621 |
+
|
1622 |
+
citation = textwrap.dedent("""\
|
1623 |
+
@article{jung2024binary,
|
1624 |
+
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
1625 |
+
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
1626 |
+
year = 2024,
|
1627 |
+
eprint = {arXiv:2404.04656}
|
1628 |
+
}""")
|
1629 |
+
|
1630 |
+
model_card = generate_model_card(
|
1631 |
+
base_model=base_model,
|
1632 |
+
model_name=model_name,
|
1633 |
+
hub_model_id=self.hub_model_id,
|
1634 |
+
dataset_name=dataset_name,
|
1635 |
+
tags=tags,
|
1636 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1637 |
+
comet_url=get_comet_experiment_url(),
|
1638 |
+
trainer_name="BCO",
|
1639 |
+
trainer_citation=citation,
|
1640 |
+
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
1641 |
+
paper_id="2404.04656",
|
1642 |
+
)
|
1643 |
+
|
1644 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1645 |
+
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
1646 |
+
"""
|
1647 |
+
|
1648 |
+
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
1649 |
+
|
1650 |
+
Args:
|
1651 |
+
model (`transformers.PreTrainedModel`):
|
1652 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1653 |
+
ref_model (`PreTrainedModelWrapper`):
|
1654 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1655 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1656 |
+
args (`BCOConfig`):
|
1657 |
+
The arguments to use for training.
|
1658 |
+
train_dataset (`datasets.Dataset`):
|
1659 |
+
The dataset to use for training.
|
1660 |
+
eval_dataset (`datasets.Dataset`):
|
1661 |
+
The dataset to use for evaluation.
|
1662 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1663 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1664 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1665 |
+
reuse the fine-tuned model.
|
1666 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1667 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1668 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1669 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1670 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1671 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1672 |
+
The callbacks to use for training.
|
1673 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1674 |
+
The optimizer and scheduler to use for training.
|
1675 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1676 |
+
The function to use to preprocess the logits before computing the metrics.
|
1677 |
+
peft_config (`dict`, defaults to `None`):
|
1678 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1679 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1680 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1681 |
+
a dictionary string to metric values.
|
1682 |
+
model_adapter_name (`str`, defaults to `None`):
|
1683 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1684 |
+
ref_adapter_name (`str`, defaults to `None`):
|
1685 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1686 |
+
|
1687 |
+
"""
|
1688 |
+
def __init__(
|
1689 |
+
self,
|
1690 |
+
model = None,
|
1691 |
+
ref_model = None,
|
1692 |
+
args = None,
|
1693 |
+
train_dataset = None,
|
1694 |
+
eval_dataset = None,
|
1695 |
+
processing_class = None,
|
1696 |
+
data_collator = None,
|
1697 |
+
model_init = None,
|
1698 |
+
callbacks = None,
|
1699 |
+
preprocess_logits_for_metrics = None,
|
1700 |
+
peft_config = None,
|
1701 |
+
compute_metrics = None,
|
1702 |
+
model_adapter_name = None,
|
1703 |
+
ref_adapter_name = None,
|
1704 |
+
embedding_func = None,
|
1705 |
+
embedding_tokenizer = None,
|
1706 |
+
**kwargs
|
1707 |
+
):
|
1708 |
+
if args is None: args = UnslothBCOConfig()
|
1709 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1710 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1711 |
+
force_float32 = False
|
1712 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1713 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1714 |
+
force_float32 = True
|
1715 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1716 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1717 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1718 |
+
from unsloth_zoo.utils import _get_dtype
|
1719 |
+
dtype = _get_dtype(dtype)
|
1720 |
+
float16 = dtype == torch.float16
|
1721 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1722 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1723 |
+
if force_float32:
|
1724 |
+
args.fp16 = False
|
1725 |
+
args.bf16 = False
|
1726 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1727 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1728 |
+
args.fp16 = float16
|
1729 |
+
args.bf16 = not float16
|
1730 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1731 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1732 |
+
args.eval_strategy = 'steps'
|
1733 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1734 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1735 |
+
if ga_steps is not None and ga_steps > 1:
|
1736 |
+
from transformers import __version__ as transformers_version
|
1737 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1738 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1739 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1740 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1741 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1742 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1743 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1744 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1745 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1746 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1747 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1748 |
+
if force_float32:
|
1749 |
+
args.bf16_full_eval = False
|
1750 |
+
args.fp16_full_eval = False
|
1751 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1752 |
+
args.bf16_full_eval = True
|
1753 |
+
args.fp16_full_eval = False
|
1754 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1755 |
+
args.bf16_full_eval = args.bf16
|
1756 |
+
args.fp16_full_eval = args.fp16
|
1757 |
+
_output_logits = False
|
1758 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1759 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1760 |
+
if _output_logits:
|
1761 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1762 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1763 |
+
pass
|
1764 |
+
else:
|
1765 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1766 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1767 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1768 |
+
max_seq_length = model.max_seq_length
|
1769 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1770 |
+
if model is not None and hasattr(model, 'for_training'):
|
1771 |
+
model.for_training()
|
1772 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1773 |
+
if 'processing_class' in locals():
|
1774 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1775 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1776 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1777 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1778 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1779 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1780 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1781 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1782 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1783 |
+
else:
|
1784 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1785 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1786 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1787 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1788 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1789 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1790 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1791 |
+
else:
|
1792 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1793 |
+
other_metrics = []
|
1794 |
+
|
1795 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1796 |
+
PatchRLStatistics('bco_trainer', other_metrics)
|
1797 |
+
|
1798 |
+
super().__init__(
|
1799 |
+
model = model,
|
1800 |
+
ref_model = ref_model,
|
1801 |
+
args = args,
|
1802 |
+
train_dataset = train_dataset,
|
1803 |
+
eval_dataset = eval_dataset,
|
1804 |
+
processing_class = processing_class,
|
1805 |
+
data_collator = data_collator,
|
1806 |
+
model_init = model_init,
|
1807 |
+
callbacks = callbacks,
|
1808 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1809 |
+
peft_config = peft_config,
|
1810 |
+
compute_metrics = compute_metrics,
|
1811 |
+
model_adapter_name = model_adapter_name,
|
1812 |
+
ref_adapter_name = ref_adapter_name,
|
1813 |
+
embedding_func = embedding_func,
|
1814 |
+
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
1815 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1816 |
+
self.neftune_hook_handle.remove()
|
1817 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1818 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1819 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1820 |
+
pass
|
1821 |
+
|
1822 |
+
pass
|
unsloth_compiled_cache/UnslothCPOTrainer.py
ADDED
@@ -0,0 +1,1555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothCPOConfig(CPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`CPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
67 |
+
the [paper](https://huggingface.co/papers/2310.12036).
|
68 |
+
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
69 |
+
Label smoothing factor. This argument is required if you want to use the default data collator.
|
70 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
71 |
+
Type of loss to use. Possible values are:
|
72 |
+
|
73 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
74 |
+
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
|
75 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
76 |
+
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
77 |
+
|
78 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether to disable dropout in the model.
|
80 |
+
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
81 |
+
Weight of the BC regularizer in CPO training.
|
82 |
+
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
83 |
+
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
84 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
85 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
86 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
87 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
88 |
+
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
89 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
90 |
+
This argument is required if you want to use the default data collator.
|
91 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
92 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
93 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
94 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
95 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
96 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
98 |
+
string.
|
99 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
100 |
+
Number of processes to use for processing the dataset.
|
101 |
+
|
102 |
+
"""
|
103 |
+
vllm_sampling_params: Optional[Any] = field(
|
104 |
+
default = None,
|
105 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
106 |
+
)
|
107 |
+
unsloth_num_chunks : Optional[int] = field(
|
108 |
+
default = -1,
|
109 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
110 |
+
)
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
output_dir = None,
|
114 |
+
overwrite_output_dir = None,
|
115 |
+
do_train = False,
|
116 |
+
do_eval = False,
|
117 |
+
do_predict = False,
|
118 |
+
eval_strategy = 'no',
|
119 |
+
prediction_loss_only = False,
|
120 |
+
per_device_train_batch_size = 4,
|
121 |
+
per_device_eval_batch_size = 4,
|
122 |
+
per_gpu_train_batch_size = None,
|
123 |
+
per_gpu_eval_batch_size = None,
|
124 |
+
gradient_accumulation_steps = 2,
|
125 |
+
eval_accumulation_steps = 2,
|
126 |
+
eval_delay = 0,
|
127 |
+
torch_empty_cache_steps = 250,
|
128 |
+
learning_rate = 5e-05,
|
129 |
+
weight_decay = 0.01,
|
130 |
+
adam_beta1 = 0.9,
|
131 |
+
adam_beta2 = 0.999,
|
132 |
+
adam_epsilon = 1e-08,
|
133 |
+
max_grad_norm = 1.0,
|
134 |
+
num_train_epochs = 3.0,
|
135 |
+
max_steps = -1,
|
136 |
+
lr_scheduler_type = 'linear',
|
137 |
+
warmup_ratio = 0.1,
|
138 |
+
warmup_steps = 0,
|
139 |
+
log_level = 'passive',
|
140 |
+
log_level_replica = 'warning',
|
141 |
+
log_on_each_node = True,
|
142 |
+
logging_dir = None,
|
143 |
+
logging_strategy = 'steps',
|
144 |
+
logging_first_step = False,
|
145 |
+
logging_steps = 1,
|
146 |
+
logging_nan_inf_filter = False,
|
147 |
+
save_strategy = 'steps',
|
148 |
+
save_steps = 500,
|
149 |
+
save_total_limit = None,
|
150 |
+
save_safetensors = True,
|
151 |
+
save_on_each_node = False,
|
152 |
+
save_only_model = False,
|
153 |
+
restore_callback_states_from_checkpoint = False,
|
154 |
+
no_cuda = False,
|
155 |
+
use_cpu = False,
|
156 |
+
use_mps_device = False,
|
157 |
+
seed = 3407,
|
158 |
+
data_seed = 3407,
|
159 |
+
jit_mode_eval = False,
|
160 |
+
use_ipex = False,
|
161 |
+
bf16 = False,
|
162 |
+
fp16 = False,
|
163 |
+
fp16_opt_level = 'O1',
|
164 |
+
half_precision_backend = 'auto',
|
165 |
+
bf16_full_eval = False,
|
166 |
+
fp16_full_eval = False,
|
167 |
+
tf32 = None,
|
168 |
+
local_rank = -1,
|
169 |
+
ddp_backend = None,
|
170 |
+
tpu_num_cores = None,
|
171 |
+
tpu_metrics_debug = False,
|
172 |
+
debug = '',
|
173 |
+
dataloader_drop_last = False,
|
174 |
+
eval_steps = None,
|
175 |
+
dataloader_num_workers = 0,
|
176 |
+
dataloader_prefetch_factor = None,
|
177 |
+
past_index = -1,
|
178 |
+
run_name = None,
|
179 |
+
disable_tqdm = None,
|
180 |
+
remove_unused_columns = True,
|
181 |
+
label_names = None,
|
182 |
+
load_best_model_at_end = False,
|
183 |
+
metric_for_best_model = None,
|
184 |
+
greater_is_better = None,
|
185 |
+
ignore_data_skip = False,
|
186 |
+
fsdp = '',
|
187 |
+
fsdp_min_num_params = 0,
|
188 |
+
fsdp_config = None,
|
189 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
190 |
+
accelerator_config = None,
|
191 |
+
deepspeed = None,
|
192 |
+
label_smoothing_factor = 0.0,
|
193 |
+
optim = 'adamw_8bit',
|
194 |
+
optim_args = None,
|
195 |
+
adafactor = False,
|
196 |
+
group_by_length = False,
|
197 |
+
length_column_name = 'length',
|
198 |
+
report_to = None,
|
199 |
+
ddp_find_unused_parameters = None,
|
200 |
+
ddp_bucket_cap_mb = None,
|
201 |
+
ddp_broadcast_buffers = None,
|
202 |
+
dataloader_pin_memory = True,
|
203 |
+
dataloader_persistent_workers = False,
|
204 |
+
skip_memory_metrics = True,
|
205 |
+
use_legacy_prediction_loop = False,
|
206 |
+
push_to_hub = False,
|
207 |
+
resume_from_checkpoint = None,
|
208 |
+
hub_model_id = None,
|
209 |
+
hub_strategy = 'every_save',
|
210 |
+
hub_token = None,
|
211 |
+
hub_private_repo = None,
|
212 |
+
hub_always_push = False,
|
213 |
+
gradient_checkpointing = False,
|
214 |
+
gradient_checkpointing_kwargs = None,
|
215 |
+
include_inputs_for_metrics = False,
|
216 |
+
eval_do_concat_batches = True,
|
217 |
+
fp16_backend = 'auto',
|
218 |
+
evaluation_strategy = None,
|
219 |
+
push_to_hub_model_id = None,
|
220 |
+
push_to_hub_organization = None,
|
221 |
+
push_to_hub_token = None,
|
222 |
+
mp_parameters = '',
|
223 |
+
auto_find_batch_size = False,
|
224 |
+
full_determinism = False,
|
225 |
+
torchdynamo = None,
|
226 |
+
ray_scope = 'last',
|
227 |
+
ddp_timeout = 1800,
|
228 |
+
torch_compile = False,
|
229 |
+
torch_compile_backend = None,
|
230 |
+
torch_compile_mode = None,
|
231 |
+
dispatch_batches = None,
|
232 |
+
split_batches = None,
|
233 |
+
include_tokens_per_second = False,
|
234 |
+
include_num_input_tokens_seen = False,
|
235 |
+
neftune_noise_alpha = None,
|
236 |
+
optim_target_modules = None,
|
237 |
+
batch_eval_metrics = False,
|
238 |
+
eval_on_start = False,
|
239 |
+
use_liger_kernel = False,
|
240 |
+
eval_use_gather_object = False,
|
241 |
+
average_tokens_across_devices = False,
|
242 |
+
max_length = 1024,
|
243 |
+
max_prompt_length = 512,
|
244 |
+
max_completion_length = None,
|
245 |
+
beta = 0.1,
|
246 |
+
label_smoothing = 0.0,
|
247 |
+
loss_type = 'sigmoid',
|
248 |
+
disable_dropout = True,
|
249 |
+
cpo_alpha = 1.0,
|
250 |
+
simpo_gamma = 0.5,
|
251 |
+
label_pad_token_id = -100,
|
252 |
+
padding_value = None,
|
253 |
+
truncation_mode = 'keep_end',
|
254 |
+
generate_during_eval = False,
|
255 |
+
is_encoder_decoder = None,
|
256 |
+
model_init_kwargs = None,
|
257 |
+
dataset_num_proc = None,
|
258 |
+
vllm_sampling_params = None,
|
259 |
+
unsloth_num_chunks = -1,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
263 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
264 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
265 |
+
output_dir = 'unsloth_training_checkpoints'
|
266 |
+
save_strategy = 'no'
|
267 |
+
if dataset_num_proc is None:
|
268 |
+
from multiprocessing import cpu_count
|
269 |
+
dataset_num_proc = cpu_count()
|
270 |
+
|
271 |
+
super().__init__(
|
272 |
+
output_dir = output_dir,
|
273 |
+
overwrite_output_dir = overwrite_output_dir,
|
274 |
+
do_train = do_train,
|
275 |
+
do_eval = do_eval,
|
276 |
+
do_predict = do_predict,
|
277 |
+
eval_strategy = eval_strategy,
|
278 |
+
prediction_loss_only = prediction_loss_only,
|
279 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
280 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
281 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
282 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
283 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
284 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
285 |
+
eval_delay = eval_delay,
|
286 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
287 |
+
learning_rate = learning_rate,
|
288 |
+
weight_decay = weight_decay,
|
289 |
+
adam_beta1 = adam_beta1,
|
290 |
+
adam_beta2 = adam_beta2,
|
291 |
+
adam_epsilon = adam_epsilon,
|
292 |
+
max_grad_norm = max_grad_norm,
|
293 |
+
num_train_epochs = num_train_epochs,
|
294 |
+
max_steps = max_steps,
|
295 |
+
lr_scheduler_type = lr_scheduler_type,
|
296 |
+
warmup_ratio = warmup_ratio,
|
297 |
+
warmup_steps = warmup_steps,
|
298 |
+
log_level = log_level,
|
299 |
+
log_level_replica = log_level_replica,
|
300 |
+
log_on_each_node = log_on_each_node,
|
301 |
+
logging_dir = logging_dir,
|
302 |
+
logging_strategy = logging_strategy,
|
303 |
+
logging_first_step = logging_first_step,
|
304 |
+
logging_steps = logging_steps,
|
305 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
306 |
+
save_strategy = save_strategy,
|
307 |
+
save_steps = save_steps,
|
308 |
+
save_total_limit = save_total_limit,
|
309 |
+
save_safetensors = save_safetensors,
|
310 |
+
save_on_each_node = save_on_each_node,
|
311 |
+
save_only_model = save_only_model,
|
312 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
313 |
+
no_cuda = no_cuda,
|
314 |
+
use_cpu = use_cpu,
|
315 |
+
use_mps_device = use_mps_device,
|
316 |
+
seed = seed,
|
317 |
+
data_seed = data_seed,
|
318 |
+
jit_mode_eval = jit_mode_eval,
|
319 |
+
use_ipex = use_ipex,
|
320 |
+
bf16 = bf16,
|
321 |
+
fp16 = fp16,
|
322 |
+
fp16_opt_level = fp16_opt_level,
|
323 |
+
half_precision_backend = half_precision_backend,
|
324 |
+
bf16_full_eval = bf16_full_eval,
|
325 |
+
fp16_full_eval = fp16_full_eval,
|
326 |
+
tf32 = tf32,
|
327 |
+
local_rank = local_rank,
|
328 |
+
ddp_backend = ddp_backend,
|
329 |
+
tpu_num_cores = tpu_num_cores,
|
330 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
331 |
+
debug = debug,
|
332 |
+
dataloader_drop_last = dataloader_drop_last,
|
333 |
+
eval_steps = eval_steps,
|
334 |
+
dataloader_num_workers = dataloader_num_workers,
|
335 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
336 |
+
past_index = past_index,
|
337 |
+
run_name = run_name,
|
338 |
+
disable_tqdm = disable_tqdm,
|
339 |
+
remove_unused_columns = remove_unused_columns,
|
340 |
+
label_names = label_names,
|
341 |
+
load_best_model_at_end = load_best_model_at_end,
|
342 |
+
metric_for_best_model = metric_for_best_model,
|
343 |
+
greater_is_better = greater_is_better,
|
344 |
+
ignore_data_skip = ignore_data_skip,
|
345 |
+
fsdp = fsdp,
|
346 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
347 |
+
fsdp_config = fsdp_config,
|
348 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
349 |
+
accelerator_config = accelerator_config,
|
350 |
+
deepspeed = deepspeed,
|
351 |
+
label_smoothing_factor = label_smoothing_factor,
|
352 |
+
optim = optim,
|
353 |
+
optim_args = optim_args,
|
354 |
+
adafactor = adafactor,
|
355 |
+
group_by_length = group_by_length,
|
356 |
+
length_column_name = length_column_name,
|
357 |
+
report_to = report_to,
|
358 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
359 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
360 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
361 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
362 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
363 |
+
skip_memory_metrics = skip_memory_metrics,
|
364 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
365 |
+
push_to_hub = push_to_hub,
|
366 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
367 |
+
hub_model_id = hub_model_id,
|
368 |
+
hub_strategy = hub_strategy,
|
369 |
+
hub_token = hub_token,
|
370 |
+
hub_private_repo = hub_private_repo,
|
371 |
+
hub_always_push = hub_always_push,
|
372 |
+
gradient_checkpointing = gradient_checkpointing,
|
373 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
374 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
375 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
376 |
+
fp16_backend = fp16_backend,
|
377 |
+
evaluation_strategy = evaluation_strategy,
|
378 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
379 |
+
push_to_hub_organization = push_to_hub_organization,
|
380 |
+
push_to_hub_token = push_to_hub_token,
|
381 |
+
mp_parameters = mp_parameters,
|
382 |
+
auto_find_batch_size = auto_find_batch_size,
|
383 |
+
full_determinism = full_determinism,
|
384 |
+
torchdynamo = torchdynamo,
|
385 |
+
ray_scope = ray_scope,
|
386 |
+
ddp_timeout = ddp_timeout,
|
387 |
+
torch_compile = torch_compile,
|
388 |
+
torch_compile_backend = torch_compile_backend,
|
389 |
+
torch_compile_mode = torch_compile_mode,
|
390 |
+
dispatch_batches = dispatch_batches,
|
391 |
+
split_batches = split_batches,
|
392 |
+
include_tokens_per_second = include_tokens_per_second,
|
393 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
394 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
395 |
+
optim_target_modules = optim_target_modules,
|
396 |
+
batch_eval_metrics = batch_eval_metrics,
|
397 |
+
eval_on_start = eval_on_start,
|
398 |
+
use_liger_kernel = use_liger_kernel,
|
399 |
+
eval_use_gather_object = eval_use_gather_object,
|
400 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
401 |
+
max_length = max_length,
|
402 |
+
max_prompt_length = max_prompt_length,
|
403 |
+
max_completion_length = max_completion_length,
|
404 |
+
beta = beta,
|
405 |
+
label_smoothing = label_smoothing,
|
406 |
+
loss_type = loss_type,
|
407 |
+
disable_dropout = disable_dropout,
|
408 |
+
cpo_alpha = cpo_alpha,
|
409 |
+
simpo_gamma = simpo_gamma,
|
410 |
+
label_pad_token_id = label_pad_token_id,
|
411 |
+
padding_value = padding_value,
|
412 |
+
truncation_mode = truncation_mode,
|
413 |
+
generate_during_eval = generate_during_eval,
|
414 |
+
is_encoder_decoder = is_encoder_decoder,
|
415 |
+
model_init_kwargs = model_init_kwargs,
|
416 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
417 |
+
self.vllm_sampling_params = vllm_sampling_params
|
418 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
419 |
+
pass
|
420 |
+
|
421 |
+
class _UnslothCPOTrainer(Trainer):
|
422 |
+
r""""""
|
423 |
+
|
424 |
+
_tag_names = ["trl", "cpo"]
|
425 |
+
|
426 |
+
def __init__(
|
427 |
+
self,
|
428 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
429 |
+
args: Optional[CPOConfig] = None,
|
430 |
+
data_collator: Optional[DataCollator] = None,
|
431 |
+
train_dataset: Optional[Dataset] = None,
|
432 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
433 |
+
processing_class: Optional[
|
434 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
435 |
+
] = None,
|
436 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
437 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
438 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
439 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
440 |
+
peft_config: Optional[dict] = None,
|
441 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
442 |
+
):
|
443 |
+
if args.model_init_kwargs is None:
|
444 |
+
model_init_kwargs = {}
|
445 |
+
elif not isinstance(model, str):
|
446 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
447 |
+
else:
|
448 |
+
model_init_kwargs = args.model_init_kwargs
|
449 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
450 |
+
if torch_dtype is not None:
|
451 |
+
# Convert to `torch.dtype` if an str is passed
|
452 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
453 |
+
torch_dtype = getattr(torch, torch_dtype)
|
454 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
455 |
+
raise ValueError(
|
456 |
+
f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
457 |
+
)
|
458 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
459 |
+
|
460 |
+
if isinstance(model, str):
|
461 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
462 |
+
|
463 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
464 |
+
# has been called in order to properly call autocast if needed.
|
465 |
+
self._peft_has_been_casted_to_bf16 = False
|
466 |
+
|
467 |
+
if not is_peft_available() and peft_config is not None:
|
468 |
+
raise ValueError(
|
469 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
470 |
+
)
|
471 |
+
elif is_peft_available() and peft_config is not None:
|
472 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
473 |
+
if isinstance(model, PeftModel):
|
474 |
+
model = model.merge_and_unload()
|
475 |
+
|
476 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
477 |
+
_support_gc_kwargs = hasattr(
|
478 |
+
args, "gradient_checkpointing_kwargs"
|
479 |
+
) and "gradient_checkpointing_kwargs" in list(
|
480 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
481 |
+
)
|
482 |
+
|
483 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
484 |
+
|
485 |
+
if _support_gc_kwargs:
|
486 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
487 |
+
|
488 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
489 |
+
elif getattr(args, "gradient_checkpointing", False):
|
490 |
+
# For backward compatibility with older versions of transformers
|
491 |
+
if hasattr(model, "enable_input_require_grads"):
|
492 |
+
model.enable_input_require_grads()
|
493 |
+
else:
|
494 |
+
|
495 |
+
def make_inputs_require_grad(module, input, output):
|
496 |
+
output.requires_grad_(True)
|
497 |
+
|
498 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
499 |
+
|
500 |
+
# get peft model with the given config
|
501 |
+
model = model
|
502 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
503 |
+
peft_module_casting_to_bf16(model)
|
504 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
505 |
+
self._peft_has_been_casted_to_bf16 = True
|
506 |
+
|
507 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
508 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
509 |
+
# fail or completely fail.
|
510 |
+
elif getattr(args, "gradient_checkpointing", False):
|
511 |
+
# For backward compatibility with older versions of transformers
|
512 |
+
if hasattr(model, "enable_input_require_grads"):
|
513 |
+
model.enable_input_require_grads()
|
514 |
+
else:
|
515 |
+
|
516 |
+
def make_inputs_require_grad(module, input, output):
|
517 |
+
output.requires_grad_(True)
|
518 |
+
|
519 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
520 |
+
|
521 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
522 |
+
raise ValueError(
|
523 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
524 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
525 |
+
)
|
526 |
+
|
527 |
+
if model is not None:
|
528 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
529 |
+
elif args.is_encoder_decoder is None:
|
530 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
531 |
+
else:
|
532 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
533 |
+
|
534 |
+
if self.is_encoder_decoder:
|
535 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
536 |
+
self.pad_token_id = model.config.pad_token_id
|
537 |
+
|
538 |
+
if processing_class is None:
|
539 |
+
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
540 |
+
if args.max_length is None:
|
541 |
+
warnings.warn(
|
542 |
+
"`max_length` is not set in the CPOConfig's init"
|
543 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
544 |
+
UserWarning,
|
545 |
+
)
|
546 |
+
max_length = 512
|
547 |
+
else:
|
548 |
+
max_length = args.max_length
|
549 |
+
if args.max_prompt_length is None:
|
550 |
+
warnings.warn(
|
551 |
+
"`max_prompt_length` is not set in the CPOConfig's init"
|
552 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
553 |
+
UserWarning,
|
554 |
+
)
|
555 |
+
max_prompt_length = 128
|
556 |
+
else:
|
557 |
+
max_prompt_length = args.max_prompt_length
|
558 |
+
|
559 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
560 |
+
warnings.warn(
|
561 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
562 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
563 |
+
UserWarning,
|
564 |
+
)
|
565 |
+
max_completion_length = 128
|
566 |
+
else:
|
567 |
+
max_completion_length = args.max_completion_length
|
568 |
+
|
569 |
+
if data_collator is None:
|
570 |
+
data_collator = DPODataCollatorWithPadding(
|
571 |
+
pad_token_id=processing_class.pad_token_id,
|
572 |
+
label_pad_token_id=args.label_pad_token_id,
|
573 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
574 |
+
)
|
575 |
+
|
576 |
+
if args.remove_unused_columns:
|
577 |
+
args.remove_unused_columns = False
|
578 |
+
# warn users
|
579 |
+
warnings.warn(
|
580 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
581 |
+
" we have set it for you, but you should do it yourself in the future.",
|
582 |
+
UserWarning,
|
583 |
+
)
|
584 |
+
|
585 |
+
self.use_dpo_data_collator = True
|
586 |
+
else:
|
587 |
+
self.use_dpo_data_collator = False
|
588 |
+
|
589 |
+
# Disable dropout in the model
|
590 |
+
if args.disable_dropout:
|
591 |
+
disable_dropout_in_model(model)
|
592 |
+
|
593 |
+
self.max_length = max_length
|
594 |
+
self.generate_during_eval = args.generate_during_eval
|
595 |
+
self.label_pad_token_id = args.label_pad_token_id
|
596 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
597 |
+
self.max_prompt_length = max_prompt_length
|
598 |
+
self.truncation_mode = args.truncation_mode
|
599 |
+
self.max_completion_length = max_completion_length
|
600 |
+
self.processing_class = processing_class
|
601 |
+
|
602 |
+
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
603 |
+
warnings.warn(
|
604 |
+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
605 |
+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
606 |
+
UserWarning,
|
607 |
+
)
|
608 |
+
if args.loss_type == "kto_pair":
|
609 |
+
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
610 |
+
|
611 |
+
self.beta = args.beta
|
612 |
+
self.label_smoothing = args.label_smoothing
|
613 |
+
self.loss_type = args.loss_type
|
614 |
+
self.cpo_alpha = args.cpo_alpha
|
615 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
616 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
617 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
618 |
+
warnings.warn(
|
619 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
620 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
621 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
622 |
+
"loss.",
|
623 |
+
UserWarning,
|
624 |
+
)
|
625 |
+
|
626 |
+
if args.loss_type == "simpo":
|
627 |
+
self.simpo_gamma = args.simpo_gamma
|
628 |
+
|
629 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
630 |
+
|
631 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
632 |
+
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
633 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
634 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
635 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
636 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
637 |
+
# that the warning has already been issued.
|
638 |
+
model.warnings_issued["estimate_tokens"] = True
|
639 |
+
|
640 |
+
# Compute that only on the main process for faster data processing.
|
641 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
642 |
+
with PartialState().local_main_process_first():
|
643 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
644 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
645 |
+
train_dataset = train_dataset.map(
|
646 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
647 |
+
)
|
648 |
+
if eval_dataset is not None:
|
649 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
650 |
+
eval_dataset = eval_dataset.map(
|
651 |
+
maybe_apply_chat_template,
|
652 |
+
fn_kwargs={"tokenizer": processing_class},
|
653 |
+
num_proc=args.dataset_num_proc,
|
654 |
+
)
|
655 |
+
|
656 |
+
# tokenize the dataset
|
657 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
658 |
+
if eval_dataset is not None:
|
659 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
660 |
+
|
661 |
+
super().__init__(
|
662 |
+
model=model,
|
663 |
+
args=args,
|
664 |
+
data_collator=data_collator,
|
665 |
+
train_dataset=train_dataset,
|
666 |
+
eval_dataset=eval_dataset,
|
667 |
+
processing_class=processing_class,
|
668 |
+
model_init=model_init,
|
669 |
+
compute_metrics=compute_metrics,
|
670 |
+
callbacks=callbacks,
|
671 |
+
optimizers=optimizers,
|
672 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
673 |
+
)
|
674 |
+
|
675 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
676 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
677 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
678 |
+
self.model_accepts_loss_kwargs = False
|
679 |
+
|
680 |
+
# Add tags for models that have been loaded with the correct transformers version
|
681 |
+
if hasattr(self.model, "add_model_tags"):
|
682 |
+
self.model.add_model_tags(self._tag_names)
|
683 |
+
|
684 |
+
if not hasattr(self, "accelerator"):
|
685 |
+
raise AttributeError(
|
686 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
687 |
+
)
|
688 |
+
|
689 |
+
def build_tokenized_answer(self, prompt, answer):
|
690 |
+
"""
|
691 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
692 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
693 |
+
Reference:
|
694 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
695 |
+
"""
|
696 |
+
|
697 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
698 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
699 |
+
|
700 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
701 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
702 |
+
|
703 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
704 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
705 |
+
|
706 |
+
# Prepare input tokens for token by token comparison
|
707 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
708 |
+
|
709 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
710 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
711 |
+
|
712 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
713 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
714 |
+
# on the last token from the prompt being different when tokenized on its own
|
715 |
+
# vs when done as prompt+answer.
|
716 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
717 |
+
|
718 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
719 |
+
# last token has changed due to merging.
|
720 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
721 |
+
response_token_ids_start_idx -= 1
|
722 |
+
|
723 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
724 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
725 |
+
|
726 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
727 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
728 |
+
|
729 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
730 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
731 |
+
|
732 |
+
return dict(
|
733 |
+
prompt_input_ids=prompt_input_ids,
|
734 |
+
prompt_attention_mask=prompt_attention_mask,
|
735 |
+
input_ids=answer_input_ids,
|
736 |
+
attention_mask=answer_attention_mask,
|
737 |
+
)
|
738 |
+
|
739 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
740 |
+
"""Tokenize a single row from a CPO specific dataset.
|
741 |
+
|
742 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
743 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
744 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
745 |
+
|
746 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
747 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
748 |
+
label_pad_token_id for the prompt tokens.
|
749 |
+
"""
|
750 |
+
batch = {}
|
751 |
+
prompt = feature["prompt"]
|
752 |
+
chosen = feature["chosen"]
|
753 |
+
rejected = feature["rejected"]
|
754 |
+
|
755 |
+
if not self.is_encoder_decoder:
|
756 |
+
# Check issues below for more details
|
757 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
758 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
759 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
760 |
+
|
761 |
+
if not isinstance(prompt, str):
|
762 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
763 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
764 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
765 |
+
|
766 |
+
if not isinstance(chosen, str):
|
767 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
768 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
769 |
+
|
770 |
+
if not isinstance(rejected, str):
|
771 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
772 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
773 |
+
|
774 |
+
# Last prompt token might get merged by tokenizer and
|
775 |
+
# it should not be included for generation if that happens
|
776 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
777 |
+
|
778 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
779 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
780 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
781 |
+
|
782 |
+
for k, v in prompt_tokens.items():
|
783 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
784 |
+
|
785 |
+
# Make sure prompts only have one different token at most an
|
786 |
+
# and length only differs by 1 at most
|
787 |
+
num_diff_tokens = sum(
|
788 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
789 |
+
)
|
790 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
791 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
792 |
+
raise ValueError(
|
793 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
794 |
+
"last token due to tokenizer merge ops."
|
795 |
+
)
|
796 |
+
|
797 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
798 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
799 |
+
self.processing_class.bos_token_id,
|
800 |
+
prompt_len_input_ids,
|
801 |
+
prompt_tokens,
|
802 |
+
chosen_prompt_len_input_ids,
|
803 |
+
chosen_tokens,
|
804 |
+
rejected_prompt_len_input_ids,
|
805 |
+
rejected_tokens,
|
806 |
+
)
|
807 |
+
|
808 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
809 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
810 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
811 |
+
)
|
812 |
+
|
813 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
814 |
+
|
815 |
+
# if combined sequence is too long, truncate the prompt
|
816 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
817 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
818 |
+
if self.truncation_mode == "keep_start":
|
819 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
820 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
821 |
+
elif self.truncation_mode == "keep_end":
|
822 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
823 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
824 |
+
else:
|
825 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
826 |
+
|
827 |
+
# if that's still too long, truncate the response
|
828 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
829 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
830 |
+
for k in ["input_ids", "attention_mask"]:
|
831 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
832 |
+
|
833 |
+
# Create labels
|
834 |
+
chosen_sequence_tokens = {
|
835 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
836 |
+
}
|
837 |
+
rejected_sequence_tokens = {
|
838 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
839 |
+
}
|
840 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
841 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
842 |
+
self.label_pad_token_id
|
843 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
844 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
845 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
846 |
+
self.label_pad_token_id
|
847 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
848 |
+
|
849 |
+
for k, toks in {
|
850 |
+
"chosen_": chosen_sequence_tokens,
|
851 |
+
"rejected_": rejected_sequence_tokens,
|
852 |
+
"": prompt_tokens,
|
853 |
+
}.items():
|
854 |
+
for type_key, tokens in toks.items():
|
855 |
+
if type_key == "token_type_ids":
|
856 |
+
continue
|
857 |
+
batch[f"{k}{type_key}"] = tokens
|
858 |
+
|
859 |
+
else:
|
860 |
+
chosen_tokens = self.processing_class(
|
861 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
862 |
+
)
|
863 |
+
rejected_tokens = self.processing_class(
|
864 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
865 |
+
)
|
866 |
+
prompt_tokens = self.processing_class(
|
867 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
868 |
+
)
|
869 |
+
|
870 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
871 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
872 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
873 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
874 |
+
|
875 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
876 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
877 |
+
labels=torch.tensor(batch["rejected_labels"])
|
878 |
+
)
|
879 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
880 |
+
labels=torch.tensor(batch["chosen_labels"])
|
881 |
+
)
|
882 |
+
|
883 |
+
return batch
|
884 |
+
|
885 |
+
@staticmethod
|
886 |
+
def concatenated_inputs(
|
887 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
888 |
+
is_encoder_decoder: bool = False,
|
889 |
+
label_pad_token_id: int = -100,
|
890 |
+
padding_value: int = 0,
|
891 |
+
device: Optional[torch.device] = None,
|
892 |
+
) -> dict[str, torch.LongTensor]:
|
893 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
894 |
+
|
895 |
+
Args:
|
896 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
897 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
898 |
+
label_pad_token_id: The label pad token id.
|
899 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
900 |
+
device: The device for the concatenated inputs.
|
901 |
+
|
902 |
+
Returns:
|
903 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
904 |
+
"""
|
905 |
+
concatenated_batch = {}
|
906 |
+
|
907 |
+
if is_encoder_decoder:
|
908 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
909 |
+
else:
|
910 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
911 |
+
|
912 |
+
for k in batch:
|
913 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
914 |
+
if "labels" in k or is_encoder_decoder:
|
915 |
+
pad_value = label_pad_token_id
|
916 |
+
elif k.endswith("_input_ids"):
|
917 |
+
pad_value = padding_value
|
918 |
+
elif k.endswith("_attention_mask"):
|
919 |
+
pad_value = 0
|
920 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
921 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
922 |
+
for k in batch:
|
923 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
924 |
+
if "labels" in k or is_encoder_decoder:
|
925 |
+
pad_value = label_pad_token_id
|
926 |
+
elif k.endswith("_input_ids"):
|
927 |
+
pad_value = padding_value
|
928 |
+
elif k.endswith("_attention_mask"):
|
929 |
+
pad_value = 0
|
930 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
931 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
932 |
+
(
|
933 |
+
concatenated_batch[concatenated_key],
|
934 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
935 |
+
),
|
936 |
+
dim=0,
|
937 |
+
).to(device=device)
|
938 |
+
|
939 |
+
if is_encoder_decoder:
|
940 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
941 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
942 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
943 |
+
)
|
944 |
+
|
945 |
+
return concatenated_batch
|
946 |
+
|
947 |
+
def cpo_loss(
|
948 |
+
self,
|
949 |
+
policy_chosen_logps: torch.FloatTensor,
|
950 |
+
policy_rejected_logps: torch.FloatTensor,
|
951 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
952 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
956 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
957 |
+
|
958 |
+
Returns:
|
959 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
960 |
+
The losses tensor contains the CPO loss for each example in the batch.
|
961 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
962 |
+
"""
|
963 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
964 |
+
|
965 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
966 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
967 |
+
# calculates a conservative CPO loss.
|
968 |
+
|
969 |
+
if self.loss_type == "simpo":
|
970 |
+
gamma_logratios = self.simpo_gamma / self.beta
|
971 |
+
logits = logits - gamma_logratios
|
972 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
973 |
+
losses = (
|
974 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
975 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
976 |
+
)
|
977 |
+
elif self.loss_type == "sigmoid":
|
978 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
979 |
+
losses = (
|
980 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
981 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
982 |
+
)
|
983 |
+
elif self.loss_type == "hinge":
|
984 |
+
losses = torch.relu(1 - self.beta * logits)
|
985 |
+
elif self.loss_type == "ipo":
|
986 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
987 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
988 |
+
else:
|
989 |
+
raise ValueError(
|
990 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
991 |
+
)
|
992 |
+
|
993 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
994 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
995 |
+
|
996 |
+
return losses, chosen_rewards, rejected_rewards
|
997 |
+
|
998 |
+
@staticmethod
|
999 |
+
def get_batch_logps(
|
1000 |
+
logits: torch.FloatTensor,
|
1001 |
+
labels: torch.LongTensor,
|
1002 |
+
average_log_prob: bool = False,
|
1003 |
+
label_pad_token_id: int = -100,
|
1004 |
+
is_encoder_decoder: bool = False,
|
1005 |
+
) -> torch.FloatTensor:
|
1006 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1007 |
+
|
1008 |
+
Args:
|
1009 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1010 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1011 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1012 |
+
label_pad_token_id: The label pad token id.
|
1013 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
1014 |
+
|
1015 |
+
Returns:
|
1016 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1017 |
+
"""
|
1018 |
+
if logits.shape[:-1] != labels.shape:
|
1019 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1020 |
+
|
1021 |
+
if not is_encoder_decoder:
|
1022 |
+
labels = labels[:, 1:].clone()
|
1023 |
+
logits = logits[:, :-1, :]
|
1024 |
+
loss_mask = labels != label_pad_token_id
|
1025 |
+
|
1026 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1027 |
+
labels[labels == label_pad_token_id] = 0
|
1028 |
+
|
1029 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1030 |
+
|
1031 |
+
if average_log_prob:
|
1032 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1033 |
+
else:
|
1034 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1035 |
+
|
1036 |
+
def concatenated_forward(
|
1037 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1038 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1039 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1040 |
+
|
1041 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1042 |
+
"""
|
1043 |
+
concatenated_batch = self.concatenated_inputs(
|
1044 |
+
batch,
|
1045 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1046 |
+
label_pad_token_id=self.label_pad_token_id,
|
1047 |
+
padding_value=self.padding_value,
|
1048 |
+
device=self.accelerator.device,
|
1049 |
+
)
|
1050 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
1051 |
+
|
1052 |
+
model_kwargs = (
|
1053 |
+
{
|
1054 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1055 |
+
}
|
1056 |
+
if self.is_encoder_decoder
|
1057 |
+
else {}
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
if self.aux_loss_enabled:
|
1061 |
+
model_kwargs["output_router_logits"] = True
|
1062 |
+
|
1063 |
+
outputs = model(
|
1064 |
+
concatenated_batch["concatenated_input_ids"],
|
1065 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1066 |
+
use_cache=False,
|
1067 |
+
**model_kwargs,
|
1068 |
+
)
|
1069 |
+
all_logits = outputs.logits
|
1070 |
+
|
1071 |
+
def cross_entropy_loss(logits, labels):
|
1072 |
+
if not self.is_encoder_decoder:
|
1073 |
+
# Shift so that tokens < n predict n
|
1074 |
+
logits = logits[..., :-1, :].contiguous()
|
1075 |
+
labels = labels[..., 1:].contiguous()
|
1076 |
+
# Flatten the tokens
|
1077 |
+
loss_fct = nn.CrossEntropyLoss()
|
1078 |
+
logits = logits.view(-1, logits.shape[-1])
|
1079 |
+
labels = labels.view(-1)
|
1080 |
+
# Enable model parallelism
|
1081 |
+
labels = labels.to(logits.device)
|
1082 |
+
loss = loss_fct(logits, labels)
|
1083 |
+
return loss
|
1084 |
+
|
1085 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1086 |
+
|
1087 |
+
if self.cpo_alpha == 0:
|
1088 |
+
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
1089 |
+
else:
|
1090 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1091 |
+
|
1092 |
+
all_logps = self.get_batch_logps(
|
1093 |
+
all_logits,
|
1094 |
+
concatenated_batch["concatenated_labels"],
|
1095 |
+
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
1096 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1097 |
+
label_pad_token_id=self.label_pad_token_id,
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
chosen_logps = all_logps[:len_chosen]
|
1101 |
+
rejected_logps = all_logps[len_chosen:]
|
1102 |
+
|
1103 |
+
chosen_logits = all_logits[:len_chosen]
|
1104 |
+
rejected_logits = all_logits[len_chosen:]
|
1105 |
+
|
1106 |
+
if self.aux_loss_enabled:
|
1107 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
1108 |
+
|
1109 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
1110 |
+
|
1111 |
+
def get_batch_loss_metrics(
|
1112 |
+
self,
|
1113 |
+
model,
|
1114 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1115 |
+
train_eval: Literal["train", "eval"] = "train",
|
1116 |
+
):
|
1117 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
1118 |
+
metrics = {}
|
1119 |
+
|
1120 |
+
forward_output = self.concatenated_forward(model, batch)
|
1121 |
+
(
|
1122 |
+
policy_chosen_logps,
|
1123 |
+
policy_rejected_logps,
|
1124 |
+
policy_chosen_logits,
|
1125 |
+
policy_rejected_logits,
|
1126 |
+
policy_nll_loss,
|
1127 |
+
) = forward_output[:5]
|
1128 |
+
if self.aux_loss_enabled:
|
1129 |
+
aux_loss = forward_output[5]
|
1130 |
+
|
1131 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
1132 |
+
policy_chosen_logps,
|
1133 |
+
policy_rejected_logps,
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
1137 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1138 |
+
|
1139 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1140 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
1141 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
1142 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
1143 |
+
metrics[f"{prefix}rewards/margins"] = (
|
1144 |
+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
1145 |
+
)
|
1146 |
+
metrics[f"{prefix}logps/rejected"] = (
|
1147 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
1148 |
+
)
|
1149 |
+
metrics[f"{prefix}logps/chosen"] = (
|
1150 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
1151 |
+
)
|
1152 |
+
metrics[f"{prefix}logits/rejected"] = (
|
1153 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
|
1154 |
+
)
|
1155 |
+
metrics[f"{prefix}logits/chosen"] = (
|
1156 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
|
1157 |
+
)
|
1158 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
1159 |
+
|
1160 |
+
if self.aux_loss_enabled:
|
1161 |
+
loss += self.aux_loss_coef * aux_loss
|
1162 |
+
|
1163 |
+
return loss, metrics
|
1164 |
+
|
1165 |
+
def compute_loss(
|
1166 |
+
self,
|
1167 |
+
model: Union[PreTrainedModel, nn.Module],
|
1168 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1169 |
+
return_outputs=False,
|
1170 |
+
num_items_in_batch=None,
|
1171 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1172 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1173 |
+
|
1174 |
+
with compute_loss_context_manager:
|
1175 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1176 |
+
|
1177 |
+
# force log the metrics
|
1178 |
+
self.store_metrics(metrics, train_eval="train")
|
1179 |
+
|
1180 |
+
if return_outputs:
|
1181 |
+
return (loss, metrics)
|
1182 |
+
return loss
|
1183 |
+
|
1184 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1185 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1186 |
+
|
1187 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1188 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1189 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1190 |
+
|
1191 |
+
with generate_context_manager:
|
1192 |
+
policy_output = model.generate(
|
1193 |
+
input_ids=batch["prompt_input_ids"],
|
1194 |
+
attention_mask=batch["prompt_attention_mask"],
|
1195 |
+
max_length=self.max_length,
|
1196 |
+
do_sample=True,
|
1197 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1201 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1202 |
+
|
1203 |
+
return policy_output_decoded
|
1204 |
+
|
1205 |
+
def prediction_step(
|
1206 |
+
self,
|
1207 |
+
model: Union[PreTrainedModel, nn.Module],
|
1208 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1209 |
+
prediction_loss_only: bool,
|
1210 |
+
ignore_keys: Optional[list[str]] = None,
|
1211 |
+
):
|
1212 |
+
if ignore_keys is None:
|
1213 |
+
if hasattr(model, "config"):
|
1214 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1215 |
+
else:
|
1216 |
+
ignore_keys = []
|
1217 |
+
|
1218 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1219 |
+
|
1220 |
+
with torch.no_grad(), prediction_context_manager:
|
1221 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1222 |
+
|
1223 |
+
# force log the metrics
|
1224 |
+
self.store_metrics(metrics, train_eval="eval")
|
1225 |
+
|
1226 |
+
if prediction_loss_only:
|
1227 |
+
return (loss.detach(), None, None)
|
1228 |
+
|
1229 |
+
# logits for the chosen and rejected samples from model
|
1230 |
+
logits_dict = {
|
1231 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1232 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1233 |
+
}
|
1234 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1235 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1236 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1237 |
+
|
1238 |
+
return (loss.detach(), logits, labels)
|
1239 |
+
|
1240 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1241 |
+
for key, value in metrics.items():
|
1242 |
+
self._stored_metrics[train_eval][key].append(value)
|
1243 |
+
|
1244 |
+
def evaluation_loop(
|
1245 |
+
self,
|
1246 |
+
dataloader: DataLoader,
|
1247 |
+
description: str,
|
1248 |
+
prediction_loss_only: Optional[bool] = None,
|
1249 |
+
ignore_keys: Optional[list[str]] = None,
|
1250 |
+
metric_key_prefix: str = "eval",
|
1251 |
+
) -> EvalLoopOutput:
|
1252 |
+
"""
|
1253 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1254 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1255 |
+
|
1256 |
+
Works both with or without labels.
|
1257 |
+
"""
|
1258 |
+
|
1259 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1260 |
+
if self.generate_during_eval:
|
1261 |
+
# Generate random indices within the range of the total number of samples
|
1262 |
+
num_samples = len(dataloader.dataset)
|
1263 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1264 |
+
|
1265 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1266 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1267 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1268 |
+
random_batch = self._prepare_inputs(random_batch)
|
1269 |
+
|
1270 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1271 |
+
|
1272 |
+
table = pd.DataFrame(
|
1273 |
+
columns=["Prompt", "Policy"],
|
1274 |
+
data=[
|
1275 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1276 |
+
],
|
1277 |
+
)
|
1278 |
+
if "wandb" in self.args.report_to:
|
1279 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1280 |
+
|
1281 |
+
if "comet_ml" in self.args.report_to:
|
1282 |
+
log_table_to_comet_experiment(
|
1283 |
+
name="game_log.csv",
|
1284 |
+
table=table,
|
1285 |
+
)
|
1286 |
+
|
1287 |
+
# Base evaluation
|
1288 |
+
initial_output = super().evaluation_loop(
|
1289 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1290 |
+
)
|
1291 |
+
|
1292 |
+
return initial_output
|
1293 |
+
|
1294 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1295 |
+
"""
|
1296 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1297 |
+
|
1298 |
+
Args:
|
1299 |
+
logs (`dict[str, float]`):
|
1300 |
+
The values to log.
|
1301 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1302 |
+
Start time of the training.
|
1303 |
+
"""
|
1304 |
+
# logs either has 'loss' or 'eval_loss'
|
1305 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1306 |
+
# Add averaged stored metrics to logs
|
1307 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1308 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1309 |
+
del self._stored_metrics[train_eval]
|
1310 |
+
|
1311 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1312 |
+
return super().log(logs, start_time)
|
1313 |
+
else: # transformers<=4.46
|
1314 |
+
return super().log(logs)
|
1315 |
+
|
1316 |
+
def _shift_right(self, input_ids):
|
1317 |
+
if self.decoder_start_token_id is None:
|
1318 |
+
raise ValueError(
|
1319 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1320 |
+
)
|
1321 |
+
|
1322 |
+
# shift inputs to the right
|
1323 |
+
if is_torch_fx_proxy(input_ids):
|
1324 |
+
# Item assignment is not supported natively for proxies.
|
1325 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1326 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1327 |
+
else:
|
1328 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1329 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1330 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1331 |
+
|
1332 |
+
if self.pad_token_id is None:
|
1333 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1334 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1335 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1336 |
+
|
1337 |
+
return shifted_input_ids
|
1338 |
+
|
1339 |
+
def create_model_card(
|
1340 |
+
self,
|
1341 |
+
model_name: Optional[str] = None,
|
1342 |
+
dataset_name: Optional[str] = None,
|
1343 |
+
tags: Union[str, list[str], None] = None,
|
1344 |
+
):
|
1345 |
+
"""
|
1346 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1347 |
+
|
1348 |
+
Args:
|
1349 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1350 |
+
Name of the model.
|
1351 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1352 |
+
Name of the dataset used for training.
|
1353 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1354 |
+
Tags to be associated with the model card.
|
1355 |
+
"""
|
1356 |
+
if not self.is_world_process_zero():
|
1357 |
+
return
|
1358 |
+
|
1359 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1360 |
+
base_model = self.model.config._name_or_path
|
1361 |
+
else:
|
1362 |
+
base_model = None
|
1363 |
+
|
1364 |
+
tags = tags or []
|
1365 |
+
if isinstance(tags, str):
|
1366 |
+
tags = [tags]
|
1367 |
+
|
1368 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1369 |
+
tags.append("unsloth")
|
1370 |
+
|
1371 |
+
citation = textwrap.dedent("""\
|
1372 |
+
@inproceedings{xu2024contrastive,
|
1373 |
+
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
1374 |
+
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
1375 |
+
year = 2024,
|
1376 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
1377 |
+
publisher = {OpenReview.net},
|
1378 |
+
url = {https://openreview.net/forum?id=51iwkioZpn}
|
1379 |
+
}""")
|
1380 |
+
|
1381 |
+
model_card = generate_model_card(
|
1382 |
+
base_model=base_model,
|
1383 |
+
model_name=model_name,
|
1384 |
+
hub_model_id=self.hub_model_id,
|
1385 |
+
dataset_name=dataset_name,
|
1386 |
+
tags=tags,
|
1387 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1388 |
+
comet_url=get_comet_experiment_url(),
|
1389 |
+
trainer_name="CPO",
|
1390 |
+
trainer_citation=citation,
|
1391 |
+
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
1392 |
+
paper_id="2401.08417",
|
1393 |
+
)
|
1394 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1395 |
+
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
1396 |
+
"""
|
1397 |
+
|
1398 |
+
Initialize CPOTrainer.
|
1399 |
+
|
1400 |
+
Args:
|
1401 |
+
model (`transformers.PreTrainedModel`):
|
1402 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1403 |
+
args (`CPOConfig`):
|
1404 |
+
The CPO config arguments to use for training.
|
1405 |
+
data_collator (`transformers.DataCollator`):
|
1406 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1407 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1408 |
+
train_dataset (`datasets.Dataset`):
|
1409 |
+
The dataset to use for training.
|
1410 |
+
eval_dataset (`datasets.Dataset`):
|
1411 |
+
The dataset to use for evaluation.
|
1412 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1413 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1414 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1415 |
+
reuse the fine-tuned model.
|
1416 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1417 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1418 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1419 |
+
The callbacks to use for training.
|
1420 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1421 |
+
The optimizer and scheduler to use for training.
|
1422 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1423 |
+
The function to use to preprocess the logits before computing the metrics.
|
1424 |
+
peft_config (`dict`, defaults to `None`):
|
1425 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1426 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1427 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1428 |
+
a dictionary string to metric values.
|
1429 |
+
|
1430 |
+
"""
|
1431 |
+
def __init__(
|
1432 |
+
self,
|
1433 |
+
model = None,
|
1434 |
+
args = None,
|
1435 |
+
data_collator = None,
|
1436 |
+
train_dataset = None,
|
1437 |
+
eval_dataset = None,
|
1438 |
+
processing_class = None,
|
1439 |
+
model_init = None,
|
1440 |
+
callbacks = None,
|
1441 |
+
preprocess_logits_for_metrics = None,
|
1442 |
+
peft_config = None,
|
1443 |
+
compute_metrics = None,
|
1444 |
+
**kwargs
|
1445 |
+
):
|
1446 |
+
if args is None: args = UnslothCPOConfig()
|
1447 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1448 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1449 |
+
force_float32 = False
|
1450 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1451 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1452 |
+
force_float32 = True
|
1453 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1454 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1455 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1456 |
+
from unsloth_zoo.utils import _get_dtype
|
1457 |
+
dtype = _get_dtype(dtype)
|
1458 |
+
float16 = dtype == torch.float16
|
1459 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1460 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1461 |
+
if force_float32:
|
1462 |
+
args.fp16 = False
|
1463 |
+
args.bf16 = False
|
1464 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1465 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1466 |
+
args.fp16 = float16
|
1467 |
+
args.bf16 = not float16
|
1468 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1469 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1470 |
+
args.eval_strategy = 'steps'
|
1471 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1472 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1473 |
+
if ga_steps is not None and ga_steps > 1:
|
1474 |
+
from transformers import __version__ as transformers_version
|
1475 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1476 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1477 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1478 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1479 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1480 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1481 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1482 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1483 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1484 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1485 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1486 |
+
if force_float32:
|
1487 |
+
args.bf16_full_eval = False
|
1488 |
+
args.fp16_full_eval = False
|
1489 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1490 |
+
args.bf16_full_eval = True
|
1491 |
+
args.fp16_full_eval = False
|
1492 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1493 |
+
args.bf16_full_eval = args.bf16
|
1494 |
+
args.fp16_full_eval = args.fp16
|
1495 |
+
_output_logits = False
|
1496 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1497 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1498 |
+
if _output_logits:
|
1499 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1500 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1501 |
+
pass
|
1502 |
+
else:
|
1503 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1504 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1505 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1506 |
+
max_seq_length = model.max_seq_length
|
1507 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1508 |
+
if model is not None and hasattr(model, 'for_training'):
|
1509 |
+
model.for_training()
|
1510 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1511 |
+
if 'processing_class' in locals():
|
1512 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1513 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1514 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1515 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1516 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1517 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1518 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1519 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1520 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1521 |
+
else:
|
1522 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1523 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1524 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1525 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1526 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1527 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1528 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1529 |
+
else:
|
1530 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1531 |
+
other_metrics = []
|
1532 |
+
|
1533 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1534 |
+
PatchRLStatistics('cpo_trainer', other_metrics)
|
1535 |
+
|
1536 |
+
super().__init__(
|
1537 |
+
model = model,
|
1538 |
+
args = args,
|
1539 |
+
data_collator = data_collator,
|
1540 |
+
train_dataset = train_dataset,
|
1541 |
+
eval_dataset = eval_dataset,
|
1542 |
+
processing_class = processing_class,
|
1543 |
+
model_init = model_init,
|
1544 |
+
callbacks = callbacks,
|
1545 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1546 |
+
peft_config = peft_config,
|
1547 |
+
compute_metrics = compute_metrics,**kwargs)
|
1548 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1549 |
+
self.neftune_hook_handle.remove()
|
1550 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1551 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1552 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1553 |
+
pass
|
1554 |
+
|
1555 |
+
pass
|
unsloth_compiled_cache/UnslothDDPOTrainer.py
ADDED
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothDDPOConfig(DDPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`DDPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
+
Name of this experiment (by default is the file name without the extension name).
|
55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
56 |
+
Name of this run.
|
57 |
+
seed (`int`, *optional*, defaults to `0`):
|
58 |
+
Random seed.
|
59 |
+
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
60 |
+
Log with either 'wandb' or 'tensorboard', check
|
61 |
+
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
62 |
+
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
63 |
+
Keyword arguments for the tracker (e.g. wandb_project).
|
64 |
+
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
65 |
+
Keyword arguments for the accelerator.
|
66 |
+
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
67 |
+
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
68 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
69 |
+
Name of project to use for tracking.
|
70 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
71 |
+
Top-level logging directory for checkpoint saving.
|
72 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
73 |
+
Number of epochs to train.
|
74 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
75 |
+
Number of epochs between saving model checkpoints.
|
76 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
77 |
+
Number of checkpoints to keep before overwriting old ones.
|
78 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
79 |
+
Mixed precision training.
|
80 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
81 |
+
Allow `tf32` on Ampere GPUs.
|
82 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
83 |
+
Resume training from a checkpoint.
|
84 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
85 |
+
Number of sampler inference steps.
|
86 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
87 |
+
Eta parameter for the DDIM sampler.
|
88 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
89 |
+
Classifier-free guidance weight.
|
90 |
+
sample_batch_size (`int`, *optional*, defaults to `1`):
|
91 |
+
Batch size (per GPU) to use for sampling.
|
92 |
+
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
93 |
+
Number of batches to sample per epoch.
|
94 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
95 |
+
Batch size (per GPU) to use for training.
|
96 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
97 |
+
Use 8bit Adam optimizer from bitsandbytes.
|
98 |
+
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
99 |
+
Learning rate.
|
100 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
101 |
+
Adam beta1.
|
102 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
103 |
+
Adam beta2.
|
104 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
105 |
+
Adam weight decay.
|
106 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
107 |
+
Adam epsilon.
|
108 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
109 |
+
Number of gradient accumulation steps.
|
110 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
111 |
+
Maximum gradient norm for gradient clipping.
|
112 |
+
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
113 |
+
Number of inner epochs per outer epoch.
|
114 |
+
train_cfg (`bool`, *optional*, defaults to `True`):
|
115 |
+
Whether to use classifier-free guidance during training.
|
116 |
+
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
117 |
+
Clip advantages to the range.
|
118 |
+
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
119 |
+
PPO clip range.
|
120 |
+
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
121 |
+
Fraction of timesteps to train on.
|
122 |
+
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
123 |
+
Whether to track statistics for each prompt separately.
|
124 |
+
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
125 |
+
Number of reward values to store in the buffer for each prompt.
|
126 |
+
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
127 |
+
Minimum number of reward values to store in the buffer.
|
128 |
+
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
129 |
+
Whether to compute rewards asynchronously.
|
130 |
+
max_workers (`int`, *optional*, defaults to `2`):
|
131 |
+
Maximum number of workers to use for async reward computation.
|
132 |
+
negative_prompts (`str`, *optional*, defaults to `""`):
|
133 |
+
Comma-separated list of prompts to use as negative examples.
|
134 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
135 |
+
Whether to push the final model checkpoint to the Hub.
|
136 |
+
|
137 |
+
"""
|
138 |
+
vllm_sampling_params: Optional[Any] = field(
|
139 |
+
default = None,
|
140 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
141 |
+
)
|
142 |
+
unsloth_num_chunks : Optional[int] = field(
|
143 |
+
default = -1,
|
144 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
145 |
+
)
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
exp_name = 'main',
|
149 |
+
run_name = '',
|
150 |
+
seed = 3407,
|
151 |
+
log_with = None,
|
152 |
+
tracker_project_name = 'trl',
|
153 |
+
logdir = 'logs',
|
154 |
+
num_epochs = 100,
|
155 |
+
save_freq = 1,
|
156 |
+
num_checkpoint_limit = 5,
|
157 |
+
mixed_precision = 'fp16',
|
158 |
+
allow_tf32 = True,
|
159 |
+
resume_from = '',
|
160 |
+
sample_num_steps = 50,
|
161 |
+
sample_eta = 1.0,
|
162 |
+
sample_guidance_scale = 5.0,
|
163 |
+
sample_batch_size = 1,
|
164 |
+
sample_num_batches_per_epoch = 2,
|
165 |
+
train_batch_size = 1,
|
166 |
+
train_use_8bit_adam = False,
|
167 |
+
train_learning_rate = 5e-05,
|
168 |
+
train_adam_beta1 = 0.9,
|
169 |
+
train_adam_beta2 = 0.999,
|
170 |
+
train_adam_weight_decay = 0.01,
|
171 |
+
train_adam_epsilon = 1e-08,
|
172 |
+
train_gradient_accumulation_steps = 2,
|
173 |
+
train_max_grad_norm = 1.0,
|
174 |
+
train_num_inner_epochs = 1,
|
175 |
+
train_cfg = True,
|
176 |
+
train_adv_clip_max = 5.0,
|
177 |
+
train_clip_range = 0.0001,
|
178 |
+
train_timestep_fraction = 1.0,
|
179 |
+
per_prompt_stat_tracking = False,
|
180 |
+
per_prompt_stat_tracking_buffer_size = 16,
|
181 |
+
per_prompt_stat_tracking_min_count = 16,
|
182 |
+
async_reward_computation = False,
|
183 |
+
max_workers = 2,
|
184 |
+
negative_prompts = '',
|
185 |
+
push_to_hub = False,
|
186 |
+
vllm_sampling_params = None,
|
187 |
+
unsloth_num_chunks = -1,
|
188 |
+
**kwargs,
|
189 |
+
):
|
190 |
+
|
191 |
+
super().__init__(
|
192 |
+
exp_name = exp_name,
|
193 |
+
run_name = run_name,
|
194 |
+
seed = seed,
|
195 |
+
log_with = log_with,
|
196 |
+
tracker_project_name = tracker_project_name,
|
197 |
+
logdir = logdir,
|
198 |
+
num_epochs = num_epochs,
|
199 |
+
save_freq = save_freq,
|
200 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
201 |
+
mixed_precision = mixed_precision,
|
202 |
+
allow_tf32 = allow_tf32,
|
203 |
+
resume_from = resume_from,
|
204 |
+
sample_num_steps = sample_num_steps,
|
205 |
+
sample_eta = sample_eta,
|
206 |
+
sample_guidance_scale = sample_guidance_scale,
|
207 |
+
sample_batch_size = sample_batch_size,
|
208 |
+
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
209 |
+
train_batch_size = train_batch_size,
|
210 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
211 |
+
train_learning_rate = train_learning_rate,
|
212 |
+
train_adam_beta1 = train_adam_beta1,
|
213 |
+
train_adam_beta2 = train_adam_beta2,
|
214 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
215 |
+
train_adam_epsilon = train_adam_epsilon,
|
216 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
217 |
+
train_max_grad_norm = train_max_grad_norm,
|
218 |
+
train_num_inner_epochs = train_num_inner_epochs,
|
219 |
+
train_cfg = train_cfg,
|
220 |
+
train_adv_clip_max = train_adv_clip_max,
|
221 |
+
train_clip_range = train_clip_range,
|
222 |
+
train_timestep_fraction = train_timestep_fraction,
|
223 |
+
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
224 |
+
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
225 |
+
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
226 |
+
async_reward_computation = async_reward_computation,
|
227 |
+
max_workers = max_workers,
|
228 |
+
negative_prompts = negative_prompts,
|
229 |
+
push_to_hub = push_to_hub,**kwargs)
|
230 |
+
self.vllm_sampling_params = vllm_sampling_params
|
231 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
232 |
+
pass
|
233 |
+
|
234 |
+
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
235 |
+
""""""
|
236 |
+
|
237 |
+
_tag_names = ["trl", "ddpo"]
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
config: DDPOConfig,
|
242 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
243 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
244 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
245 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
246 |
+
):
|
247 |
+
if image_samples_hook is None:
|
248 |
+
warn("No image_samples_hook provided; no images will be logged")
|
249 |
+
|
250 |
+
self.prompt_fn = prompt_function
|
251 |
+
self.reward_fn = reward_function
|
252 |
+
self.config = config
|
253 |
+
self.image_samples_callback = image_samples_hook
|
254 |
+
|
255 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
256 |
+
|
257 |
+
if self.config.resume_from:
|
258 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
259 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
260 |
+
# get the most recent checkpoint in this directory
|
261 |
+
checkpoints = list(
|
262 |
+
filter(
|
263 |
+
lambda x: "checkpoint_" in x,
|
264 |
+
os.listdir(self.config.resume_from),
|
265 |
+
)
|
266 |
+
)
|
267 |
+
if len(checkpoints) == 0:
|
268 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
269 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
270 |
+
self.config.resume_from = os.path.join(
|
271 |
+
self.config.resume_from,
|
272 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
273 |
+
)
|
274 |
+
|
275 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
276 |
+
|
277 |
+
# number of timesteps within each trajectory to train on
|
278 |
+
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
279 |
+
|
280 |
+
self.accelerator = Accelerator(
|
281 |
+
log_with=self.config.log_with,
|
282 |
+
mixed_precision=self.config.mixed_precision,
|
283 |
+
project_config=accelerator_project_config,
|
284 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
285 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
286 |
+
# the total number of optimizer steps to accumulate across.
|
287 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
288 |
+
**self.config.accelerator_kwargs,
|
289 |
+
)
|
290 |
+
|
291 |
+
is_okay, message = self._config_check()
|
292 |
+
if not is_okay:
|
293 |
+
raise ValueError(message)
|
294 |
+
|
295 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
296 |
+
|
297 |
+
if self.accelerator.is_main_process:
|
298 |
+
self.accelerator.init_trackers(
|
299 |
+
self.config.tracker_project_name,
|
300 |
+
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
301 |
+
init_kwargs=self.config.tracker_kwargs,
|
302 |
+
)
|
303 |
+
|
304 |
+
logger.info(f"\n{config}")
|
305 |
+
|
306 |
+
set_seed(self.config.seed, device_specific=True)
|
307 |
+
|
308 |
+
self.sd_pipeline = sd_pipeline
|
309 |
+
|
310 |
+
self.sd_pipeline.set_progress_bar_config(
|
311 |
+
position=1,
|
312 |
+
disable=not self.accelerator.is_local_main_process,
|
313 |
+
leave=False,
|
314 |
+
desc="Timestep",
|
315 |
+
dynamic_ncols=True,
|
316 |
+
)
|
317 |
+
|
318 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
319 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
320 |
+
if self.accelerator.mixed_precision == "fp16":
|
321 |
+
inference_dtype = torch.float16
|
322 |
+
elif self.accelerator.mixed_precision == "bf16":
|
323 |
+
inference_dtype = torch.bfloat16
|
324 |
+
else:
|
325 |
+
inference_dtype = torch.float32
|
326 |
+
|
327 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
328 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
329 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
330 |
+
|
331 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
332 |
+
|
333 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
334 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
335 |
+
|
336 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
337 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
338 |
+
if self.config.allow_tf32:
|
339 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
340 |
+
|
341 |
+
self.optimizer = self._setup_optimizer(
|
342 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
343 |
+
)
|
344 |
+
|
345 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
346 |
+
self.sd_pipeline.tokenizer(
|
347 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
348 |
+
return_tensors="pt",
|
349 |
+
padding="max_length",
|
350 |
+
truncation=True,
|
351 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
352 |
+
).input_ids.to(self.accelerator.device)
|
353 |
+
)[0]
|
354 |
+
|
355 |
+
if config.per_prompt_stat_tracking:
|
356 |
+
self.stat_tracker = PerPromptStatTracker(
|
357 |
+
config.per_prompt_stat_tracking_buffer_size,
|
358 |
+
config.per_prompt_stat_tracking_min_count,
|
359 |
+
)
|
360 |
+
|
361 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
362 |
+
# more memory
|
363 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
364 |
+
|
365 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
366 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
367 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
368 |
+
else:
|
369 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
370 |
+
|
371 |
+
if self.config.async_reward_computation:
|
372 |
+
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
373 |
+
|
374 |
+
if config.resume_from:
|
375 |
+
logger.info(f"Resuming from {config.resume_from}")
|
376 |
+
self.accelerator.load_state(config.resume_from)
|
377 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
378 |
+
else:
|
379 |
+
self.first_epoch = 0
|
380 |
+
|
381 |
+
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
382 |
+
if not is_async:
|
383 |
+
rewards = []
|
384 |
+
for images, prompts, prompt_metadata in prompt_image_pairs:
|
385 |
+
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
386 |
+
rewards.append(
|
387 |
+
(
|
388 |
+
torch.as_tensor(reward, device=self.accelerator.device),
|
389 |
+
reward_metadata,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
394 |
+
rewards = [
|
395 |
+
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
396 |
+
for reward, reward_metadata in rewards
|
397 |
+
]
|
398 |
+
|
399 |
+
return zip(*rewards)
|
400 |
+
|
401 |
+
def step(self, epoch: int, global_step: int):
|
402 |
+
"""
|
403 |
+
Perform a single step of training.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
epoch (int): The current epoch.
|
407 |
+
global_step (int): The current global step.
|
408 |
+
|
409 |
+
Side Effects:
|
410 |
+
- Model weights are updated
|
411 |
+
- Logs the statistics to the accelerator trackers.
|
412 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
global_step (int): The updated global step.
|
416 |
+
|
417 |
+
"""
|
418 |
+
samples, prompt_image_data = self._generate_samples(
|
419 |
+
iterations=self.config.sample_num_batches_per_epoch,
|
420 |
+
batch_size=self.config.sample_batch_size,
|
421 |
+
)
|
422 |
+
|
423 |
+
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
424 |
+
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
425 |
+
rewards, rewards_metadata = self.compute_rewards(
|
426 |
+
prompt_image_data, is_async=self.config.async_reward_computation
|
427 |
+
)
|
428 |
+
|
429 |
+
for i, image_data in enumerate(prompt_image_data):
|
430 |
+
image_data.extend([rewards[i], rewards_metadata[i]])
|
431 |
+
|
432 |
+
if self.image_samples_callback is not None:
|
433 |
+
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
434 |
+
|
435 |
+
rewards = torch.cat(rewards)
|
436 |
+
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
437 |
+
|
438 |
+
self.accelerator.log(
|
439 |
+
{
|
440 |
+
"reward": rewards,
|
441 |
+
"epoch": epoch,
|
442 |
+
"reward_mean": rewards.mean(),
|
443 |
+
"reward_std": rewards.std(),
|
444 |
+
},
|
445 |
+
step=global_step,
|
446 |
+
)
|
447 |
+
|
448 |
+
if self.config.per_prompt_stat_tracking:
|
449 |
+
# gather the prompts across processes
|
450 |
+
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
451 |
+
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
452 |
+
advantages = self.stat_tracker.update(prompts, rewards)
|
453 |
+
else:
|
454 |
+
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
455 |
+
|
456 |
+
# ungather advantages; keep the entries corresponding to the samples on this process
|
457 |
+
samples["advantages"] = (
|
458 |
+
torch.as_tensor(advantages)
|
459 |
+
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
460 |
+
.to(self.accelerator.device)
|
461 |
+
)
|
462 |
+
|
463 |
+
del samples["prompt_ids"]
|
464 |
+
|
465 |
+
total_batch_size, num_timesteps = samples["timesteps"].shape
|
466 |
+
|
467 |
+
for inner_epoch in range(self.config.train_num_inner_epochs):
|
468 |
+
# shuffle samples along batch dimension
|
469 |
+
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
470 |
+
samples = {k: v[perm] for k, v in samples.items()}
|
471 |
+
|
472 |
+
# shuffle along time dimension independently for each sample
|
473 |
+
# still trying to understand the code below
|
474 |
+
perms = torch.stack(
|
475 |
+
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
476 |
+
)
|
477 |
+
|
478 |
+
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
479 |
+
samples[key] = samples[key][
|
480 |
+
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
481 |
+
perms,
|
482 |
+
]
|
483 |
+
|
484 |
+
original_keys = samples.keys()
|
485 |
+
original_values = samples.values()
|
486 |
+
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
487 |
+
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
488 |
+
|
489 |
+
# Transpose the list of original values
|
490 |
+
transposed_values = zip(*reshaped_values)
|
491 |
+
# Create new dictionaries for each row of transposed values
|
492 |
+
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
493 |
+
|
494 |
+
self.sd_pipeline.unet.train()
|
495 |
+
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
496 |
+
# ensure optimization step at the end of the inner epoch
|
497 |
+
if not self.accelerator.sync_gradients:
|
498 |
+
raise ValueError(
|
499 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
500 |
+
)
|
501 |
+
|
502 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
503 |
+
self.accelerator.save_state()
|
504 |
+
|
505 |
+
return global_step
|
506 |
+
|
507 |
+
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
508 |
+
"""
|
509 |
+
Calculate the loss for a batch of an unpacked sample
|
510 |
+
|
511 |
+
Args:
|
512 |
+
latents (torch.Tensor):
|
513 |
+
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
514 |
+
timesteps (torch.Tensor):
|
515 |
+
The timesteps sampled from the diffusion model, shape: [batch_size]
|
516 |
+
next_latents (torch.Tensor):
|
517 |
+
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
518 |
+
log_probs (torch.Tensor):
|
519 |
+
The log probabilities of the latents, shape: [batch_size]
|
520 |
+
advantages (torch.Tensor):
|
521 |
+
The advantages of the latents, shape: [batch_size]
|
522 |
+
embeds (torch.Tensor):
|
523 |
+
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
524 |
+
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
528 |
+
(all of these are of shape (1,))
|
529 |
+
"""
|
530 |
+
with self.autocast():
|
531 |
+
if self.config.train_cfg:
|
532 |
+
noise_pred = self.sd_pipeline.unet(
|
533 |
+
torch.cat([latents] * 2),
|
534 |
+
torch.cat([timesteps] * 2),
|
535 |
+
embeds,
|
536 |
+
).sample
|
537 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
538 |
+
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
539 |
+
noise_pred_text - noise_pred_uncond
|
540 |
+
)
|
541 |
+
else:
|
542 |
+
noise_pred = self.sd_pipeline.unet(
|
543 |
+
latents,
|
544 |
+
timesteps,
|
545 |
+
embeds,
|
546 |
+
).sample
|
547 |
+
# compute the log prob of next_latents given latents under the current model
|
548 |
+
|
549 |
+
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
550 |
+
noise_pred,
|
551 |
+
timesteps,
|
552 |
+
latents,
|
553 |
+
eta=self.config.sample_eta,
|
554 |
+
prev_sample=next_latents,
|
555 |
+
)
|
556 |
+
|
557 |
+
log_prob = scheduler_step_output.log_probs
|
558 |
+
|
559 |
+
advantages = torch.clamp(
|
560 |
+
advantages,
|
561 |
+
-self.config.train_adv_clip_max,
|
562 |
+
self.config.train_adv_clip_max,
|
563 |
+
)
|
564 |
+
|
565 |
+
ratio = torch.exp(log_prob - log_probs)
|
566 |
+
|
567 |
+
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
568 |
+
|
569 |
+
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
570 |
+
|
571 |
+
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
572 |
+
|
573 |
+
return loss, approx_kl, clipfrac
|
574 |
+
|
575 |
+
def loss(
|
576 |
+
self,
|
577 |
+
advantages: torch.Tensor,
|
578 |
+
clip_range: float,
|
579 |
+
ratio: torch.Tensor,
|
580 |
+
):
|
581 |
+
unclipped_loss = -advantages * ratio
|
582 |
+
clipped_loss = -advantages * torch.clamp(
|
583 |
+
ratio,
|
584 |
+
1.0 - clip_range,
|
585 |
+
1.0 + clip_range,
|
586 |
+
)
|
587 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
588 |
+
|
589 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
590 |
+
if self.config.train_use_8bit_adam:
|
591 |
+
import bitsandbytes
|
592 |
+
|
593 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
594 |
+
else:
|
595 |
+
optimizer_cls = torch.optim.AdamW
|
596 |
+
|
597 |
+
return optimizer_cls(
|
598 |
+
trainable_layers_parameters,
|
599 |
+
lr=self.config.train_learning_rate,
|
600 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
601 |
+
weight_decay=self.config.train_adam_weight_decay,
|
602 |
+
eps=self.config.train_adam_epsilon,
|
603 |
+
)
|
604 |
+
|
605 |
+
def _save_model_hook(self, models, weights, output_dir):
|
606 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
607 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
608 |
+
|
609 |
+
def _load_model_hook(self, models, input_dir):
|
610 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
611 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
612 |
+
|
613 |
+
def _generate_samples(self, iterations, batch_size):
|
614 |
+
"""
|
615 |
+
Generate samples from the model
|
616 |
+
|
617 |
+
Args:
|
618 |
+
iterations (int): Number of iterations to generate samples for
|
619 |
+
batch_size (int): Batch size to use for sampling
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
623 |
+
"""
|
624 |
+
samples = []
|
625 |
+
prompt_image_pairs = []
|
626 |
+
self.sd_pipeline.unet.eval()
|
627 |
+
|
628 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
629 |
+
|
630 |
+
for _ in range(iterations):
|
631 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
632 |
+
|
633 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
634 |
+
prompts,
|
635 |
+
return_tensors="pt",
|
636 |
+
padding="max_length",
|
637 |
+
truncation=True,
|
638 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
639 |
+
).input_ids.to(self.accelerator.device)
|
640 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
641 |
+
|
642 |
+
with self.autocast():
|
643 |
+
sd_output = self.sd_pipeline(
|
644 |
+
prompt_embeds=prompt_embeds,
|
645 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
646 |
+
num_inference_steps=self.config.sample_num_steps,
|
647 |
+
guidance_scale=self.config.sample_guidance_scale,
|
648 |
+
eta=self.config.sample_eta,
|
649 |
+
output_type="pt",
|
650 |
+
)
|
651 |
+
|
652 |
+
images = sd_output.images
|
653 |
+
latents = sd_output.latents
|
654 |
+
log_probs = sd_output.log_probs
|
655 |
+
|
656 |
+
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
657 |
+
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
658 |
+
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
659 |
+
|
660 |
+
samples.append(
|
661 |
+
{
|
662 |
+
"prompt_ids": prompt_ids,
|
663 |
+
"prompt_embeds": prompt_embeds,
|
664 |
+
"timesteps": timesteps,
|
665 |
+
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
666 |
+
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
667 |
+
"log_probs": log_probs,
|
668 |
+
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
669 |
+
}
|
670 |
+
)
|
671 |
+
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
672 |
+
|
673 |
+
return samples, prompt_image_pairs
|
674 |
+
|
675 |
+
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
676 |
+
"""
|
677 |
+
Train on a batch of samples. Main training segment
|
678 |
+
|
679 |
+
Args:
|
680 |
+
inner_epoch (int): The current inner epoch
|
681 |
+
epoch (int): The current epoch
|
682 |
+
global_step (int): The current global step
|
683 |
+
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
684 |
+
|
685 |
+
Side Effects:
|
686 |
+
- Model weights are updated
|
687 |
+
- Logs the statistics to the accelerator trackers.
|
688 |
+
|
689 |
+
Returns:
|
690 |
+
global_step (int): The updated global step
|
691 |
+
"""
|
692 |
+
info = defaultdict(list)
|
693 |
+
for _i, sample in enumerate(batched_samples):
|
694 |
+
if self.config.train_cfg:
|
695 |
+
# concat negative prompts to sample prompts to avoid two forward passes
|
696 |
+
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
697 |
+
else:
|
698 |
+
embeds = sample["prompt_embeds"]
|
699 |
+
|
700 |
+
for j in range(self.num_train_timesteps):
|
701 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
702 |
+
loss, approx_kl, clipfrac = self.calculate_loss(
|
703 |
+
sample["latents"][:, j],
|
704 |
+
sample["timesteps"][:, j],
|
705 |
+
sample["next_latents"][:, j],
|
706 |
+
sample["log_probs"][:, j],
|
707 |
+
sample["advantages"],
|
708 |
+
embeds,
|
709 |
+
)
|
710 |
+
info["approx_kl"].append(approx_kl)
|
711 |
+
info["clipfrac"].append(clipfrac)
|
712 |
+
info["loss"].append(loss)
|
713 |
+
|
714 |
+
self.accelerator.backward(loss)
|
715 |
+
if self.accelerator.sync_gradients:
|
716 |
+
self.accelerator.clip_grad_norm_(
|
717 |
+
self.trainable_layers.parameters()
|
718 |
+
if not isinstance(self.trainable_layers, list)
|
719 |
+
else self.trainable_layers,
|
720 |
+
self.config.train_max_grad_norm,
|
721 |
+
)
|
722 |
+
self.optimizer.step()
|
723 |
+
self.optimizer.zero_grad()
|
724 |
+
|
725 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
726 |
+
if self.accelerator.sync_gradients:
|
727 |
+
# log training-related stuff
|
728 |
+
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
729 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
730 |
+
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
731 |
+
self.accelerator.log(info, step=global_step)
|
732 |
+
global_step += 1
|
733 |
+
info = defaultdict(list)
|
734 |
+
return global_step
|
735 |
+
|
736 |
+
def _config_check(self) -> tuple[bool, str]:
|
737 |
+
samples_per_epoch = (
|
738 |
+
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
739 |
+
)
|
740 |
+
total_train_batch_size = (
|
741 |
+
self.config.train_batch_size
|
742 |
+
* self.accelerator.num_processes
|
743 |
+
* self.config.train_gradient_accumulation_steps
|
744 |
+
)
|
745 |
+
|
746 |
+
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
747 |
+
return (
|
748 |
+
False,
|
749 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
750 |
+
)
|
751 |
+
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
752 |
+
return (
|
753 |
+
False,
|
754 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
755 |
+
)
|
756 |
+
if not samples_per_epoch % total_train_batch_size == 0:
|
757 |
+
return (
|
758 |
+
False,
|
759 |
+
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
760 |
+
)
|
761 |
+
return True, ""
|
762 |
+
|
763 |
+
def train(self, epochs: Optional[int] = None):
|
764 |
+
"""
|
765 |
+
Train the model for a given number of epochs
|
766 |
+
"""
|
767 |
+
global_step = 0
|
768 |
+
if epochs is None:
|
769 |
+
epochs = self.config.num_epochs
|
770 |
+
for epoch in range(self.first_epoch, epochs):
|
771 |
+
global_step = self.step(epoch, global_step)
|
772 |
+
|
773 |
+
def _save_pretrained(self, save_directory):
|
774 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
775 |
+
self.create_model_card()
|
776 |
+
|
777 |
+
def create_model_card(
|
778 |
+
self,
|
779 |
+
model_name: Optional[str] = None,
|
780 |
+
dataset_name: Optional[str] = None,
|
781 |
+
tags: Union[str, list[str], None] = None,
|
782 |
+
):
|
783 |
+
"""
|
784 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
788 |
+
Name of the model.
|
789 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
790 |
+
Name of the dataset used for training.
|
791 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
792 |
+
Tags to be associated with the model card.
|
793 |
+
"""
|
794 |
+
if not self.is_world_process_zero():
|
795 |
+
return
|
796 |
+
|
797 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
798 |
+
base_model = self.model.config._name_or_path
|
799 |
+
else:
|
800 |
+
base_model = None
|
801 |
+
|
802 |
+
tags = tags or []
|
803 |
+
if isinstance(tags, str):
|
804 |
+
tags = [tags]
|
805 |
+
|
806 |
+
if hasattr(self.model.config, "unsloth_version"):
|
807 |
+
tags.append("unsloth")
|
808 |
+
|
809 |
+
citation = textwrap.dedent("""\
|
810 |
+
@inproceedings{black2024training,
|
811 |
+
title = {{Training Diffusion Models with Reinforcement Learning}},
|
812 |
+
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
813 |
+
year = 2024,
|
814 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
815 |
+
publisher = {OpenReview.net},
|
816 |
+
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
817 |
+
}""")
|
818 |
+
|
819 |
+
model_card = generate_model_card(
|
820 |
+
base_model=base_model,
|
821 |
+
model_name=model_name,
|
822 |
+
hub_model_id=self.hub_model_id,
|
823 |
+
dataset_name=dataset_name,
|
824 |
+
tags=tags,
|
825 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
826 |
+
comet_url=get_comet_experiment_url(),
|
827 |
+
trainer_name="DDPO",
|
828 |
+
trainer_citation=citation,
|
829 |
+
paper_title="Training Diffusion Models with Reinforcement Learning",
|
830 |
+
paper_id="2305.13301",
|
831 |
+
)
|
832 |
+
|
833 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
834 |
+
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
835 |
+
"""
|
836 |
+
|
837 |
+
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
838 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
839 |
+
As of now only Stable Diffusion based pipelines are supported
|
840 |
+
|
841 |
+
Attributes:
|
842 |
+
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
843 |
+
details.
|
844 |
+
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
845 |
+
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
846 |
+
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
847 |
+
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
848 |
+
|
849 |
+
"""
|
850 |
+
def __init__(
|
851 |
+
self,
|
852 |
+
config,
|
853 |
+
reward_function,
|
854 |
+
prompt_function,
|
855 |
+
sd_pipeline,
|
856 |
+
image_samples_hook = None,
|
857 |
+
**kwargs
|
858 |
+
):
|
859 |
+
if args is None: args = UnslothDDPOConfig()
|
860 |
+
other_metrics = []
|
861 |
+
|
862 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
863 |
+
PatchRLStatistics('ddpo_trainer', other_metrics)
|
864 |
+
|
865 |
+
super().__init__(
|
866 |
+
config = config,
|
867 |
+
reward_function = reward_function,
|
868 |
+
prompt_function = prompt_function,
|
869 |
+
sd_pipeline = sd_pipeline,
|
870 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
871 |
+
|
872 |
+
pass
|
unsloth_compiled_cache/UnslothDPOTrainer.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
unsloth_compiled_cache/UnslothGKDTrainer.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothGKDConfig(GKDConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for [`GKDTrainer`].
|
47 |
+
|
48 |
+
Args:
|
49 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
50 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
51 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
52 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
53 |
+
student-generated outputs).
|
54 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
55 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
56 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
57 |
+
max_new_tokens (`int`, *optional*, defaults to `128`):
|
58 |
+
Maximum number of tokens to generate per completion.
|
59 |
+
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
61 |
+
being trained.
|
62 |
+
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
63 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
64 |
+
from a string.
|
65 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether to disable dropout in the model.
|
67 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
68 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
69 |
+
on teacher-generated output).
|
70 |
+
|
71 |
+
"""
|
72 |
+
vllm_sampling_params: Optional[Any] = field(
|
73 |
+
default = None,
|
74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
+
)
|
76 |
+
unsloth_num_chunks : Optional[int] = field(
|
77 |
+
default = -1,
|
78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
+
)
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
output_dir = None,
|
83 |
+
overwrite_output_dir = None,
|
84 |
+
do_train = False,
|
85 |
+
do_eval = False,
|
86 |
+
do_predict = False,
|
87 |
+
eval_strategy = 'no',
|
88 |
+
prediction_loss_only = False,
|
89 |
+
per_device_train_batch_size = 4,
|
90 |
+
per_device_eval_batch_size = 4,
|
91 |
+
per_gpu_train_batch_size = None,
|
92 |
+
per_gpu_eval_batch_size = None,
|
93 |
+
gradient_accumulation_steps = 2,
|
94 |
+
eval_accumulation_steps = 2,
|
95 |
+
eval_delay = 0,
|
96 |
+
torch_empty_cache_steps = 250,
|
97 |
+
learning_rate = 5e-05,
|
98 |
+
weight_decay = 0.01,
|
99 |
+
adam_beta1 = 0.9,
|
100 |
+
adam_beta2 = 0.999,
|
101 |
+
adam_epsilon = 1e-08,
|
102 |
+
max_grad_norm = 1.0,
|
103 |
+
num_train_epochs = 3.0,
|
104 |
+
max_steps = -1,
|
105 |
+
lr_scheduler_type = 'linear',
|
106 |
+
warmup_ratio = 0.1,
|
107 |
+
warmup_steps = 0,
|
108 |
+
log_level = 'passive',
|
109 |
+
log_level_replica = 'warning',
|
110 |
+
log_on_each_node = True,
|
111 |
+
logging_dir = None,
|
112 |
+
logging_strategy = 'steps',
|
113 |
+
logging_first_step = False,
|
114 |
+
logging_steps = 1,
|
115 |
+
logging_nan_inf_filter = False,
|
116 |
+
save_strategy = 'steps',
|
117 |
+
save_steps = 500,
|
118 |
+
save_total_limit = None,
|
119 |
+
save_safetensors = True,
|
120 |
+
save_on_each_node = False,
|
121 |
+
save_only_model = False,
|
122 |
+
restore_callback_states_from_checkpoint = False,
|
123 |
+
no_cuda = False,
|
124 |
+
use_cpu = False,
|
125 |
+
use_mps_device = False,
|
126 |
+
seed = 3407,
|
127 |
+
data_seed = 3407,
|
128 |
+
jit_mode_eval = False,
|
129 |
+
use_ipex = False,
|
130 |
+
bf16 = False,
|
131 |
+
fp16 = False,
|
132 |
+
fp16_opt_level = 'O1',
|
133 |
+
half_precision_backend = 'auto',
|
134 |
+
bf16_full_eval = False,
|
135 |
+
fp16_full_eval = False,
|
136 |
+
tf32 = None,
|
137 |
+
local_rank = -1,
|
138 |
+
ddp_backend = None,
|
139 |
+
tpu_num_cores = None,
|
140 |
+
tpu_metrics_debug = False,
|
141 |
+
debug = '',
|
142 |
+
dataloader_drop_last = False,
|
143 |
+
eval_steps = None,
|
144 |
+
dataloader_num_workers = 0,
|
145 |
+
dataloader_prefetch_factor = None,
|
146 |
+
past_index = -1,
|
147 |
+
run_name = None,
|
148 |
+
disable_tqdm = None,
|
149 |
+
remove_unused_columns = True,
|
150 |
+
label_names = None,
|
151 |
+
load_best_model_at_end = False,
|
152 |
+
metric_for_best_model = None,
|
153 |
+
greater_is_better = None,
|
154 |
+
ignore_data_skip = False,
|
155 |
+
fsdp = '',
|
156 |
+
fsdp_min_num_params = 0,
|
157 |
+
fsdp_config = None,
|
158 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
159 |
+
accelerator_config = None,
|
160 |
+
deepspeed = None,
|
161 |
+
label_smoothing_factor = 0.0,
|
162 |
+
optim = 'adamw_8bit',
|
163 |
+
optim_args = None,
|
164 |
+
adafactor = False,
|
165 |
+
group_by_length = False,
|
166 |
+
length_column_name = 'length',
|
167 |
+
report_to = None,
|
168 |
+
ddp_find_unused_parameters = None,
|
169 |
+
ddp_bucket_cap_mb = None,
|
170 |
+
ddp_broadcast_buffers = None,
|
171 |
+
dataloader_pin_memory = True,
|
172 |
+
dataloader_persistent_workers = False,
|
173 |
+
skip_memory_metrics = True,
|
174 |
+
use_legacy_prediction_loop = False,
|
175 |
+
push_to_hub = False,
|
176 |
+
resume_from_checkpoint = None,
|
177 |
+
hub_model_id = None,
|
178 |
+
hub_strategy = 'every_save',
|
179 |
+
hub_token = None,
|
180 |
+
hub_private_repo = None,
|
181 |
+
hub_always_push = False,
|
182 |
+
gradient_checkpointing = False,
|
183 |
+
gradient_checkpointing_kwargs = None,
|
184 |
+
include_inputs_for_metrics = False,
|
185 |
+
eval_do_concat_batches = True,
|
186 |
+
fp16_backend = 'auto',
|
187 |
+
evaluation_strategy = None,
|
188 |
+
push_to_hub_model_id = None,
|
189 |
+
push_to_hub_organization = None,
|
190 |
+
push_to_hub_token = None,
|
191 |
+
mp_parameters = '',
|
192 |
+
auto_find_batch_size = False,
|
193 |
+
full_determinism = False,
|
194 |
+
torchdynamo = None,
|
195 |
+
ray_scope = 'last',
|
196 |
+
ddp_timeout = 1800,
|
197 |
+
torch_compile = False,
|
198 |
+
torch_compile_backend = None,
|
199 |
+
torch_compile_mode = None,
|
200 |
+
dispatch_batches = None,
|
201 |
+
split_batches = None,
|
202 |
+
include_tokens_per_second = False,
|
203 |
+
include_num_input_tokens_seen = False,
|
204 |
+
neftune_noise_alpha = None,
|
205 |
+
optim_target_modules = None,
|
206 |
+
batch_eval_metrics = False,
|
207 |
+
eval_on_start = False,
|
208 |
+
use_liger_kernel = False,
|
209 |
+
eval_use_gather_object = False,
|
210 |
+
average_tokens_across_devices = False,
|
211 |
+
model_init_kwargs = None,
|
212 |
+
use_liger = False,
|
213 |
+
dataset_text_field = 'text',
|
214 |
+
dataset_kwargs = None,
|
215 |
+
dataset_num_proc = None,
|
216 |
+
max_seq_length = None,
|
217 |
+
packing = False,
|
218 |
+
eval_packing = None,
|
219 |
+
dataset_batch_size = None,
|
220 |
+
num_of_sequences = None,
|
221 |
+
chars_per_token = None,
|
222 |
+
temperature = 0.9,
|
223 |
+
lmbda = 0.5,
|
224 |
+
beta = 0.5,
|
225 |
+
max_new_tokens = 128,
|
226 |
+
teacher_model_name_or_path = None,
|
227 |
+
teacher_model_init_kwargs = None,
|
228 |
+
disable_dropout = True,
|
229 |
+
seq_kd = False,
|
230 |
+
vllm_sampling_params = None,
|
231 |
+
unsloth_num_chunks = -1,
|
232 |
+
**kwargs,
|
233 |
+
):
|
234 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
235 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
236 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
237 |
+
output_dir = 'unsloth_training_checkpoints'
|
238 |
+
save_strategy = 'no'
|
239 |
+
if dataset_num_proc is None:
|
240 |
+
from multiprocessing import cpu_count
|
241 |
+
dataset_num_proc = cpu_count()
|
242 |
+
|
243 |
+
super().__init__(
|
244 |
+
output_dir = output_dir,
|
245 |
+
overwrite_output_dir = overwrite_output_dir,
|
246 |
+
do_train = do_train,
|
247 |
+
do_eval = do_eval,
|
248 |
+
do_predict = do_predict,
|
249 |
+
eval_strategy = eval_strategy,
|
250 |
+
prediction_loss_only = prediction_loss_only,
|
251 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
252 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
253 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
254 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
255 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
256 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
257 |
+
eval_delay = eval_delay,
|
258 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
259 |
+
learning_rate = learning_rate,
|
260 |
+
weight_decay = weight_decay,
|
261 |
+
adam_beta1 = adam_beta1,
|
262 |
+
adam_beta2 = adam_beta2,
|
263 |
+
adam_epsilon = adam_epsilon,
|
264 |
+
max_grad_norm = max_grad_norm,
|
265 |
+
num_train_epochs = num_train_epochs,
|
266 |
+
max_steps = max_steps,
|
267 |
+
lr_scheduler_type = lr_scheduler_type,
|
268 |
+
warmup_ratio = warmup_ratio,
|
269 |
+
warmup_steps = warmup_steps,
|
270 |
+
log_level = log_level,
|
271 |
+
log_level_replica = log_level_replica,
|
272 |
+
log_on_each_node = log_on_each_node,
|
273 |
+
logging_dir = logging_dir,
|
274 |
+
logging_strategy = logging_strategy,
|
275 |
+
logging_first_step = logging_first_step,
|
276 |
+
logging_steps = logging_steps,
|
277 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
278 |
+
save_strategy = save_strategy,
|
279 |
+
save_steps = save_steps,
|
280 |
+
save_total_limit = save_total_limit,
|
281 |
+
save_safetensors = save_safetensors,
|
282 |
+
save_on_each_node = save_on_each_node,
|
283 |
+
save_only_model = save_only_model,
|
284 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
285 |
+
no_cuda = no_cuda,
|
286 |
+
use_cpu = use_cpu,
|
287 |
+
use_mps_device = use_mps_device,
|
288 |
+
seed = seed,
|
289 |
+
data_seed = data_seed,
|
290 |
+
jit_mode_eval = jit_mode_eval,
|
291 |
+
use_ipex = use_ipex,
|
292 |
+
bf16 = bf16,
|
293 |
+
fp16 = fp16,
|
294 |
+
fp16_opt_level = fp16_opt_level,
|
295 |
+
half_precision_backend = half_precision_backend,
|
296 |
+
bf16_full_eval = bf16_full_eval,
|
297 |
+
fp16_full_eval = fp16_full_eval,
|
298 |
+
tf32 = tf32,
|
299 |
+
local_rank = local_rank,
|
300 |
+
ddp_backend = ddp_backend,
|
301 |
+
tpu_num_cores = tpu_num_cores,
|
302 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
303 |
+
debug = debug,
|
304 |
+
dataloader_drop_last = dataloader_drop_last,
|
305 |
+
eval_steps = eval_steps,
|
306 |
+
dataloader_num_workers = dataloader_num_workers,
|
307 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
308 |
+
past_index = past_index,
|
309 |
+
run_name = run_name,
|
310 |
+
disable_tqdm = disable_tqdm,
|
311 |
+
remove_unused_columns = remove_unused_columns,
|
312 |
+
label_names = label_names,
|
313 |
+
load_best_model_at_end = load_best_model_at_end,
|
314 |
+
metric_for_best_model = metric_for_best_model,
|
315 |
+
greater_is_better = greater_is_better,
|
316 |
+
ignore_data_skip = ignore_data_skip,
|
317 |
+
fsdp = fsdp,
|
318 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
319 |
+
fsdp_config = fsdp_config,
|
320 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
321 |
+
accelerator_config = accelerator_config,
|
322 |
+
deepspeed = deepspeed,
|
323 |
+
label_smoothing_factor = label_smoothing_factor,
|
324 |
+
optim = optim,
|
325 |
+
optim_args = optim_args,
|
326 |
+
adafactor = adafactor,
|
327 |
+
group_by_length = group_by_length,
|
328 |
+
length_column_name = length_column_name,
|
329 |
+
report_to = report_to,
|
330 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
331 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
332 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
333 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
334 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
335 |
+
skip_memory_metrics = skip_memory_metrics,
|
336 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
337 |
+
push_to_hub = push_to_hub,
|
338 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
339 |
+
hub_model_id = hub_model_id,
|
340 |
+
hub_strategy = hub_strategy,
|
341 |
+
hub_token = hub_token,
|
342 |
+
hub_private_repo = hub_private_repo,
|
343 |
+
hub_always_push = hub_always_push,
|
344 |
+
gradient_checkpointing = gradient_checkpointing,
|
345 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
346 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
347 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
348 |
+
fp16_backend = fp16_backend,
|
349 |
+
evaluation_strategy = evaluation_strategy,
|
350 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
351 |
+
push_to_hub_organization = push_to_hub_organization,
|
352 |
+
push_to_hub_token = push_to_hub_token,
|
353 |
+
mp_parameters = mp_parameters,
|
354 |
+
auto_find_batch_size = auto_find_batch_size,
|
355 |
+
full_determinism = full_determinism,
|
356 |
+
torchdynamo = torchdynamo,
|
357 |
+
ray_scope = ray_scope,
|
358 |
+
ddp_timeout = ddp_timeout,
|
359 |
+
torch_compile = torch_compile,
|
360 |
+
torch_compile_backend = torch_compile_backend,
|
361 |
+
torch_compile_mode = torch_compile_mode,
|
362 |
+
dispatch_batches = dispatch_batches,
|
363 |
+
split_batches = split_batches,
|
364 |
+
include_tokens_per_second = include_tokens_per_second,
|
365 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
366 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
367 |
+
optim_target_modules = optim_target_modules,
|
368 |
+
batch_eval_metrics = batch_eval_metrics,
|
369 |
+
eval_on_start = eval_on_start,
|
370 |
+
use_liger_kernel = use_liger_kernel,
|
371 |
+
eval_use_gather_object = eval_use_gather_object,
|
372 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
373 |
+
model_init_kwargs = model_init_kwargs,
|
374 |
+
use_liger = use_liger,
|
375 |
+
dataset_text_field = dataset_text_field,
|
376 |
+
dataset_kwargs = dataset_kwargs,
|
377 |
+
dataset_num_proc = dataset_num_proc,
|
378 |
+
max_seq_length = max_seq_length,
|
379 |
+
packing = packing,
|
380 |
+
eval_packing = eval_packing,
|
381 |
+
dataset_batch_size = dataset_batch_size,
|
382 |
+
num_of_sequences = num_of_sequences,
|
383 |
+
chars_per_token = chars_per_token,
|
384 |
+
temperature = temperature,
|
385 |
+
lmbda = lmbda,
|
386 |
+
beta = beta,
|
387 |
+
max_new_tokens = max_new_tokens,
|
388 |
+
teacher_model_name_or_path = teacher_model_name_or_path,
|
389 |
+
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
390 |
+
disable_dropout = disable_dropout,
|
391 |
+
seq_kd = seq_kd,**kwargs)
|
392 |
+
self.vllm_sampling_params = vllm_sampling_params
|
393 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
394 |
+
pass
|
395 |
+
|
396 |
+
class _UnslothGKDTrainer(SFTTrainer):
|
397 |
+
_tag_names = ["trl", "gkd"]
|
398 |
+
|
399 |
+
def __init__(
|
400 |
+
self,
|
401 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
402 |
+
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
403 |
+
args: Optional[GKDConfig] = None,
|
404 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
405 |
+
train_dataset: Optional[Dataset] = None,
|
406 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
407 |
+
processing_class: Optional[
|
408 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
409 |
+
] = None,
|
410 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
411 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
412 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
413 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
414 |
+
peft_config: Optional["PeftConfig"] = None,
|
415 |
+
formatting_func: Optional[Callable] = None,
|
416 |
+
):
|
417 |
+
# add remove_unused_columns=False to the dataclass args
|
418 |
+
args.remove_unused_columns = False
|
419 |
+
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
|
420 |
+
|
421 |
+
super().__init__(
|
422 |
+
model,
|
423 |
+
args=args,
|
424 |
+
data_collator=data_collator,
|
425 |
+
train_dataset=train_dataset,
|
426 |
+
eval_dataset=eval_dataset,
|
427 |
+
processing_class=processing_class,
|
428 |
+
compute_metrics=compute_metrics,
|
429 |
+
callbacks=callbacks,
|
430 |
+
optimizers=optimizers,
|
431 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
432 |
+
peft_config=peft_config,
|
433 |
+
formatting_func=formatting_func,
|
434 |
+
)
|
435 |
+
|
436 |
+
if args.teacher_model_init_kwargs is None:
|
437 |
+
teacher_model_init_kwargs = {}
|
438 |
+
elif not isinstance(teacher_model, str):
|
439 |
+
raise ValueError(
|
440 |
+
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
444 |
+
teacher_model_init_kwargs["torch_dtype"] = (
|
445 |
+
teacher_model_init_kwargs["torch_dtype"]
|
446 |
+
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
447 |
+
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
448 |
+
)
|
449 |
+
|
450 |
+
if isinstance(teacher_model, str):
|
451 |
+
if args.use_liger:
|
452 |
+
teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
453 |
+
else:
|
454 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
455 |
+
|
456 |
+
# Disable dropout in the model
|
457 |
+
if args.disable_dropout:
|
458 |
+
disable_dropout_in_model(self.model)
|
459 |
+
|
460 |
+
if self.is_deepspeed_enabled:
|
461 |
+
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
462 |
+
else:
|
463 |
+
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
464 |
+
|
465 |
+
self.lmbda = args.lmbda
|
466 |
+
self.beta = args.beta
|
467 |
+
self.temperature = args.temperature
|
468 |
+
self.seq_kd = args.seq_kd
|
469 |
+
|
470 |
+
self.generation_config = GenerationConfig(
|
471 |
+
max_new_tokens=args.max_new_tokens,
|
472 |
+
temperature=args.temperature,
|
473 |
+
do_sample=True,
|
474 |
+
top_k=0,
|
475 |
+
use_cache=False if args.gradient_checkpointing else True,
|
476 |
+
pad_token_id=self.processing_class.pad_token_id,
|
477 |
+
)
|
478 |
+
# Set custom EOS tokens if they are specified by the model's generation
|
479 |
+
# config. This is important for models with the Llama 3 chat template,
|
480 |
+
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
481 |
+
# turns or messages.
|
482 |
+
if (
|
483 |
+
hasattr(self.model.generation_config, "eos_token_id")
|
484 |
+
and self.model.generation_config.eos_token_id is not None
|
485 |
+
):
|
486 |
+
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
487 |
+
|
488 |
+
def _prepare_dataset(self, dataset, *args):
|
489 |
+
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
490 |
+
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
491 |
+
dataset = dataset.add_column("_messages", dataset["messages"])
|
492 |
+
dataset = super()._prepare_dataset(dataset, *args)
|
493 |
+
dataset = dataset.rename_column("_messages", "messages")
|
494 |
+
return dataset
|
495 |
+
|
496 |
+
@staticmethod
|
497 |
+
def generalized_jsd_loss(
|
498 |
+
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
499 |
+
):
|
500 |
+
"""
|
501 |
+
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
502 |
+
of https://huggingface.co/papers/2306.13649 for the definition.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
506 |
+
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
507 |
+
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
508 |
+
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
509 |
+
temperature: Softmax temperature (default: 1.0)
|
510 |
+
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
loss: Scalar tensor with the generalized JSD loss
|
514 |
+
"""
|
515 |
+
|
516 |
+
# Apply temperature scaling
|
517 |
+
student_logits = student_logits / temperature
|
518 |
+
teacher_logits = teacher_logits / temperature
|
519 |
+
|
520 |
+
# Compute log probabilities for student and probabilities for teacher
|
521 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
522 |
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
523 |
+
|
524 |
+
# Compute the log of the mixture distribution
|
525 |
+
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
526 |
+
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
527 |
+
mixture_log_probs = torch.logsumexp(
|
528 |
+
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
|
529 |
+
dim=0,
|
530 |
+
)
|
531 |
+
|
532 |
+
# Compute KL divergences using F.kl_div
|
533 |
+
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
534 |
+
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
535 |
+
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
536 |
+
|
537 |
+
# Compute the Generalized Jensen-Shannon Divergence
|
538 |
+
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
539 |
+
|
540 |
+
# Masking
|
541 |
+
if labels is not None:
|
542 |
+
mask = labels != -100
|
543 |
+
jsd = jsd[mask]
|
544 |
+
|
545 |
+
# Apply reduction
|
546 |
+
if reduction == "batchmean":
|
547 |
+
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
548 |
+
elif reduction == "sum":
|
549 |
+
return jsd.sum()
|
550 |
+
elif reduction == "mean":
|
551 |
+
return jsd.mean()
|
552 |
+
else:
|
553 |
+
return jsd
|
554 |
+
|
555 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
556 |
+
# compute student output
|
557 |
+
outputs_student = model(
|
558 |
+
input_ids=inputs["input_ids"],
|
559 |
+
attention_mask=inputs["attention_mask"],
|
560 |
+
)
|
561 |
+
|
562 |
+
# compute teacher output in eval mode
|
563 |
+
self.teacher_model.eval()
|
564 |
+
with torch.no_grad():
|
565 |
+
outputs_teacher = self.teacher_model(
|
566 |
+
input_ids=inputs["input_ids"],
|
567 |
+
attention_mask=inputs["attention_mask"],
|
568 |
+
)
|
569 |
+
|
570 |
+
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
571 |
+
prompt_lengths = inputs["prompts"].shape[1]
|
572 |
+
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
573 |
+
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
574 |
+
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
575 |
+
|
576 |
+
# compute loss
|
577 |
+
loss = self.generalized_jsd_loss(
|
578 |
+
student_logits=shifted_student_logits,
|
579 |
+
teacher_logits=shifted_teacher_logits,
|
580 |
+
labels=shifted_labels,
|
581 |
+
beta=self.beta,
|
582 |
+
)
|
583 |
+
|
584 |
+
# empty cache
|
585 |
+
empty_cache()
|
586 |
+
|
587 |
+
# Return loss
|
588 |
+
return (loss, outputs_student) if return_outputs else loss
|
589 |
+
|
590 |
+
@staticmethod
|
591 |
+
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
592 |
+
# Generate output with respect to the prompt only
|
593 |
+
generated_outputs = model.generate(
|
594 |
+
input_ids=inputs["prompts"],
|
595 |
+
attention_mask=inputs.get("prompt_attention_mask", None),
|
596 |
+
generation_config=generation_config,
|
597 |
+
return_dict_in_generate=True,
|
598 |
+
)
|
599 |
+
|
600 |
+
# Get the generated token IDs
|
601 |
+
generated_tokens = generated_outputs.sequences
|
602 |
+
# Calculate new attention mask
|
603 |
+
new_attention_mask = torch.ones_like(generated_tokens)
|
604 |
+
new_labels = generated_tokens.clone()
|
605 |
+
|
606 |
+
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
607 |
+
if pad_token_id is not None:
|
608 |
+
new_labels[new_labels == pad_token_id] = -100
|
609 |
+
new_attention_mask[generated_tokens == pad_token_id] = 0
|
610 |
+
|
611 |
+
return generated_tokens, new_attention_mask, new_labels
|
612 |
+
|
613 |
+
def training_step(
|
614 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
615 |
+
) -> torch.Tensor:
|
616 |
+
"""
|
617 |
+
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
618 |
+
|
619 |
+
This method implements the on-policy learning approach described in the GKD paper.
|
620 |
+
With probability `self.lmbda`, it generates new responses using the student model,
|
621 |
+
which are then used for training instead of the original inputs.
|
622 |
+
"""
|
623 |
+
if self.seq_kd:
|
624 |
+
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
625 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
626 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
627 |
+
)
|
628 |
+
inputs["input_ids"] = new_input_ids
|
629 |
+
inputs["attention_mask"] = new_attention_mask
|
630 |
+
inputs["labels"] = new_labels
|
631 |
+
if random.random() <= self.lmbda:
|
632 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
633 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
634 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
635 |
+
)
|
636 |
+
inputs["input_ids"] = new_input_ids
|
637 |
+
inputs["attention_mask"] = new_attention_mask
|
638 |
+
inputs["labels"] = new_labels
|
639 |
+
|
640 |
+
loss = super().training_step(model, inputs, num_items_in_batch)
|
641 |
+
return loss
|
642 |
+
|
643 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
644 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
645 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
646 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
647 |
+
|
648 |
+
if model is not None:
|
649 |
+
if hasattr(model, "config"):
|
650 |
+
hidden_size = (
|
651 |
+
max(model.config.hidden_sizes)
|
652 |
+
if getattr(model.config, "hidden_sizes", None)
|
653 |
+
else getattr(model.config, "hidden_size", None)
|
654 |
+
)
|
655 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
656 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
657 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
658 |
+
config_kwargs.update(
|
659 |
+
{
|
660 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
661 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
662 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
663 |
+
}
|
664 |
+
)
|
665 |
+
|
666 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
667 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
668 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
669 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
670 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
671 |
+
model.eval()
|
672 |
+
return model
|
673 |
+
|
674 |
+
def create_model_card(
|
675 |
+
self,
|
676 |
+
model_name: Optional[str] = None,
|
677 |
+
dataset_name: Optional[str] = None,
|
678 |
+
tags: Union[str, list[str], None] = None,
|
679 |
+
):
|
680 |
+
"""
|
681 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
682 |
+
|
683 |
+
Args:
|
684 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
685 |
+
Name of the model.
|
686 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
687 |
+
Name of the dataset used for training.
|
688 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
689 |
+
Tags to be associated with the model card.
|
690 |
+
"""
|
691 |
+
if not self.is_world_process_zero():
|
692 |
+
return
|
693 |
+
|
694 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
695 |
+
base_model = self.model.config._name_or_path
|
696 |
+
else:
|
697 |
+
base_model = None
|
698 |
+
|
699 |
+
tags = tags or []
|
700 |
+
if isinstance(tags, str):
|
701 |
+
tags = [tags]
|
702 |
+
|
703 |
+
if hasattr(self.model.config, "unsloth_version"):
|
704 |
+
tags.append("unsloth")
|
705 |
+
|
706 |
+
citation = textwrap.dedent("""\
|
707 |
+
@inproceedings{agarwal2024on-policy,
|
708 |
+
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
709 |
+
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
710 |
+
year = 2024,
|
711 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
712 |
+
publisher = {OpenReview.net},
|
713 |
+
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
714 |
+
}""")
|
715 |
+
|
716 |
+
model_card = generate_model_card(
|
717 |
+
base_model=base_model,
|
718 |
+
model_name=model_name,
|
719 |
+
hub_model_id=self.hub_model_id,
|
720 |
+
dataset_name=dataset_name,
|
721 |
+
tags=tags,
|
722 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
723 |
+
comet_url=get_comet_experiment_url(),
|
724 |
+
trainer_name="GKD",
|
725 |
+
trainer_citation=citation,
|
726 |
+
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
727 |
+
paper_id="2306.13649",
|
728 |
+
)
|
729 |
+
|
730 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
731 |
+
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
732 |
+
"""
|
733 |
+
|
734 |
+
"""
|
735 |
+
def __init__(
|
736 |
+
self,
|
737 |
+
model = None,
|
738 |
+
teacher_model = None,
|
739 |
+
args = None,
|
740 |
+
data_collator = None,
|
741 |
+
train_dataset = None,
|
742 |
+
eval_dataset = None,
|
743 |
+
processing_class = None,
|
744 |
+
compute_metrics = None,
|
745 |
+
callbacks = None,
|
746 |
+
preprocess_logits_for_metrics = None,
|
747 |
+
peft_config = None,
|
748 |
+
formatting_func = None,
|
749 |
+
**kwargs
|
750 |
+
):
|
751 |
+
if args is None: args = UnslothGKDConfig()
|
752 |
+
use_bf16 = getattr(args, 'bf16', False)
|
753 |
+
use_fp16 = getattr(args, 'fp16', False)
|
754 |
+
force_float32 = False
|
755 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
756 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
757 |
+
force_float32 = True
|
758 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
759 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
760 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
761 |
+
from unsloth_zoo.utils import _get_dtype
|
762 |
+
dtype = _get_dtype(dtype)
|
763 |
+
float16 = dtype == torch.float16
|
764 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
765 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
766 |
+
if force_float32:
|
767 |
+
args.fp16 = False
|
768 |
+
args.bf16 = False
|
769 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
770 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
771 |
+
args.fp16 = float16
|
772 |
+
args.bf16 = not float16
|
773 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
774 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
775 |
+
args.eval_strategy = 'steps'
|
776 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
777 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
778 |
+
if ga_steps is not None and ga_steps > 1:
|
779 |
+
from transformers import __version__ as transformers_version
|
780 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
781 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
782 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
783 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
784 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
785 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
786 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
787 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
788 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
789 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
790 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
791 |
+
if force_float32:
|
792 |
+
args.bf16_full_eval = False
|
793 |
+
args.fp16_full_eval = False
|
794 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
795 |
+
args.bf16_full_eval = True
|
796 |
+
args.fp16_full_eval = False
|
797 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
798 |
+
args.bf16_full_eval = args.bf16
|
799 |
+
args.fp16_full_eval = args.fp16
|
800 |
+
_output_logits = False
|
801 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
802 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
803 |
+
if _output_logits:
|
804 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
805 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
806 |
+
pass
|
807 |
+
else:
|
808 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
809 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
810 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
811 |
+
max_seq_length = model.max_seq_length
|
812 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
813 |
+
if model is not None and hasattr(model, 'for_training'):
|
814 |
+
model.for_training()
|
815 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
816 |
+
if 'processing_class' in locals():
|
817 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
818 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
819 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
820 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
821 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
822 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
823 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
824 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
825 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
826 |
+
else:
|
827 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
828 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
829 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
830 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
831 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
832 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
833 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
834 |
+
else:
|
835 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
836 |
+
other_metrics = []
|
837 |
+
|
838 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
839 |
+
PatchRLStatistics('gkd_trainer', other_metrics)
|
840 |
+
|
841 |
+
super().__init__(
|
842 |
+
model = model,
|
843 |
+
teacher_model = teacher_model,
|
844 |
+
args = args,
|
845 |
+
data_collator = data_collator,
|
846 |
+
train_dataset = train_dataset,
|
847 |
+
eval_dataset = eval_dataset,
|
848 |
+
processing_class = processing_class,
|
849 |
+
compute_metrics = compute_metrics,
|
850 |
+
callbacks = callbacks,
|
851 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
852 |
+
peft_config = peft_config,
|
853 |
+
formatting_func = formatting_func,**kwargs)
|
854 |
+
if hasattr(self, 'neftune_hook_handle'):
|
855 |
+
self.neftune_hook_handle.remove()
|
856 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
857 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
858 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
859 |
+
pass
|
860 |
+
|
861 |
+
pass
|
unsloth_compiled_cache/UnslothGRPOTrainer.py
ADDED
@@ -0,0 +1,1436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, wandb, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, wandb, GRPOTrainer, Trainer, gather, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
|
43 |
+
def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
|
44 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
45 |
+
old_logits = old_logits.to(torch.float32)
|
46 |
+
new_logits = new_logits.to(torch.float32)
|
47 |
+
input_ids = input_ids.unsqueeze(-1)
|
48 |
+
|
49 |
+
# x_i - logsumexp(x_i)
|
50 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
51 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
52 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
53 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
54 |
+
|
55 |
+
# Reverse KL
|
56 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
57 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
58 |
+
# kl_i = torch.exp(new) * kl_i
|
59 |
+
|
60 |
+
# Below is forward KL (normal KL)
|
61 |
+
# kl_i = torch.exp(old) * (old - new)
|
62 |
+
|
63 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
64 |
+
# exp(x - x) == 1
|
65 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
66 |
+
loss_i = -(loss_i - beta * kl_i)
|
67 |
+
|
68 |
+
mask = mask.to(torch.float32)
|
69 |
+
n_mask_per_reward = mask.sum(1)
|
70 |
+
|
71 |
+
# See https://github.com/huggingface/trl/pull/2881
|
72 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
73 |
+
loss = loss_per_reward.mean()
|
74 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
75 |
+
|
76 |
+
# Get metrics as well which are folded
|
77 |
+
with torch.inference_mode():
|
78 |
+
completion_length = n_mask_per_reward.mean()
|
79 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
80 |
+
mean_kl = mean_kl_per_reward.mean()
|
81 |
+
pass
|
82 |
+
return loss, completion_length, mean_kl
|
83 |
+
|
84 |
+
class UnslothEfficientGRPO(torch.autograd.Function):
|
85 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
|
88 |
+
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
|
89 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
90 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
91 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
92 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
93 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
94 |
+
old_logits, new_logits, input_ids, mask, beta, advantages,
|
95 |
+
)
|
96 |
+
# Scale loss if needed for mixed precision training
|
97 |
+
scaled_loss = loss * scaling
|
98 |
+
# Must add .loss.detach otherwise autograd uses 2x VRAM
|
99 |
+
return scaled_loss, (loss.detach(), completion_length, mean_kl,)
|
100 |
+
pass
|
101 |
+
|
102 |
+
device =_new_hidden_states.device
|
103 |
+
grad_inputs = torch.empty_like(_new_hidden_states)
|
104 |
+
accumulated_loss = torch.zeros(1, device = device)
|
105 |
+
accumulated_completion_length = torch.zeros(1, device = device)
|
106 |
+
accumulated_mean_kl = torch.zeros(1, device = device)
|
107 |
+
|
108 |
+
def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
|
109 |
+
(chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
|
110 |
+
compute_loss,
|
111 |
+
argnums = (0,),
|
112 |
+
has_aux = True,
|
113 |
+
)(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
114 |
+
accumulated_loss .add_(unscaled_loss)
|
115 |
+
accumulated_completion_length.add_(chunk_completion_length)
|
116 |
+
accumulated_mean_kl .add_(chunk_mean_kl)
|
117 |
+
return chunk_grad_input
|
118 |
+
pass
|
119 |
+
|
120 |
+
accumulate_chunk = torch.compile(
|
121 |
+
accumulate_chunk,
|
122 |
+
fullgraph = True,
|
123 |
+
options = torch_compile_options,
|
124 |
+
)
|
125 |
+
|
126 |
+
grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
|
127 |
+
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
|
128 |
+
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
|
129 |
+
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
|
130 |
+
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
|
131 |
+
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
|
132 |
+
|
133 |
+
# Get mixed precision scaling if seen
|
134 |
+
scaling = scaler.get_scale() if scaler is not None else 1.0
|
135 |
+
|
136 |
+
# Force torch.compile to use dynamic shapes for seqlen dim
|
137 |
+
mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
|
138 |
+
|
139 |
+
for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
|
140 |
+
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
|
141 |
+
|
142 |
+
mark_dynamic(new_hidden_states_j)
|
143 |
+
mark_dynamic(old_hidden_states_j)
|
144 |
+
mark_dynamic(input_ids_j)
|
145 |
+
mark_dynamic(mask_j)
|
146 |
+
|
147 |
+
grad_inputs_j.copy_(
|
148 |
+
accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
149 |
+
)
|
150 |
+
pass
|
151 |
+
|
152 |
+
grad_inputs .div_(n_chunks)
|
153 |
+
accumulated_loss .div_(n_chunks)
|
154 |
+
accumulated_completion_length.div_(n_chunks)
|
155 |
+
accumulated_mean_kl .div_(n_chunks)
|
156 |
+
ctx.save_for_backward(grad_inputs)
|
157 |
+
|
158 |
+
return (
|
159 |
+
accumulated_loss,
|
160 |
+
accumulated_completion_length,
|
161 |
+
accumulated_mean_kl,
|
162 |
+
)
|
163 |
+
pass
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def backward(ctx, grad_output, dcompletion_length, dmean_kl):
|
167 |
+
(grad_input,) = ctx.saved_tensors
|
168 |
+
return (grad_input, None, None, None, None, None, None, None, None,)
|
169 |
+
pass
|
170 |
+
|
171 |
+
def grpo_accumulated_loss(
|
172 |
+
trainer,
|
173 |
+
input_ids,
|
174 |
+
logits_to_keep,
|
175 |
+
completion_mask,
|
176 |
+
advantages,
|
177 |
+
n_chunks = -1,
|
178 |
+
):
|
179 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
180 |
+
bsz, qlen = input_ids.shape
|
181 |
+
# Find closest multiple
|
182 |
+
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
|
183 |
+
if n_chunks == -1: n_chunks = bsz
|
184 |
+
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
|
185 |
+
|
186 |
+
mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
187 |
+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
188 |
+
|
189 |
+
completion_input_ids = input_ids[:, -logits_to_keep:]
|
190 |
+
lm_head = trainer.model.get_output_embeddings().weight
|
191 |
+
|
192 |
+
with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
|
193 |
+
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
|
194 |
+
old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
195 |
+
pass
|
196 |
+
|
197 |
+
new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
198 |
+
|
199 |
+
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
200 |
+
new_hidden_states, old_hidden_states, lm_head,
|
201 |
+
completion_input_ids, completion_mask, advantages, trainer.beta,
|
202 |
+
trainer.accelerator.scaler,
|
203 |
+
n_chunks,
|
204 |
+
)
|
205 |
+
return loss, completion_length, mean_kl
|
206 |
+
|
207 |
+
# Old non efficient code path
|
208 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
209 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
210 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
211 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
212 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
213 |
+
old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
|
214 |
+
)
|
215 |
+
return loss, completion_length, mean_kl
|
216 |
+
pass
|
217 |
+
|
218 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
|
219 |
+
def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
|
220 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
221 |
+
old_logits = old_logits.to(torch.float32)
|
222 |
+
new_logits = new_logits.to(torch.float32)
|
223 |
+
input_ids = input_ids.unsqueeze(-1)
|
224 |
+
|
225 |
+
# x_i - logsumexp(x_i)
|
226 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
227 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
228 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
229 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
230 |
+
|
231 |
+
# Reverse KL
|
232 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
233 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
234 |
+
# kl_i = torch.exp(new) * kl_i
|
235 |
+
|
236 |
+
# Below is forward KL (normal KL)
|
237 |
+
# kl_i = torch.exp(old) * (old - new)
|
238 |
+
|
239 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
240 |
+
# exp(x - x) == 1
|
241 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
242 |
+
loss_i = -(loss_i - beta * kl_i)
|
243 |
+
|
244 |
+
mask = mask.to(torch.float32)
|
245 |
+
n_mask_per_reward = mask.sum(1)
|
246 |
+
|
247 |
+
# See https://github.com/huggingface/trl/pull/2881
|
248 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
249 |
+
loss = loss_per_reward.mean()
|
250 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
251 |
+
|
252 |
+
# Get metrics as well which are folded
|
253 |
+
with torch.inference_mode():
|
254 |
+
completion_length = n_mask_per_reward.mean()
|
255 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
256 |
+
mean_kl = mean_kl_per_reward.mean()
|
257 |
+
pass
|
258 |
+
return loss, completion_length, mean_kl
|
259 |
+
|
260 |
+
def vLLMSamplingParams(**kwargs):
|
261 |
+
from vllm import SamplingParams
|
262 |
+
sampling_params = SamplingParams(**kwargs)
|
263 |
+
sampling_params._set_kwargs = kwargs
|
264 |
+
return sampling_params
|
265 |
+
@dataclass
|
266 |
+
class UnslothGRPOConfig(GRPOConfig):
|
267 |
+
"""
|
268 |
+
|
269 |
+
Configuration class for the [`GRPOTrainer`].
|
270 |
+
|
271 |
+
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
272 |
+
[`~transformers.TrainingArguments`] documentation.
|
273 |
+
|
274 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
275 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
276 |
+
command line.
|
277 |
+
|
278 |
+
Parameters:
|
279 |
+
> Parameters that control the model and reference model
|
280 |
+
|
281 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
282 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
283 |
+
argument of the [`GRPOTrainer`] is provided as a string.
|
284 |
+
|
285 |
+
> Parameters that control the data preprocessing
|
286 |
+
|
287 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
288 |
+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
289 |
+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
290 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
291 |
+
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
292 |
+
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
293 |
+
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
294 |
+
must be divisible by this value.
|
295 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
296 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
297 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
298 |
+
Maximum length of the generated completion.
|
299 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
300 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
301 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
302 |
+
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
303 |
+
with vLLM generation.
|
304 |
+
|
305 |
+
> Parameters that control generation acceleration powered by vLLM
|
306 |
+
|
307 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
308 |
+
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
309 |
+
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
310 |
+
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
311 |
+
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
312 |
+
automatically select the next available GPU after the last one used for training. This assumes that
|
313 |
+
training has not already occupied all available GPUs. If only one device is available, the device will be
|
314 |
+
shared between both training and vLLM.
|
315 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
316 |
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
317 |
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
318 |
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
319 |
+
during initialization.
|
320 |
+
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
321 |
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
322 |
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
323 |
+
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
324 |
+
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
325 |
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
326 |
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
327 |
+
|
328 |
+
> Parameters that control the training
|
329 |
+
|
330 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
331 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
332 |
+
[`~transformers.TrainingArguments`].
|
333 |
+
beta (`float`, *optional*, defaults to `0.04`):
|
334 |
+
KL coefficient.
|
335 |
+
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
336 |
+
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
337 |
+
weighted equally with weight `1.0`.
|
338 |
+
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
339 |
+
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
340 |
+
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
341 |
+
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
342 |
+
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
|
343 |
+
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
344 |
+
between the current policy and the previous reference policy during updates. The reference policy is
|
345 |
+
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
346 |
+
must set `sync_ref_model=True`.
|
347 |
+
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
|
348 |
+
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
349 |
+
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
350 |
+
set `sync_ref_model=True`.
|
351 |
+
|
352 |
+
> Parameters that control the logging
|
353 |
+
|
354 |
+
log_completions (`bool`, *optional*, defaults to `False`):
|
355 |
+
Whether to log the completions during training.
|
356 |
+
|
357 |
+
"""
|
358 |
+
vllm_sampling_params: Optional[Any] = field(
|
359 |
+
default = None,
|
360 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
361 |
+
)
|
362 |
+
unsloth_num_chunks : Optional[int] = field(
|
363 |
+
default = -1,
|
364 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
365 |
+
)
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
output_dir = None,
|
369 |
+
overwrite_output_dir = None,
|
370 |
+
do_train = False,
|
371 |
+
do_eval = False,
|
372 |
+
do_predict = False,
|
373 |
+
eval_strategy = 'no',
|
374 |
+
prediction_loss_only = False,
|
375 |
+
per_device_train_batch_size = 4,
|
376 |
+
per_device_eval_batch_size = 4,
|
377 |
+
per_gpu_train_batch_size = None,
|
378 |
+
per_gpu_eval_batch_size = None,
|
379 |
+
gradient_accumulation_steps = 2,
|
380 |
+
eval_accumulation_steps = 2,
|
381 |
+
eval_delay = 0,
|
382 |
+
torch_empty_cache_steps = 250,
|
383 |
+
learning_rate = 5e-05,
|
384 |
+
weight_decay = 0.01,
|
385 |
+
adam_beta1 = 0.9,
|
386 |
+
adam_beta2 = 0.999,
|
387 |
+
adam_epsilon = 1e-08,
|
388 |
+
max_grad_norm = 1.0,
|
389 |
+
num_train_epochs = 3.0,
|
390 |
+
max_steps = -1,
|
391 |
+
lr_scheduler_type = 'linear',
|
392 |
+
warmup_ratio = 0.1,
|
393 |
+
warmup_steps = 0,
|
394 |
+
log_level = 'passive',
|
395 |
+
log_level_replica = 'warning',
|
396 |
+
log_on_each_node = True,
|
397 |
+
logging_dir = None,
|
398 |
+
logging_strategy = 'steps',
|
399 |
+
logging_first_step = False,
|
400 |
+
logging_steps = 1,
|
401 |
+
logging_nan_inf_filter = False,
|
402 |
+
save_strategy = 'steps',
|
403 |
+
save_steps = 500,
|
404 |
+
save_total_limit = None,
|
405 |
+
save_safetensors = True,
|
406 |
+
save_on_each_node = False,
|
407 |
+
save_only_model = False,
|
408 |
+
restore_callback_states_from_checkpoint = False,
|
409 |
+
no_cuda = False,
|
410 |
+
use_cpu = False,
|
411 |
+
use_mps_device = False,
|
412 |
+
seed = 3407,
|
413 |
+
data_seed = 3407,
|
414 |
+
jit_mode_eval = False,
|
415 |
+
use_ipex = False,
|
416 |
+
bf16 = False,
|
417 |
+
fp16 = False,
|
418 |
+
fp16_opt_level = 'O1',
|
419 |
+
half_precision_backend = 'auto',
|
420 |
+
bf16_full_eval = False,
|
421 |
+
fp16_full_eval = False,
|
422 |
+
tf32 = None,
|
423 |
+
local_rank = -1,
|
424 |
+
ddp_backend = None,
|
425 |
+
tpu_num_cores = None,
|
426 |
+
tpu_metrics_debug = False,
|
427 |
+
debug = '',
|
428 |
+
dataloader_drop_last = False,
|
429 |
+
eval_steps = None,
|
430 |
+
dataloader_num_workers = 0,
|
431 |
+
dataloader_prefetch_factor = None,
|
432 |
+
past_index = -1,
|
433 |
+
run_name = None,
|
434 |
+
disable_tqdm = None,
|
435 |
+
remove_unused_columns = False,
|
436 |
+
label_names = None,
|
437 |
+
load_best_model_at_end = False,
|
438 |
+
metric_for_best_model = None,
|
439 |
+
greater_is_better = None,
|
440 |
+
ignore_data_skip = False,
|
441 |
+
fsdp = '',
|
442 |
+
fsdp_min_num_params = 0,
|
443 |
+
fsdp_config = None,
|
444 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
445 |
+
accelerator_config = None,
|
446 |
+
deepspeed = None,
|
447 |
+
label_smoothing_factor = 0.0,
|
448 |
+
optim = 'adamw_8bit',
|
449 |
+
optim_args = None,
|
450 |
+
adafactor = False,
|
451 |
+
group_by_length = False,
|
452 |
+
length_column_name = 'length',
|
453 |
+
report_to = None,
|
454 |
+
ddp_find_unused_parameters = None,
|
455 |
+
ddp_bucket_cap_mb = None,
|
456 |
+
ddp_broadcast_buffers = None,
|
457 |
+
dataloader_pin_memory = True,
|
458 |
+
dataloader_persistent_workers = False,
|
459 |
+
skip_memory_metrics = True,
|
460 |
+
use_legacy_prediction_loop = False,
|
461 |
+
push_to_hub = False,
|
462 |
+
resume_from_checkpoint = None,
|
463 |
+
hub_model_id = None,
|
464 |
+
hub_strategy = 'every_save',
|
465 |
+
hub_token = None,
|
466 |
+
hub_private_repo = None,
|
467 |
+
hub_always_push = False,
|
468 |
+
gradient_checkpointing = False,
|
469 |
+
gradient_checkpointing_kwargs = None,
|
470 |
+
include_inputs_for_metrics = False,
|
471 |
+
eval_do_concat_batches = True,
|
472 |
+
fp16_backend = 'auto',
|
473 |
+
evaluation_strategy = None,
|
474 |
+
push_to_hub_model_id = None,
|
475 |
+
push_to_hub_organization = None,
|
476 |
+
push_to_hub_token = None,
|
477 |
+
mp_parameters = '',
|
478 |
+
auto_find_batch_size = False,
|
479 |
+
full_determinism = False,
|
480 |
+
torchdynamo = None,
|
481 |
+
ray_scope = 'last',
|
482 |
+
ddp_timeout = 1800,
|
483 |
+
torch_compile = False,
|
484 |
+
torch_compile_backend = None,
|
485 |
+
torch_compile_mode = None,
|
486 |
+
dispatch_batches = None,
|
487 |
+
split_batches = None,
|
488 |
+
include_tokens_per_second = False,
|
489 |
+
include_num_input_tokens_seen = False,
|
490 |
+
neftune_noise_alpha = None,
|
491 |
+
optim_target_modules = None,
|
492 |
+
batch_eval_metrics = False,
|
493 |
+
eval_on_start = False,
|
494 |
+
use_liger_kernel = False,
|
495 |
+
eval_use_gather_object = False,
|
496 |
+
average_tokens_across_devices = False,
|
497 |
+
model_init_kwargs = None,
|
498 |
+
max_prompt_length = 512,
|
499 |
+
num_generations = 8,
|
500 |
+
temperature = 0.9,
|
501 |
+
max_completion_length = 256,
|
502 |
+
ds3_gather_for_generation = True,
|
503 |
+
use_vllm = False,
|
504 |
+
vllm_device = 'auto',
|
505 |
+
vllm_gpu_memory_utilization = 0.9,
|
506 |
+
vllm_dtype = 'auto',
|
507 |
+
vllm_max_model_len = None,
|
508 |
+
beta = 0.04,
|
509 |
+
reward_weights = None,
|
510 |
+
sync_ref_model = False,
|
511 |
+
ref_model_mixup_alpha = 0.9,
|
512 |
+
ref_model_sync_steps = 64,
|
513 |
+
log_completions = False,
|
514 |
+
vllm_sampling_params = None,
|
515 |
+
unsloth_num_chunks = -1,
|
516 |
+
**kwargs,
|
517 |
+
):
|
518 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
519 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
520 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
521 |
+
output_dir = 'unsloth_training_checkpoints'
|
522 |
+
save_strategy = 'no'
|
523 |
+
div = per_device_train_batch_size // num_generations
|
524 |
+
if div * num_generations != per_device_train_batch_size:
|
525 |
+
print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
|
526 |
+
per_device_train_batch_size = num_generations
|
527 |
+
|
528 |
+
super().__init__(
|
529 |
+
output_dir = output_dir,
|
530 |
+
overwrite_output_dir = overwrite_output_dir,
|
531 |
+
do_train = do_train,
|
532 |
+
do_eval = do_eval,
|
533 |
+
do_predict = do_predict,
|
534 |
+
eval_strategy = eval_strategy,
|
535 |
+
prediction_loss_only = prediction_loss_only,
|
536 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
537 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
538 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
539 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
540 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
541 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
542 |
+
eval_delay = eval_delay,
|
543 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
544 |
+
learning_rate = learning_rate,
|
545 |
+
weight_decay = weight_decay,
|
546 |
+
adam_beta1 = adam_beta1,
|
547 |
+
adam_beta2 = adam_beta2,
|
548 |
+
adam_epsilon = adam_epsilon,
|
549 |
+
max_grad_norm = max_grad_norm,
|
550 |
+
num_train_epochs = num_train_epochs,
|
551 |
+
max_steps = max_steps,
|
552 |
+
lr_scheduler_type = lr_scheduler_type,
|
553 |
+
warmup_ratio = warmup_ratio,
|
554 |
+
warmup_steps = warmup_steps,
|
555 |
+
log_level = log_level,
|
556 |
+
log_level_replica = log_level_replica,
|
557 |
+
log_on_each_node = log_on_each_node,
|
558 |
+
logging_dir = logging_dir,
|
559 |
+
logging_strategy = logging_strategy,
|
560 |
+
logging_first_step = logging_first_step,
|
561 |
+
logging_steps = logging_steps,
|
562 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
563 |
+
save_strategy = save_strategy,
|
564 |
+
save_steps = save_steps,
|
565 |
+
save_total_limit = save_total_limit,
|
566 |
+
save_safetensors = save_safetensors,
|
567 |
+
save_on_each_node = save_on_each_node,
|
568 |
+
save_only_model = save_only_model,
|
569 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
570 |
+
no_cuda = no_cuda,
|
571 |
+
use_cpu = use_cpu,
|
572 |
+
use_mps_device = use_mps_device,
|
573 |
+
seed = seed,
|
574 |
+
data_seed = data_seed,
|
575 |
+
jit_mode_eval = jit_mode_eval,
|
576 |
+
use_ipex = use_ipex,
|
577 |
+
bf16 = bf16,
|
578 |
+
fp16 = fp16,
|
579 |
+
fp16_opt_level = fp16_opt_level,
|
580 |
+
half_precision_backend = half_precision_backend,
|
581 |
+
bf16_full_eval = bf16_full_eval,
|
582 |
+
fp16_full_eval = fp16_full_eval,
|
583 |
+
tf32 = tf32,
|
584 |
+
local_rank = local_rank,
|
585 |
+
ddp_backend = ddp_backend,
|
586 |
+
tpu_num_cores = tpu_num_cores,
|
587 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
588 |
+
debug = debug,
|
589 |
+
dataloader_drop_last = dataloader_drop_last,
|
590 |
+
eval_steps = eval_steps,
|
591 |
+
dataloader_num_workers = dataloader_num_workers,
|
592 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
593 |
+
past_index = past_index,
|
594 |
+
run_name = run_name,
|
595 |
+
disable_tqdm = disable_tqdm,
|
596 |
+
remove_unused_columns = remove_unused_columns,
|
597 |
+
label_names = label_names,
|
598 |
+
load_best_model_at_end = load_best_model_at_end,
|
599 |
+
metric_for_best_model = metric_for_best_model,
|
600 |
+
greater_is_better = greater_is_better,
|
601 |
+
ignore_data_skip = ignore_data_skip,
|
602 |
+
fsdp = fsdp,
|
603 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
604 |
+
fsdp_config = fsdp_config,
|
605 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
606 |
+
accelerator_config = accelerator_config,
|
607 |
+
deepspeed = deepspeed,
|
608 |
+
label_smoothing_factor = label_smoothing_factor,
|
609 |
+
optim = optim,
|
610 |
+
optim_args = optim_args,
|
611 |
+
adafactor = adafactor,
|
612 |
+
group_by_length = group_by_length,
|
613 |
+
length_column_name = length_column_name,
|
614 |
+
report_to = report_to,
|
615 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
616 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
617 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
618 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
619 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
620 |
+
skip_memory_metrics = skip_memory_metrics,
|
621 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
622 |
+
push_to_hub = push_to_hub,
|
623 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
624 |
+
hub_model_id = hub_model_id,
|
625 |
+
hub_strategy = hub_strategy,
|
626 |
+
hub_token = hub_token,
|
627 |
+
hub_private_repo = hub_private_repo,
|
628 |
+
hub_always_push = hub_always_push,
|
629 |
+
gradient_checkpointing = gradient_checkpointing,
|
630 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
631 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
632 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
633 |
+
fp16_backend = fp16_backend,
|
634 |
+
evaluation_strategy = evaluation_strategy,
|
635 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
636 |
+
push_to_hub_organization = push_to_hub_organization,
|
637 |
+
push_to_hub_token = push_to_hub_token,
|
638 |
+
mp_parameters = mp_parameters,
|
639 |
+
auto_find_batch_size = auto_find_batch_size,
|
640 |
+
full_determinism = full_determinism,
|
641 |
+
torchdynamo = torchdynamo,
|
642 |
+
ray_scope = ray_scope,
|
643 |
+
ddp_timeout = ddp_timeout,
|
644 |
+
torch_compile = torch_compile,
|
645 |
+
torch_compile_backend = torch_compile_backend,
|
646 |
+
torch_compile_mode = torch_compile_mode,
|
647 |
+
dispatch_batches = dispatch_batches,
|
648 |
+
split_batches = split_batches,
|
649 |
+
include_tokens_per_second = include_tokens_per_second,
|
650 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
651 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
652 |
+
optim_target_modules = optim_target_modules,
|
653 |
+
batch_eval_metrics = batch_eval_metrics,
|
654 |
+
eval_on_start = eval_on_start,
|
655 |
+
use_liger_kernel = use_liger_kernel,
|
656 |
+
eval_use_gather_object = eval_use_gather_object,
|
657 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
658 |
+
model_init_kwargs = model_init_kwargs,
|
659 |
+
max_prompt_length = max_prompt_length,
|
660 |
+
num_generations = num_generations,
|
661 |
+
temperature = temperature,
|
662 |
+
max_completion_length = max_completion_length,
|
663 |
+
ds3_gather_for_generation = ds3_gather_for_generation,
|
664 |
+
use_vllm = use_vllm,
|
665 |
+
vllm_device = vllm_device,
|
666 |
+
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
667 |
+
vllm_dtype = vllm_dtype,
|
668 |
+
vllm_max_model_len = vllm_max_model_len,
|
669 |
+
beta = beta,
|
670 |
+
reward_weights = reward_weights,
|
671 |
+
sync_ref_model = sync_ref_model,
|
672 |
+
ref_model_mixup_alpha = ref_model_mixup_alpha,
|
673 |
+
ref_model_sync_steps = ref_model_sync_steps,
|
674 |
+
log_completions = log_completions,**kwargs)
|
675 |
+
self.vllm_sampling_params = vllm_sampling_params
|
676 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
677 |
+
pass
|
678 |
+
|
679 |
+
class _UnslothGRPOTrainer(Trainer):
|
680 |
+
""""""
|
681 |
+
|
682 |
+
_tag_names = ["trl", "grpo"]
|
683 |
+
|
684 |
+
def __init__(
|
685 |
+
self,
|
686 |
+
model: Union[str, PreTrainedModel],
|
687 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
688 |
+
args: GRPOConfig = None,
|
689 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
690 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
691 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
692 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
693 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
694 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
695 |
+
peft_config: Optional["PeftConfig"] = None,
|
696 |
+
):
|
697 |
+
|
698 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
699 |
+
# Args
|
700 |
+
if args is None:
|
701 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
702 |
+
model_name = model_name.split("/")[-1]
|
703 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
704 |
+
|
705 |
+
# Models
|
706 |
+
# Trained model
|
707 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
708 |
+
if isinstance(model, str):
|
709 |
+
model_id = model
|
710 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
711 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
712 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
713 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
714 |
+
torch_dtype = getattr(torch, torch_dtype)
|
715 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
716 |
+
else:
|
717 |
+
raise ValueError(
|
718 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
719 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
720 |
+
)
|
721 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
722 |
+
model_init_kwargs["use_cache"] = (
|
723 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
724 |
+
)
|
725 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
726 |
+
else:
|
727 |
+
model_id = model.config._name_or_path
|
728 |
+
if args.model_init_kwargs is not None:
|
729 |
+
raise ValueError(
|
730 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
731 |
+
"This argument can only be used when the `model` argument is a string."
|
732 |
+
)
|
733 |
+
|
734 |
+
if False:
|
735 |
+
model = model
|
736 |
+
|
737 |
+
# Reference model
|
738 |
+
if is_deepspeed_zero3_enabled():
|
739 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
740 |
+
elif not is_peft_model(model):
|
741 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
742 |
+
self.ref_model = create_reference_model(model)
|
743 |
+
else:
|
744 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
745 |
+
# to revert to the initial model.
|
746 |
+
self.ref_model = None
|
747 |
+
|
748 |
+
# Processing class
|
749 |
+
if processing_class is None:
|
750 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
751 |
+
|
752 |
+
# Reward functions
|
753 |
+
if not isinstance(reward_funcs, list):
|
754 |
+
reward_funcs = [reward_funcs]
|
755 |
+
for i, reward_func in enumerate(reward_funcs):
|
756 |
+
if isinstance(reward_func, str):
|
757 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
758 |
+
reward_func, num_labels=1, **model_init_kwargs
|
759 |
+
)
|
760 |
+
self.reward_funcs = reward_funcs
|
761 |
+
|
762 |
+
# Reward weights
|
763 |
+
if args.reward_weights is not None:
|
764 |
+
if len(args.reward_weights) != len(reward_funcs):
|
765 |
+
raise ValueError(
|
766 |
+
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
767 |
+
f"functions ({len(reward_funcs)})"
|
768 |
+
)
|
769 |
+
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
770 |
+
else:
|
771 |
+
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
772 |
+
|
773 |
+
# Reward processing class
|
774 |
+
if reward_processing_classes is None:
|
775 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
776 |
+
elif not isinstance(reward_processing_classes, list):
|
777 |
+
reward_processing_classes = [reward_processing_classes]
|
778 |
+
else:
|
779 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
780 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
781 |
+
|
782 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
783 |
+
if isinstance(reward_func, PreTrainedModel):
|
784 |
+
if reward_processing_class is None:
|
785 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
786 |
+
if reward_processing_class.pad_token_id is None:
|
787 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
788 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
789 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
790 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
791 |
+
reward_processing_classes[i] = reward_processing_class
|
792 |
+
self.reward_processing_classes = reward_processing_classes
|
793 |
+
|
794 |
+
# Data collator
|
795 |
+
def data_collator(features): # No data collation is needed in GRPO
|
796 |
+
return features
|
797 |
+
|
798 |
+
# Training arguments
|
799 |
+
self.max_prompt_length = args.max_prompt_length
|
800 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
801 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
802 |
+
self.use_vllm = args.use_vllm
|
803 |
+
|
804 |
+
self.beta = args.beta
|
805 |
+
|
806 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
807 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
808 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
809 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
810 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
811 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
812 |
+
model.warnings_issued["estimate_tokens"] = True
|
813 |
+
|
814 |
+
# Initialize the metrics
|
815 |
+
self._metrics = defaultdict(list)
|
816 |
+
self.log_completions = args.log_completions
|
817 |
+
|
818 |
+
super().__init__(
|
819 |
+
model=model,
|
820 |
+
args=args,
|
821 |
+
data_collator=data_collator,
|
822 |
+
train_dataset=train_dataset,
|
823 |
+
eval_dataset=eval_dataset,
|
824 |
+
processing_class=processing_class,
|
825 |
+
callbacks=callbacks,
|
826 |
+
optimizers=optimizers,
|
827 |
+
)
|
828 |
+
|
829 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
830 |
+
num_processes = self.accelerator.num_processes
|
831 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
832 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
833 |
+
if self.num_generations not in possible_values:
|
834 |
+
raise ValueError(
|
835 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
836 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
837 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
838 |
+
)
|
839 |
+
if self.args.eval_strategy != "no":
|
840 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
841 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
842 |
+
if self.num_generations not in possible_values:
|
843 |
+
raise ValueError(
|
844 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
845 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
846 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
847 |
+
)
|
848 |
+
|
849 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
850 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
851 |
+
# it's safer to set it in all cases.
|
852 |
+
set_seed(args.seed, device_specific=True)
|
853 |
+
|
854 |
+
if self.use_vllm:
|
855 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
|
856 |
+
temperature=args.temperature,
|
857 |
+
max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
858 |
+
else:
|
859 |
+
self.generation_config = GenerationConfig(
|
860 |
+
max_new_tokens=self.max_completion_length,
|
861 |
+
do_sample=True,
|
862 |
+
temperature=args.temperature,
|
863 |
+
pad_token_id=processing_class.pad_token_id,
|
864 |
+
)
|
865 |
+
|
866 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
867 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
868 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
869 |
+
self.model_accepts_loss_kwargs = False
|
870 |
+
|
871 |
+
# Add tags to the model
|
872 |
+
self.model.add_model_tags(self._tag_names)
|
873 |
+
|
874 |
+
if self.ref_model is not None:
|
875 |
+
if self.is_deepspeed_enabled:
|
876 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
877 |
+
else:
|
878 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
879 |
+
|
880 |
+
if args.sync_ref_model:
|
881 |
+
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
882 |
+
|
883 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
884 |
+
if isinstance(reward_func, PreTrainedModel):
|
885 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
886 |
+
|
887 |
+
def _set_signature_columns_if_needed(self):
|
888 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
889 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
890 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
891 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
892 |
+
if self._signature_columns is None:
|
893 |
+
self._signature_columns = ["prompt"]
|
894 |
+
|
895 |
+
def _get_train_sampler(self) -> Sampler:
|
896 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
897 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
898 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
899 |
+
# preventing discrepancies in group formation.
|
900 |
+
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
901 |
+
|
902 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
903 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
904 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
905 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
906 |
+
# preventing discrepancies in group formation.
|
907 |
+
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
908 |
+
|
909 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
910 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
911 |
+
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
|
912 |
+
return None # Unsloth efficient GRPO
|
913 |
+
# Otherwise, calculate normally:
|
914 |
+
if not hasattr(self, '_autocast_dtype'):
|
915 |
+
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
916 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
|
917 |
+
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
|
918 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
919 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
920 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
921 |
+
|
922 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
923 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
924 |
+
# See https://github.com/huggingface/trl/issues/2770
|
925 |
+
logits = logits[:, -logits_to_keep:]
|
926 |
+
return logits
|
927 |
+
# return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
928 |
+
pass
|
929 |
+
|
930 |
+
def _move_model_to_vllm(self, *args, **kwargs): return None
|
931 |
+
|
932 |
+
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
933 |
+
device = self.accelerator.device
|
934 |
+
prompts = [x["prompt"] for x in inputs]
|
935 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
936 |
+
prompt_inputs = self.processing_class(
|
937 |
+
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
938 |
+
)
|
939 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
940 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
941 |
+
|
942 |
+
if self.max_prompt_length is not None:
|
943 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
944 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
945 |
+
|
946 |
+
# Generate completions using either vLLM or regular generation
|
947 |
+
if self.args.use_vllm:
|
948 |
+
# First, have main process load weights if needed
|
949 |
+
if self.state.global_step != self._last_loaded_step:
|
950 |
+
self._move_model_to_vllm()
|
951 |
+
self._last_loaded_step = self.state.global_step
|
952 |
+
|
953 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
954 |
+
all_prompts_text = gather_object(prompts_text)
|
955 |
+
if self.accelerator.is_main_process:
|
956 |
+
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
|
957 |
+
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
958 |
+
else:
|
959 |
+
completion_ids = [None] * len(all_prompts_text)
|
960 |
+
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
961 |
+
# corresponding slice.
|
962 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
963 |
+
process_slice = slice(
|
964 |
+
self.accelerator.process_index * len(prompts),
|
965 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
966 |
+
)
|
967 |
+
completion_ids = completion_ids[process_slice]
|
968 |
+
|
969 |
+
# Pad the completions, and concatenate them with the prompts
|
970 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
971 |
+
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
972 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
973 |
+
else:
|
974 |
+
# Regular generation path
|
975 |
+
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
976 |
+
prompt_completion_ids = unwrapped_model.generate(
|
977 |
+
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
978 |
+
)
|
979 |
+
|
980 |
+
# Compute prompt length and extract completion ids
|
981 |
+
prompt_length = prompt_ids.size(1)
|
982 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
983 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
984 |
+
|
985 |
+
# Mask everything after the first EOS token
|
986 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
987 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
988 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
989 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
990 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
991 |
+
|
992 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
993 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
994 |
+
|
995 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
996 |
+
|
997 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
998 |
+
if self.ref_model is not None:
|
999 |
+
ref_per_token_logps = self._get_per_token_logps(
|
1000 |
+
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
1001 |
+
)
|
1002 |
+
else:
|
1003 |
+
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
|
1004 |
+
ref_per_token_logps = self._get_per_token_logps(
|
1005 |
+
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
# Decode the generated completions
|
1009 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
1010 |
+
if is_conversational(inputs[0]):
|
1011 |
+
completions = []
|
1012 |
+
for prompt, completion in zip(prompts, completions_text):
|
1013 |
+
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
1014 |
+
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
1015 |
+
else:
|
1016 |
+
completions = completions_text
|
1017 |
+
|
1018 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
1019 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
1020 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
1021 |
+
):
|
1022 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1023 |
+
if is_conversational(inputs[0]):
|
1024 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
1025 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
1026 |
+
else:
|
1027 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
1028 |
+
reward_inputs = reward_processing_class(
|
1029 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
1030 |
+
)
|
1031 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
1032 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
1033 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
1034 |
+
else:
|
1035 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
1036 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
1037 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
1038 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
1039 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
1040 |
+
|
1041 |
+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
1042 |
+
# completions may be distributed across processes
|
1043 |
+
rewards_per_func = gather(rewards_per_func)
|
1044 |
+
|
1045 |
+
# Apply weights to each reward function's output and sum
|
1046 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
1047 |
+
|
1048 |
+
# Compute grouped-wise rewards
|
1049 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
1050 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
1051 |
+
|
1052 |
+
# Normalize the rewards to compute the advantages
|
1053 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1054 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1055 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
1056 |
+
|
1057 |
+
# Slice to keep only the local part of the data
|
1058 |
+
process_slice = slice(
|
1059 |
+
self.accelerator.process_index * len(prompts),
|
1060 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
1061 |
+
)
|
1062 |
+
advantages = advantages[process_slice]
|
1063 |
+
|
1064 |
+
# Log the metrics
|
1065 |
+
reward_per_func = rewards_per_func.mean(0)
|
1066 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
1067 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1068 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
1069 |
+
else:
|
1070 |
+
reward_func_name = reward_func.__name__
|
1071 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
1072 |
+
|
1073 |
+
self._metrics["reward"].append(rewards.mean().item())
|
1074 |
+
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
1075 |
+
|
1076 |
+
if (
|
1077 |
+
self.log_completions
|
1078 |
+
and self.state.global_step % self.args.logging_steps == 0
|
1079 |
+
and "wandb" in self.args.report_to
|
1080 |
+
):
|
1081 |
+
import pandas as pd
|
1082 |
+
|
1083 |
+
# For logging
|
1084 |
+
table = {
|
1085 |
+
"step": [str(self.state.global_step)] * len(rewards),
|
1086 |
+
"prompt": gather_object(prompts_text),
|
1087 |
+
"completion": gather_object(completions_text),
|
1088 |
+
"reward": rewards.tolist(),
|
1089 |
+
}
|
1090 |
+
df = pd.DataFrame(table)
|
1091 |
+
|
1092 |
+
if wandb.run is not None and self.accelerator.is_main_process:
|
1093 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1094 |
+
|
1095 |
+
return {
|
1096 |
+
"prompt_ids": prompt_ids,
|
1097 |
+
"prompt_mask": prompt_mask,
|
1098 |
+
"completion_ids": completion_ids,
|
1099 |
+
"completion_mask": completion_mask,
|
1100 |
+
"ref_per_token_logps": ref_per_token_logps,
|
1101 |
+
"advantages": advantages,
|
1102 |
+
}
|
1103 |
+
|
1104 |
+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
1105 |
+
if return_outputs:
|
1106 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
1107 |
+
# Compute the per-token log probabilities for the model
|
1108 |
+
|
1109 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
1110 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
1111 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
1112 |
+
bsz, qlen = input_ids.shape
|
1113 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
1114 |
+
# attention_mask = None
|
1115 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
1116 |
+
_input_ids = input_ids
|
1117 |
+
_logits_to_keep = logits_to_keep
|
1118 |
+
|
1119 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
1120 |
+
|
1121 |
+
# Compute the KL divergence between the model and the reference model
|
1122 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
1123 |
+
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
1124 |
+
|
1125 |
+
# x - x.detach() allows for preserving gradients from x
|
1126 |
+
advantages = inputs["advantages"]
|
1127 |
+
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
1128 |
+
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
1129 |
+
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1130 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
1131 |
+
if per_token_logps is not None:
|
1132 |
+
loss, completion_length, mean_kl = grpo_compute_loss_slow(
|
1133 |
+
ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
|
1134 |
+
)
|
1135 |
+
else:
|
1136 |
+
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
1137 |
+
self, _input_ids, logits_to_keep, completion_mask, advantages,
|
1138 |
+
n_chunks = self.args.unsloth_num_chunks,
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# Log the metrics
|
1142 |
+
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
1143 |
+
|
1144 |
+
# mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1145 |
+
# self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
1146 |
+
|
1147 |
+
if "train" in self._metrics:
|
1148 |
+
mode = "eval" if self.control.should_evaluate else "train"
|
1149 |
+
self._metrics[mode]["completion_length"].append(completion_length.item())
|
1150 |
+
self._metrics[mode]["kl"].append(mean_kl.item())
|
1151 |
+
else:
|
1152 |
+
self._metrics["completion_length"].append(completion_length.item())
|
1153 |
+
self._metrics["kl"].append(mean_kl.item())
|
1154 |
+
return loss
|
1155 |
+
|
1156 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
1157 |
+
inputs = self._prepare_inputs(inputs)
|
1158 |
+
with torch.no_grad():
|
1159 |
+
with self.compute_loss_context_manager():
|
1160 |
+
loss = self.compute_loss(model, inputs)
|
1161 |
+
loss = loss.mean().detach()
|
1162 |
+
return loss, None, None
|
1163 |
+
|
1164 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1165 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
1166 |
+
|
1167 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
1168 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
1169 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
1170 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
1171 |
+
|
1172 |
+
logs = {**logs, **metrics}
|
1173 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1174 |
+
super().log(logs, start_time)
|
1175 |
+
else: # transformers<=4.46
|
1176 |
+
super().log(logs)
|
1177 |
+
self._metrics.clear()
|
1178 |
+
|
1179 |
+
def create_model_card(
|
1180 |
+
self,
|
1181 |
+
model_name: Optional[str] = None,
|
1182 |
+
dataset_name: Optional[str] = None,
|
1183 |
+
tags: Union[str, list[str], None] = None,
|
1184 |
+
):
|
1185 |
+
"""
|
1186 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1187 |
+
|
1188 |
+
Args:
|
1189 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1190 |
+
Name of the model.
|
1191 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1192 |
+
Name of the dataset used for training.
|
1193 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1194 |
+
Tags to be associated with the model card.
|
1195 |
+
"""
|
1196 |
+
if not self.is_world_process_zero():
|
1197 |
+
return
|
1198 |
+
|
1199 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1200 |
+
base_model = self.model.config._name_or_path
|
1201 |
+
else:
|
1202 |
+
base_model = None
|
1203 |
+
|
1204 |
+
tags = tags or []
|
1205 |
+
if isinstance(tags, str):
|
1206 |
+
tags = [tags]
|
1207 |
+
|
1208 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1209 |
+
tags.append("unsloth")
|
1210 |
+
|
1211 |
+
citation = textwrap.dedent(
|
1212 |
+
"""\
|
1213 |
+
@article{zhihong2024deepseekmath,
|
1214 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
1215 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
1216 |
+
year = 2024,
|
1217 |
+
eprint = {arXiv:2402.03300},
|
1218 |
+
}
|
1219 |
+
"""
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
model_card = generate_model_card(
|
1223 |
+
base_model=base_model,
|
1224 |
+
model_name=model_name,
|
1225 |
+
hub_model_id=self.hub_model_id,
|
1226 |
+
dataset_name=dataset_name,
|
1227 |
+
tags=tags,
|
1228 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1229 |
+
comet_url=get_comet_experiment_url(),
|
1230 |
+
trainer_name="GRPO",
|
1231 |
+
trainer_citation=citation,
|
1232 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
1233 |
+
paper_id="2402.03300",
|
1234 |
+
)
|
1235 |
+
|
1236 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1237 |
+
class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
1238 |
+
"""
|
1239 |
+
|
1240 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
1241 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
1242 |
+
|
1243 |
+
Example:
|
1244 |
+
|
1245 |
+
```python
|
1246 |
+
from datasets import load_dataset
|
1247 |
+
from trl import GRPOTrainer
|
1248 |
+
|
1249 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
1250 |
+
|
1251 |
+
def reward_func(completions, **kwargs):
|
1252 |
+
# Dummy reward function that rewards completions with more unique letters.
|
1253 |
+
return [float(len(set(completion))) for completion in completions]
|
1254 |
+
|
1255 |
+
trainer = GRPOTrainer(
|
1256 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
1257 |
+
reward_funcs=reward_func,
|
1258 |
+
train_dataset=dataset,
|
1259 |
+
)
|
1260 |
+
|
1261 |
+
trainer.train()
|
1262 |
+
```
|
1263 |
+
|
1264 |
+
Args:
|
1265 |
+
model (`Union[str, PreTrainedModel]`):
|
1266 |
+
Model to be trained. Can be either:
|
1267 |
+
|
1268 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
1269 |
+
a path to a *directory* containing model weights saved using
|
1270 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
1271 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
1272 |
+
in `args.model_init_kwargs`.
|
1273 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
1274 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
1275 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
1276 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
1277 |
+
|
1278 |
+
- A single reward function, such as:
|
1279 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
1280 |
+
path to a *directory* containing model weights saved using
|
1281 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
1282 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
1283 |
+
keyword arguments in `args.model_init_kwargs`.
|
1284 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
1285 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
1286 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
1287 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
1288 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
1289 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
1290 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
1291 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
1292 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
1293 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
1294 |
+
ignored. The format of the samples can be either:
|
1295 |
+
|
1296 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
1297 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
1298 |
+
and content).
|
1299 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
1300 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
1301 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
1302 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
1303 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
1304 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
1305 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
1306 |
+
|
1307 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
1308 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
1309 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
1310 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
1311 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
1312 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
1313 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
1314 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
1315 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
1316 |
+
|
1317 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
1318 |
+
method.
|
1319 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
1320 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
1321 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
1322 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
1323 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
1324 |
+
|
1325 |
+
"""
|
1326 |
+
def __init__(
|
1327 |
+
self,
|
1328 |
+
model,
|
1329 |
+
reward_funcs,
|
1330 |
+
args = None,
|
1331 |
+
train_dataset = None,
|
1332 |
+
eval_dataset = None,
|
1333 |
+
processing_class = None,
|
1334 |
+
reward_processing_classes = None,
|
1335 |
+
callbacks = None,
|
1336 |
+
peft_config = None,
|
1337 |
+
**kwargs
|
1338 |
+
):
|
1339 |
+
if args is None: args = UnslothGRPOConfig()
|
1340 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1341 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1342 |
+
force_float32 = False
|
1343 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1344 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1345 |
+
force_float32 = True
|
1346 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1347 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1348 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1349 |
+
from unsloth_zoo.utils import _get_dtype
|
1350 |
+
dtype = _get_dtype(dtype)
|
1351 |
+
float16 = dtype == torch.float16
|
1352 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1353 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1354 |
+
if force_float32:
|
1355 |
+
args.fp16 = False
|
1356 |
+
args.bf16 = False
|
1357 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1358 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1359 |
+
args.fp16 = float16
|
1360 |
+
args.bf16 = not float16
|
1361 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1362 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1363 |
+
args.eval_strategy = 'steps'
|
1364 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1365 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1366 |
+
if ga_steps is not None and ga_steps > 1:
|
1367 |
+
from transformers import __version__ as transformers_version
|
1368 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1369 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1370 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1371 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1372 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1373 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1374 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1375 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1376 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1377 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1378 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1379 |
+
if force_float32:
|
1380 |
+
args.bf16_full_eval = False
|
1381 |
+
args.fp16_full_eval = False
|
1382 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1383 |
+
args.bf16_full_eval = True
|
1384 |
+
args.fp16_full_eval = False
|
1385 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1386 |
+
args.bf16_full_eval = args.bf16
|
1387 |
+
args.fp16_full_eval = args.fp16
|
1388 |
+
_output_logits = False
|
1389 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1390 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1391 |
+
if _output_logits:
|
1392 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1393 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1394 |
+
pass
|
1395 |
+
else:
|
1396 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1397 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1398 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1399 |
+
max_seq_length = model.max_seq_length
|
1400 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1401 |
+
if model is not None and hasattr(model, 'for_training'):
|
1402 |
+
model.for_training()
|
1403 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1404 |
+
if 'processing_class' in locals():
|
1405 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1406 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1407 |
+
other_metrics = []
|
1408 |
+
if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
|
1409 |
+
else: _reward_funcs = reward_funcs
|
1410 |
+
for reward_func in _reward_funcs:
|
1411 |
+
try:
|
1412 |
+
reward_func_name = reward_func.__name__
|
1413 |
+
other_metrics.append(f'rewards/{reward_func_name}')
|
1414 |
+
except: pass
|
1415 |
+
|
1416 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1417 |
+
PatchRLStatistics('grpo_trainer', other_metrics)
|
1418 |
+
|
1419 |
+
super().__init__(
|
1420 |
+
model = model,
|
1421 |
+
reward_funcs = reward_funcs,
|
1422 |
+
args = args,
|
1423 |
+
train_dataset = train_dataset,
|
1424 |
+
eval_dataset = eval_dataset,
|
1425 |
+
processing_class = processing_class,
|
1426 |
+
reward_processing_classes = reward_processing_classes,
|
1427 |
+
callbacks = callbacks,
|
1428 |
+
peft_config = peft_config,**kwargs)
|
1429 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1430 |
+
self.neftune_hook_handle.remove()
|
1431 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1432 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1433 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1434 |
+
pass
|
1435 |
+
|
1436 |
+
pass
|
unsloth_compiled_cache/UnslothKTOTrainer.py
ADDED
@@ -0,0 +1,1838 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothKTOConfig(KTOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`KTOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
+
reference model.
|
67 |
+
loss_type (`str`, *optional*, defaults to `"kto"`):
|
68 |
+
Type of loss to use. Possible values are:
|
69 |
+
|
70 |
+
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
71 |
+
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
|
72 |
+
|
73 |
+
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
74 |
+
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
75 |
+
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
76 |
+
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
77 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
78 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
79 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
80 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
81 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
82 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
83 |
+
This argument is required if you want to use the default data collator.
|
84 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
85 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
86 |
+
evaluation.
|
87 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
88 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
89 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
90 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
92 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
93 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
94 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
95 |
+
string.
|
96 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
98 |
+
from a string.
|
99 |
+
dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
|
100 |
+
Number of processes to use for processing the dataset.
|
101 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
102 |
+
Whether to disable dropout in the model and reference model.
|
103 |
+
|
104 |
+
"""
|
105 |
+
vllm_sampling_params: Optional[Any] = field(
|
106 |
+
default = None,
|
107 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
108 |
+
)
|
109 |
+
unsloth_num_chunks : Optional[int] = field(
|
110 |
+
default = -1,
|
111 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
112 |
+
)
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
output_dir = None,
|
116 |
+
overwrite_output_dir = None,
|
117 |
+
do_train = False,
|
118 |
+
do_eval = False,
|
119 |
+
do_predict = False,
|
120 |
+
eval_strategy = 'no',
|
121 |
+
prediction_loss_only = False,
|
122 |
+
per_device_train_batch_size = 4,
|
123 |
+
per_device_eval_batch_size = 4,
|
124 |
+
per_gpu_train_batch_size = None,
|
125 |
+
per_gpu_eval_batch_size = None,
|
126 |
+
gradient_accumulation_steps = 2,
|
127 |
+
eval_accumulation_steps = 2,
|
128 |
+
eval_delay = 0,
|
129 |
+
torch_empty_cache_steps = 250,
|
130 |
+
learning_rate = 5e-05,
|
131 |
+
weight_decay = 0.01,
|
132 |
+
adam_beta1 = 0.9,
|
133 |
+
adam_beta2 = 0.999,
|
134 |
+
adam_epsilon = 1e-08,
|
135 |
+
max_grad_norm = 1.0,
|
136 |
+
num_train_epochs = 3.0,
|
137 |
+
max_steps = -1,
|
138 |
+
lr_scheduler_type = 'linear',
|
139 |
+
warmup_ratio = 0.1,
|
140 |
+
warmup_steps = 0,
|
141 |
+
log_level = 'passive',
|
142 |
+
log_level_replica = 'warning',
|
143 |
+
log_on_each_node = True,
|
144 |
+
logging_dir = None,
|
145 |
+
logging_strategy = 'steps',
|
146 |
+
logging_first_step = False,
|
147 |
+
logging_steps = 1,
|
148 |
+
logging_nan_inf_filter = False,
|
149 |
+
save_strategy = 'steps',
|
150 |
+
save_steps = 500,
|
151 |
+
save_total_limit = None,
|
152 |
+
save_safetensors = True,
|
153 |
+
save_on_each_node = False,
|
154 |
+
save_only_model = False,
|
155 |
+
restore_callback_states_from_checkpoint = False,
|
156 |
+
no_cuda = False,
|
157 |
+
use_cpu = False,
|
158 |
+
use_mps_device = False,
|
159 |
+
seed = 3407,
|
160 |
+
data_seed = 3407,
|
161 |
+
jit_mode_eval = False,
|
162 |
+
use_ipex = False,
|
163 |
+
bf16 = False,
|
164 |
+
fp16 = False,
|
165 |
+
fp16_opt_level = 'O1',
|
166 |
+
half_precision_backend = 'auto',
|
167 |
+
bf16_full_eval = False,
|
168 |
+
fp16_full_eval = False,
|
169 |
+
tf32 = None,
|
170 |
+
local_rank = -1,
|
171 |
+
ddp_backend = None,
|
172 |
+
tpu_num_cores = None,
|
173 |
+
tpu_metrics_debug = False,
|
174 |
+
debug = '',
|
175 |
+
dataloader_drop_last = False,
|
176 |
+
eval_steps = None,
|
177 |
+
dataloader_num_workers = 0,
|
178 |
+
dataloader_prefetch_factor = None,
|
179 |
+
past_index = -1,
|
180 |
+
run_name = None,
|
181 |
+
disable_tqdm = None,
|
182 |
+
remove_unused_columns = True,
|
183 |
+
label_names = None,
|
184 |
+
load_best_model_at_end = False,
|
185 |
+
metric_for_best_model = None,
|
186 |
+
greater_is_better = None,
|
187 |
+
ignore_data_skip = False,
|
188 |
+
fsdp = '',
|
189 |
+
fsdp_min_num_params = 0,
|
190 |
+
fsdp_config = None,
|
191 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
192 |
+
accelerator_config = None,
|
193 |
+
deepspeed = None,
|
194 |
+
label_smoothing_factor = 0.0,
|
195 |
+
optim = 'adamw_8bit',
|
196 |
+
optim_args = None,
|
197 |
+
adafactor = False,
|
198 |
+
group_by_length = False,
|
199 |
+
length_column_name = 'length',
|
200 |
+
report_to = None,
|
201 |
+
ddp_find_unused_parameters = None,
|
202 |
+
ddp_bucket_cap_mb = None,
|
203 |
+
ddp_broadcast_buffers = None,
|
204 |
+
dataloader_pin_memory = True,
|
205 |
+
dataloader_persistent_workers = False,
|
206 |
+
skip_memory_metrics = True,
|
207 |
+
use_legacy_prediction_loop = False,
|
208 |
+
push_to_hub = False,
|
209 |
+
resume_from_checkpoint = None,
|
210 |
+
hub_model_id = None,
|
211 |
+
hub_strategy = 'every_save',
|
212 |
+
hub_token = None,
|
213 |
+
hub_private_repo = None,
|
214 |
+
hub_always_push = False,
|
215 |
+
gradient_checkpointing = False,
|
216 |
+
gradient_checkpointing_kwargs = None,
|
217 |
+
include_inputs_for_metrics = False,
|
218 |
+
eval_do_concat_batches = True,
|
219 |
+
fp16_backend = 'auto',
|
220 |
+
evaluation_strategy = None,
|
221 |
+
push_to_hub_model_id = None,
|
222 |
+
push_to_hub_organization = None,
|
223 |
+
push_to_hub_token = None,
|
224 |
+
mp_parameters = '',
|
225 |
+
auto_find_batch_size = False,
|
226 |
+
full_determinism = False,
|
227 |
+
torchdynamo = None,
|
228 |
+
ray_scope = 'last',
|
229 |
+
ddp_timeout = 1800,
|
230 |
+
torch_compile = False,
|
231 |
+
torch_compile_backend = None,
|
232 |
+
torch_compile_mode = None,
|
233 |
+
dispatch_batches = None,
|
234 |
+
split_batches = None,
|
235 |
+
include_tokens_per_second = False,
|
236 |
+
include_num_input_tokens_seen = False,
|
237 |
+
neftune_noise_alpha = None,
|
238 |
+
optim_target_modules = None,
|
239 |
+
batch_eval_metrics = False,
|
240 |
+
eval_on_start = False,
|
241 |
+
use_liger_kernel = False,
|
242 |
+
eval_use_gather_object = False,
|
243 |
+
average_tokens_across_devices = False,
|
244 |
+
max_length = 1024,
|
245 |
+
max_prompt_length = 512,
|
246 |
+
max_completion_length = None,
|
247 |
+
beta = 0.1,
|
248 |
+
loss_type = 'kto',
|
249 |
+
desirable_weight = 1.0,
|
250 |
+
undesirable_weight = 1.0,
|
251 |
+
label_pad_token_id = -100,
|
252 |
+
padding_value = None,
|
253 |
+
truncation_mode = 'keep_end',
|
254 |
+
generate_during_eval = False,
|
255 |
+
is_encoder_decoder = None,
|
256 |
+
disable_dropout = True,
|
257 |
+
precompute_ref_log_probs = False,
|
258 |
+
model_init_kwargs = None,
|
259 |
+
ref_model_init_kwargs = None,
|
260 |
+
dataset_num_proc = None,
|
261 |
+
vllm_sampling_params = None,
|
262 |
+
unsloth_num_chunks = -1,
|
263 |
+
**kwargs,
|
264 |
+
):
|
265 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
266 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
267 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
268 |
+
output_dir = 'unsloth_training_checkpoints'
|
269 |
+
save_strategy = 'no'
|
270 |
+
if dataset_num_proc is None:
|
271 |
+
from multiprocessing import cpu_count
|
272 |
+
dataset_num_proc = cpu_count()
|
273 |
+
|
274 |
+
super().__init__(
|
275 |
+
output_dir = output_dir,
|
276 |
+
overwrite_output_dir = overwrite_output_dir,
|
277 |
+
do_train = do_train,
|
278 |
+
do_eval = do_eval,
|
279 |
+
do_predict = do_predict,
|
280 |
+
eval_strategy = eval_strategy,
|
281 |
+
prediction_loss_only = prediction_loss_only,
|
282 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
283 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
284 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
285 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
286 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
287 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
288 |
+
eval_delay = eval_delay,
|
289 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
290 |
+
learning_rate = learning_rate,
|
291 |
+
weight_decay = weight_decay,
|
292 |
+
adam_beta1 = adam_beta1,
|
293 |
+
adam_beta2 = adam_beta2,
|
294 |
+
adam_epsilon = adam_epsilon,
|
295 |
+
max_grad_norm = max_grad_norm,
|
296 |
+
num_train_epochs = num_train_epochs,
|
297 |
+
max_steps = max_steps,
|
298 |
+
lr_scheduler_type = lr_scheduler_type,
|
299 |
+
warmup_ratio = warmup_ratio,
|
300 |
+
warmup_steps = warmup_steps,
|
301 |
+
log_level = log_level,
|
302 |
+
log_level_replica = log_level_replica,
|
303 |
+
log_on_each_node = log_on_each_node,
|
304 |
+
logging_dir = logging_dir,
|
305 |
+
logging_strategy = logging_strategy,
|
306 |
+
logging_first_step = logging_first_step,
|
307 |
+
logging_steps = logging_steps,
|
308 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
309 |
+
save_strategy = save_strategy,
|
310 |
+
save_steps = save_steps,
|
311 |
+
save_total_limit = save_total_limit,
|
312 |
+
save_safetensors = save_safetensors,
|
313 |
+
save_on_each_node = save_on_each_node,
|
314 |
+
save_only_model = save_only_model,
|
315 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
316 |
+
no_cuda = no_cuda,
|
317 |
+
use_cpu = use_cpu,
|
318 |
+
use_mps_device = use_mps_device,
|
319 |
+
seed = seed,
|
320 |
+
data_seed = data_seed,
|
321 |
+
jit_mode_eval = jit_mode_eval,
|
322 |
+
use_ipex = use_ipex,
|
323 |
+
bf16 = bf16,
|
324 |
+
fp16 = fp16,
|
325 |
+
fp16_opt_level = fp16_opt_level,
|
326 |
+
half_precision_backend = half_precision_backend,
|
327 |
+
bf16_full_eval = bf16_full_eval,
|
328 |
+
fp16_full_eval = fp16_full_eval,
|
329 |
+
tf32 = tf32,
|
330 |
+
local_rank = local_rank,
|
331 |
+
ddp_backend = ddp_backend,
|
332 |
+
tpu_num_cores = tpu_num_cores,
|
333 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
334 |
+
debug = debug,
|
335 |
+
dataloader_drop_last = dataloader_drop_last,
|
336 |
+
eval_steps = eval_steps,
|
337 |
+
dataloader_num_workers = dataloader_num_workers,
|
338 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
339 |
+
past_index = past_index,
|
340 |
+
run_name = run_name,
|
341 |
+
disable_tqdm = disable_tqdm,
|
342 |
+
remove_unused_columns = remove_unused_columns,
|
343 |
+
label_names = label_names,
|
344 |
+
load_best_model_at_end = load_best_model_at_end,
|
345 |
+
metric_for_best_model = metric_for_best_model,
|
346 |
+
greater_is_better = greater_is_better,
|
347 |
+
ignore_data_skip = ignore_data_skip,
|
348 |
+
fsdp = fsdp,
|
349 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
350 |
+
fsdp_config = fsdp_config,
|
351 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
352 |
+
accelerator_config = accelerator_config,
|
353 |
+
deepspeed = deepspeed,
|
354 |
+
label_smoothing_factor = label_smoothing_factor,
|
355 |
+
optim = optim,
|
356 |
+
optim_args = optim_args,
|
357 |
+
adafactor = adafactor,
|
358 |
+
group_by_length = group_by_length,
|
359 |
+
length_column_name = length_column_name,
|
360 |
+
report_to = report_to,
|
361 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
362 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
363 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
364 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
365 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
366 |
+
skip_memory_metrics = skip_memory_metrics,
|
367 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
368 |
+
push_to_hub = push_to_hub,
|
369 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
370 |
+
hub_model_id = hub_model_id,
|
371 |
+
hub_strategy = hub_strategy,
|
372 |
+
hub_token = hub_token,
|
373 |
+
hub_private_repo = hub_private_repo,
|
374 |
+
hub_always_push = hub_always_push,
|
375 |
+
gradient_checkpointing = gradient_checkpointing,
|
376 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
377 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
378 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
379 |
+
fp16_backend = fp16_backend,
|
380 |
+
evaluation_strategy = evaluation_strategy,
|
381 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
382 |
+
push_to_hub_organization = push_to_hub_organization,
|
383 |
+
push_to_hub_token = push_to_hub_token,
|
384 |
+
mp_parameters = mp_parameters,
|
385 |
+
auto_find_batch_size = auto_find_batch_size,
|
386 |
+
full_determinism = full_determinism,
|
387 |
+
torchdynamo = torchdynamo,
|
388 |
+
ray_scope = ray_scope,
|
389 |
+
ddp_timeout = ddp_timeout,
|
390 |
+
torch_compile = torch_compile,
|
391 |
+
torch_compile_backend = torch_compile_backend,
|
392 |
+
torch_compile_mode = torch_compile_mode,
|
393 |
+
dispatch_batches = dispatch_batches,
|
394 |
+
split_batches = split_batches,
|
395 |
+
include_tokens_per_second = include_tokens_per_second,
|
396 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
397 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
398 |
+
optim_target_modules = optim_target_modules,
|
399 |
+
batch_eval_metrics = batch_eval_metrics,
|
400 |
+
eval_on_start = eval_on_start,
|
401 |
+
use_liger_kernel = use_liger_kernel,
|
402 |
+
eval_use_gather_object = eval_use_gather_object,
|
403 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
404 |
+
max_length = max_length,
|
405 |
+
max_prompt_length = max_prompt_length,
|
406 |
+
max_completion_length = max_completion_length,
|
407 |
+
beta = beta,
|
408 |
+
loss_type = loss_type,
|
409 |
+
desirable_weight = desirable_weight,
|
410 |
+
undesirable_weight = undesirable_weight,
|
411 |
+
label_pad_token_id = label_pad_token_id,
|
412 |
+
padding_value = padding_value,
|
413 |
+
truncation_mode = truncation_mode,
|
414 |
+
generate_during_eval = generate_during_eval,
|
415 |
+
is_encoder_decoder = is_encoder_decoder,
|
416 |
+
disable_dropout = disable_dropout,
|
417 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
418 |
+
model_init_kwargs = model_init_kwargs,
|
419 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
420 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
421 |
+
self.vllm_sampling_params = vllm_sampling_params
|
422 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
423 |
+
pass
|
424 |
+
|
425 |
+
class _UnslothKTOTrainer(Trainer):
|
426 |
+
r""""""
|
427 |
+
|
428 |
+
_tag_names = ["trl", "kto"]
|
429 |
+
|
430 |
+
def __init__(
|
431 |
+
self,
|
432 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
433 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
434 |
+
args: KTOConfig = None,
|
435 |
+
train_dataset: Optional[Dataset] = None,
|
436 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
437 |
+
processing_class: Optional[
|
438 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
439 |
+
] = None,
|
440 |
+
data_collator: Optional[DataCollator] = None,
|
441 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
442 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
443 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
444 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
445 |
+
peft_config: Optional[dict] = None,
|
446 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
447 |
+
model_adapter_name: Optional[str] = None,
|
448 |
+
ref_adapter_name: Optional[str] = None,
|
449 |
+
):
|
450 |
+
if type(args) is TrainingArguments:
|
451 |
+
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
452 |
+
|
453 |
+
if not isinstance(model, str) and ref_model is model:
|
454 |
+
raise ValueError(
|
455 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
456 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
457 |
+
)
|
458 |
+
|
459 |
+
if args.model_init_kwargs is None:
|
460 |
+
model_init_kwargs = {}
|
461 |
+
elif not isinstance(model, str):
|
462 |
+
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
463 |
+
else:
|
464 |
+
model_init_kwargs = args.model_init_kwargs
|
465 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
466 |
+
if torch_dtype is not None:
|
467 |
+
# Convert to `torch.dtype` if an str is passed
|
468 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
469 |
+
torch_dtype = getattr(torch, torch_dtype)
|
470 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
471 |
+
raise ValueError(
|
472 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
473 |
+
)
|
474 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
475 |
+
|
476 |
+
if args.ref_model_init_kwargs is None:
|
477 |
+
ref_model_init_kwargs = {}
|
478 |
+
elif not isinstance(ref_model, str):
|
479 |
+
raise ValueError(
|
480 |
+
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
484 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
485 |
+
if torch_dtype is not None:
|
486 |
+
# Convert to `torch.dtype` if an str is passed
|
487 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
488 |
+
torch_dtype = getattr(torch, torch_dtype)
|
489 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
490 |
+
raise ValueError(
|
491 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
492 |
+
)
|
493 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
494 |
+
|
495 |
+
if isinstance(model, str):
|
496 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
497 |
+
|
498 |
+
if isinstance(ref_model, str):
|
499 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
500 |
+
|
501 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
502 |
+
# has been called in order to properly call autocast if needed.
|
503 |
+
self._peft_has_been_casted_to_bf16 = False
|
504 |
+
|
505 |
+
if not is_peft_available() and peft_config is not None:
|
506 |
+
raise ValueError(
|
507 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
508 |
+
)
|
509 |
+
elif is_peft_available() and peft_config is not None:
|
510 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
511 |
+
if isinstance(model, PeftModel):
|
512 |
+
model = model.merge_and_unload()
|
513 |
+
|
514 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
515 |
+
_support_gc_kwargs = hasattr(
|
516 |
+
args, "gradient_checkpointing_kwargs"
|
517 |
+
) and "gradient_checkpointing_kwargs" in list(
|
518 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
519 |
+
)
|
520 |
+
|
521 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
522 |
+
|
523 |
+
if _support_gc_kwargs:
|
524 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
525 |
+
|
526 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
527 |
+
elif getattr(args, "gradient_checkpointing", False):
|
528 |
+
# For backward compatibility with older versions of transformers
|
529 |
+
if hasattr(model, "enable_input_require_grads"):
|
530 |
+
model.enable_input_require_grads()
|
531 |
+
else:
|
532 |
+
|
533 |
+
def make_inputs_require_grad(module, input, output):
|
534 |
+
output.requires_grad_(True)
|
535 |
+
|
536 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
537 |
+
|
538 |
+
# get peft model with the given config
|
539 |
+
model = model
|
540 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
541 |
+
peft_module_casting_to_bf16(model)
|
542 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
543 |
+
self._peft_has_been_casted_to_bf16 = True
|
544 |
+
|
545 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
546 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
547 |
+
# fail or completely fail.
|
548 |
+
elif getattr(args, "gradient_checkpointing", False):
|
549 |
+
# For backward compatibility with older versions of transformers
|
550 |
+
if hasattr(model, "enable_input_require_grads"):
|
551 |
+
model.enable_input_require_grads()
|
552 |
+
else:
|
553 |
+
|
554 |
+
def make_inputs_require_grad(module, input, output):
|
555 |
+
output.requires_grad_(True)
|
556 |
+
|
557 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
558 |
+
|
559 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
560 |
+
raise ValueError(
|
561 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
562 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
563 |
+
)
|
564 |
+
|
565 |
+
if model is not None:
|
566 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
567 |
+
elif args.is_encoder_decoder is None:
|
568 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
569 |
+
else:
|
570 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
571 |
+
|
572 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
573 |
+
self.model_adapter_name = model_adapter_name
|
574 |
+
self.ref_adapter_name = ref_adapter_name
|
575 |
+
|
576 |
+
if ref_model:
|
577 |
+
self.ref_model = ref_model
|
578 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
579 |
+
# The `model` with adapters turned off will be used as the reference model
|
580 |
+
self.ref_model = None
|
581 |
+
else:
|
582 |
+
self.ref_model = create_reference_model(model)
|
583 |
+
|
584 |
+
if processing_class is None:
|
585 |
+
raise ValueError(
|
586 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
587 |
+
)
|
588 |
+
if args.max_length is None:
|
589 |
+
warnings.warn(
|
590 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
591 |
+
" it will be set to `512` by default, but you should do it yourself in the future.",
|
592 |
+
UserWarning,
|
593 |
+
)
|
594 |
+
max_length = 512
|
595 |
+
if args.max_length is not None:
|
596 |
+
max_length = args.max_length
|
597 |
+
|
598 |
+
if args.max_prompt_length is None:
|
599 |
+
warnings.warn(
|
600 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
601 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
602 |
+
UserWarning,
|
603 |
+
)
|
604 |
+
max_prompt_length = 128
|
605 |
+
if args.max_prompt_length is not None:
|
606 |
+
max_prompt_length = args.max_prompt_length
|
607 |
+
|
608 |
+
max_completion_length = None
|
609 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
610 |
+
warnings.warn(
|
611 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
612 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
613 |
+
UserWarning,
|
614 |
+
)
|
615 |
+
max_completion_length = 128
|
616 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
617 |
+
max_completion_length = args.max_completion_length
|
618 |
+
|
619 |
+
if data_collator is None:
|
620 |
+
data_collator = DPODataCollatorWithPadding(
|
621 |
+
pad_token_id=processing_class.pad_token_id,
|
622 |
+
label_pad_token_id=args.label_pad_token_id,
|
623 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
624 |
+
)
|
625 |
+
|
626 |
+
if args.remove_unused_columns:
|
627 |
+
args.remove_unused_columns = False
|
628 |
+
# warn users
|
629 |
+
warnings.warn(
|
630 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
631 |
+
" we have set it for you, but you should do it yourself in the future.",
|
632 |
+
UserWarning,
|
633 |
+
)
|
634 |
+
|
635 |
+
self.use_dpo_data_collator = True
|
636 |
+
else:
|
637 |
+
self.use_dpo_data_collator = False
|
638 |
+
|
639 |
+
# Disable dropout in the model and reference model
|
640 |
+
if args.disable_dropout:
|
641 |
+
disable_dropout_in_model(model)
|
642 |
+
if self.ref_model is not None:
|
643 |
+
disable_dropout_in_model(self.ref_model)
|
644 |
+
|
645 |
+
self.loss_type = args.loss_type
|
646 |
+
self.max_length = max_length
|
647 |
+
self.generate_during_eval = args.generate_during_eval
|
648 |
+
self.label_pad_token_id = args.label_pad_token_id
|
649 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
650 |
+
self.max_prompt_length = max_prompt_length
|
651 |
+
self.truncation_mode = args.truncation_mode
|
652 |
+
self.max_completion_length = max_completion_length
|
653 |
+
self.processing_class = processing_class
|
654 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
655 |
+
|
656 |
+
# Not all losses require a KL calculation
|
657 |
+
self.calculate_KL = True
|
658 |
+
if self.loss_type in ["apo_zero_unpaired"]:
|
659 |
+
self.calculate_KL = False
|
660 |
+
|
661 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
662 |
+
# keep track of first called to avoid computation of future calls
|
663 |
+
self._precomputed_train_ref_log_probs = False
|
664 |
+
self._precomputed_eval_ref_log_probs = False
|
665 |
+
|
666 |
+
# metric
|
667 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
668 |
+
|
669 |
+
# KTO parameter
|
670 |
+
self.beta = args.beta
|
671 |
+
self.desirable_weight = args.desirable_weight
|
672 |
+
self.undesirable_weight = args.undesirable_weight
|
673 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
674 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
675 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
676 |
+
warnings.warn(
|
677 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
678 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
679 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
680 |
+
"loss.",
|
681 |
+
UserWarning,
|
682 |
+
)
|
683 |
+
|
684 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
685 |
+
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
686 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
687 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
688 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
689 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
690 |
+
# issued.
|
691 |
+
model.warnings_issued["estimate_tokens"] = True
|
692 |
+
|
693 |
+
# Compute that only on the main process for faster data processing.
|
694 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
695 |
+
with PartialState().local_main_process_first():
|
696 |
+
# Extract the prompt if needed
|
697 |
+
train_dataset = train_dataset.map(
|
698 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
699 |
+
)
|
700 |
+
# Unpair the dataset if needed
|
701 |
+
train_dataset = maybe_unpair_preference_dataset(
|
702 |
+
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
703 |
+
)
|
704 |
+
# Apply the chat template if needed
|
705 |
+
train_dataset = train_dataset.map(
|
706 |
+
maybe_apply_chat_template,
|
707 |
+
fn_kwargs={"tokenizer": processing_class},
|
708 |
+
num_proc=args.dataset_num_proc,
|
709 |
+
desc="Applying chat template to train dataset",
|
710 |
+
)
|
711 |
+
if eval_dataset is not None:
|
712 |
+
eval_dataset = eval_dataset.map(
|
713 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
714 |
+
)
|
715 |
+
eval_dataset = maybe_unpair_preference_dataset(
|
716 |
+
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
717 |
+
)
|
718 |
+
eval_dataset = eval_dataset.map(
|
719 |
+
maybe_apply_chat_template,
|
720 |
+
fn_kwargs={"tokenizer": processing_class},
|
721 |
+
num_proc=args.dataset_num_proc,
|
722 |
+
desc="Applying chat template to eval dataset",
|
723 |
+
)
|
724 |
+
|
725 |
+
# Tokenize and prepare the training datasets
|
726 |
+
train_dataset = train_dataset.map(
|
727 |
+
_tokenize,
|
728 |
+
batched=True,
|
729 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
730 |
+
num_proc=args.dataset_num_proc,
|
731 |
+
desc="Tokenizing train dataset",
|
732 |
+
)
|
733 |
+
|
734 |
+
fn_kwargs = {
|
735 |
+
"prefix": "",
|
736 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
737 |
+
"tokenizer": self.processing_class,
|
738 |
+
"max_length": self.max_length,
|
739 |
+
"truncation_mode": self.truncation_mode,
|
740 |
+
"label_pad_token_id": self.label_pad_token_id,
|
741 |
+
"max_prompt_length": self.max_prompt_length,
|
742 |
+
"max_completion_length": self.max_completion_length,
|
743 |
+
}
|
744 |
+
|
745 |
+
train_dataset = train_dataset.map(
|
746 |
+
_process_tokens,
|
747 |
+
fn_kwargs=fn_kwargs,
|
748 |
+
num_proc=args.dataset_num_proc,
|
749 |
+
desc="Processing tokenized train dataset",
|
750 |
+
)
|
751 |
+
|
752 |
+
# Tokenize and prepare the eval datasets
|
753 |
+
if eval_dataset is not None:
|
754 |
+
eval_dataset = eval_dataset.map(
|
755 |
+
_tokenize,
|
756 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
757 |
+
batched=True,
|
758 |
+
num_proc=args.dataset_num_proc,
|
759 |
+
desc="Tokenizing eval dataset",
|
760 |
+
)
|
761 |
+
|
762 |
+
eval_dataset = eval_dataset.map(
|
763 |
+
_process_tokens,
|
764 |
+
fn_kwargs=fn_kwargs,
|
765 |
+
num_proc=args.dataset_num_proc,
|
766 |
+
desc="Processing tokenized eval dataset",
|
767 |
+
)
|
768 |
+
|
769 |
+
# Get KL datasets if needed
|
770 |
+
if self.calculate_KL:
|
771 |
+
if args.per_device_train_batch_size <= 1:
|
772 |
+
raise ValueError(
|
773 |
+
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
774 |
+
)
|
775 |
+
|
776 |
+
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
777 |
+
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
778 |
+
train_kl_dataset = train_dataset.map(
|
779 |
+
_get_kl_dataset,
|
780 |
+
batched=True,
|
781 |
+
batch_size=args.per_device_train_batch_size,
|
782 |
+
num_proc=args.dataset_num_proc,
|
783 |
+
desc="Extracting KL train dataset",
|
784 |
+
)
|
785 |
+
|
786 |
+
fn_kwargs["prefix"] = "KL_"
|
787 |
+
train_kl_dataset = train_kl_dataset.map(
|
788 |
+
_process_tokens,
|
789 |
+
fn_kwargs=fn_kwargs,
|
790 |
+
num_proc=args.dataset_num_proc,
|
791 |
+
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
792 |
+
desc="Processing tokenized train KL dataset",
|
793 |
+
)
|
794 |
+
|
795 |
+
# merge the datasets
|
796 |
+
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
797 |
+
|
798 |
+
if eval_dataset is not None:
|
799 |
+
# Get KL dataset
|
800 |
+
eval_kl_dataset = eval_dataset.map(
|
801 |
+
_get_kl_dataset,
|
802 |
+
batched=True,
|
803 |
+
batch_size=args.per_device_train_batch_size,
|
804 |
+
num_proc=args.dataset_num_proc,
|
805 |
+
desc="Extracting eval KL dataset",
|
806 |
+
)
|
807 |
+
|
808 |
+
eval_kl_dataset = eval_kl_dataset.map(
|
809 |
+
_process_tokens,
|
810 |
+
fn_kwargs=fn_kwargs,
|
811 |
+
num_proc=args.dataset_num_proc,
|
812 |
+
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
813 |
+
desc="Processing tokenized eval KL dataset",
|
814 |
+
)
|
815 |
+
|
816 |
+
# merge the datasets
|
817 |
+
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
818 |
+
|
819 |
+
# calculate dataset desirability balance
|
820 |
+
num_desirable = max(sum(train_dataset["label"]), 1)
|
821 |
+
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
822 |
+
|
823 |
+
if num_desirable != num_undesirable:
|
824 |
+
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
825 |
+
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
826 |
+
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
827 |
+
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
828 |
+
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
829 |
+
|
830 |
+
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
831 |
+
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
832 |
+
|
833 |
+
if not (des_weight_in_range or und_weight_in_range):
|
834 |
+
warnings.warn(
|
835 |
+
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
836 |
+
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
837 |
+
f"on your data, we recommend EITHER "
|
838 |
+
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
839 |
+
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
840 |
+
"See the documentation on how to optimally set these weights.",
|
841 |
+
UserWarning,
|
842 |
+
)
|
843 |
+
|
844 |
+
super().__init__(
|
845 |
+
model=model,
|
846 |
+
args=args,
|
847 |
+
data_collator=data_collator,
|
848 |
+
train_dataset=train_dataset,
|
849 |
+
eval_dataset=eval_dataset,
|
850 |
+
processing_class=processing_class,
|
851 |
+
model_init=model_init,
|
852 |
+
compute_metrics=compute_metrics,
|
853 |
+
callbacks=callbacks,
|
854 |
+
optimizers=optimizers,
|
855 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
856 |
+
)
|
857 |
+
|
858 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
859 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
860 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
861 |
+
self.model_accepts_loss_kwargs = False
|
862 |
+
|
863 |
+
# Add tags for models that have been loaded with the correct transformers version
|
864 |
+
if hasattr(self.model, "add_model_tags"):
|
865 |
+
self.model.add_model_tags(self._tag_names)
|
866 |
+
|
867 |
+
if not hasattr(self, "accelerator"):
|
868 |
+
raise AttributeError(
|
869 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
870 |
+
)
|
871 |
+
|
872 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
873 |
+
if self.is_deepspeed_enabled:
|
874 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
875 |
+
raise ValueError(
|
876 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
877 |
+
)
|
878 |
+
|
879 |
+
if self.ref_model is None:
|
880 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
881 |
+
raise ValueError(
|
882 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
883 |
+
)
|
884 |
+
else:
|
885 |
+
if self.is_deepspeed_enabled:
|
886 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
887 |
+
else:
|
888 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
889 |
+
|
890 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
891 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
892 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
893 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
894 |
+
|
895 |
+
if model is not None:
|
896 |
+
if hasattr(model, "config"):
|
897 |
+
hidden_size = (
|
898 |
+
max(model.config.hidden_sizes)
|
899 |
+
if getattr(model.config, "hidden_sizes", None)
|
900 |
+
else getattr(model.config, "hidden_size", None)
|
901 |
+
)
|
902 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
903 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
904 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
905 |
+
config_kwargs.update(
|
906 |
+
{
|
907 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
908 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
909 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
910 |
+
}
|
911 |
+
)
|
912 |
+
|
913 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
914 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
915 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
916 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
917 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
918 |
+
model.eval()
|
919 |
+
return model
|
920 |
+
|
921 |
+
@contextmanager
|
922 |
+
def null_ref_context(self):
|
923 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
924 |
+
with (
|
925 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
926 |
+
if self.is_peft_model and not self.ref_adapter_name
|
927 |
+
else nullcontext()
|
928 |
+
):
|
929 |
+
if self.ref_adapter_name:
|
930 |
+
self.model.set_adapter(self.ref_adapter_name)
|
931 |
+
yield
|
932 |
+
if self.ref_adapter_name:
|
933 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
934 |
+
|
935 |
+
def get_train_dataloader(self) -> DataLoader:
|
936 |
+
"""
|
937 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
938 |
+
|
939 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
940 |
+
"""
|
941 |
+
|
942 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
943 |
+
dataloader_params = {
|
944 |
+
"batch_size": self.args.per_device_train_batch_size,
|
945 |
+
"collate_fn": self.data_collator,
|
946 |
+
"num_workers": self.args.dataloader_num_workers,
|
947 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
948 |
+
"shuffle": False,
|
949 |
+
}
|
950 |
+
|
951 |
+
# prepare dataloader
|
952 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
953 |
+
reference_completion_logps = []
|
954 |
+
reference_KL_logps = []
|
955 |
+
|
956 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
957 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
958 |
+
|
959 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
960 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
961 |
+
|
962 |
+
if self.calculate_KL:
|
963 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
964 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
965 |
+
|
966 |
+
self.train_dataset = self.train_dataset.add_column(
|
967 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
968 |
+
)
|
969 |
+
|
970 |
+
if self.calculate_KL:
|
971 |
+
self.train_dataset = self.train_dataset.add_column(
|
972 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
973 |
+
)
|
974 |
+
|
975 |
+
self._precomputed_train_ref_log_probs = True
|
976 |
+
|
977 |
+
return super().get_train_dataloader()
|
978 |
+
|
979 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
980 |
+
"""
|
981 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
982 |
+
|
983 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
987 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
988 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
989 |
+
"""
|
990 |
+
if eval_dataset is None and self.eval_dataset is None:
|
991 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
992 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
993 |
+
|
994 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
995 |
+
dataloader_params = {
|
996 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
997 |
+
"collate_fn": self.data_collator,
|
998 |
+
"num_workers": self.args.dataloader_num_workers,
|
999 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1000 |
+
"shuffle": False,
|
1001 |
+
}
|
1002 |
+
|
1003 |
+
# prepare dataloader
|
1004 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1005 |
+
|
1006 |
+
reference_completion_logps = []
|
1007 |
+
reference_KL_logps = []
|
1008 |
+
|
1009 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1010 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
1011 |
+
|
1012 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1013 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1014 |
+
|
1015 |
+
if self.calculate_KL:
|
1016 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
1017 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
1018 |
+
|
1019 |
+
eval_dataset = eval_dataset.add_column(
|
1020 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1021 |
+
)
|
1022 |
+
if self.calculate_KL:
|
1023 |
+
eval_dataset = eval_dataset.add_column(
|
1024 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1028 |
+
if self.eval_dataset is not None:
|
1029 |
+
self.eval_dataset = eval_dataset
|
1030 |
+
self._precomputed_eval_ref_log_probs = True
|
1031 |
+
|
1032 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1033 |
+
|
1034 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1035 |
+
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
1036 |
+
with torch.no_grad():
|
1037 |
+
if self.ref_model is None:
|
1038 |
+
with self.null_ref_context():
|
1039 |
+
if self.is_encoder_decoder:
|
1040 |
+
completion_logits = self.model(
|
1041 |
+
padded_batch["prompt_input_ids"],
|
1042 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1043 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1044 |
+
labels=padded_batch["completion_labels"],
|
1045 |
+
).logits
|
1046 |
+
|
1047 |
+
if self.calculate_KL:
|
1048 |
+
KL_logits = self.model(
|
1049 |
+
padded_batch["KL_prompt_input_ids"],
|
1050 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1051 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1052 |
+
labels=padded_batch["KL_completion_labels"],
|
1053 |
+
).logits
|
1054 |
+
else:
|
1055 |
+
completion_logits = self.model(
|
1056 |
+
padded_batch["completion_input_ids"],
|
1057 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
1058 |
+
).logits
|
1059 |
+
|
1060 |
+
if self.calculate_KL:
|
1061 |
+
KL_logits = self.model(
|
1062 |
+
padded_batch["KL_completion_input_ids"],
|
1063 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1064 |
+
).logits
|
1065 |
+
else:
|
1066 |
+
if self.is_encoder_decoder:
|
1067 |
+
completion_logits = self.ref_model(
|
1068 |
+
padded_batch["prompt_input_ids"],
|
1069 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1070 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1071 |
+
labels=padded_batch["completion_labels"],
|
1072 |
+
).logits
|
1073 |
+
|
1074 |
+
if self.calculate_KL:
|
1075 |
+
KL_logits = self.ref_model(
|
1076 |
+
padded_batch["KL_prompt_input_ids"],
|
1077 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1078 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1079 |
+
labels=padded_batch["KL_completion_labels"],
|
1080 |
+
).logits
|
1081 |
+
else:
|
1082 |
+
completion_logits = self.ref_model(
|
1083 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1084 |
+
).logits
|
1085 |
+
|
1086 |
+
if self.calculate_KL:
|
1087 |
+
KL_logits = self.ref_model(
|
1088 |
+
padded_batch["KL_completion_input_ids"],
|
1089 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1090 |
+
).logits
|
1091 |
+
|
1092 |
+
completion_logps = self.get_batch_logps(
|
1093 |
+
completion_logits,
|
1094 |
+
padded_batch["completion_labels"],
|
1095 |
+
average_log_prob=False,
|
1096 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1097 |
+
label_pad_token_id=self.label_pad_token_id,
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
if self.calculate_KL:
|
1101 |
+
KL_logps = self.get_batch_logps(
|
1102 |
+
KL_logits,
|
1103 |
+
padded_batch["KL_completion_labels"],
|
1104 |
+
average_log_prob=False,
|
1105 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1106 |
+
label_pad_token_id=self.label_pad_token_id,
|
1107 |
+
)
|
1108 |
+
else:
|
1109 |
+
KL_logps = None
|
1110 |
+
|
1111 |
+
return completion_logps, KL_logps
|
1112 |
+
|
1113 |
+
@staticmethod
|
1114 |
+
def get_batch_logps(
|
1115 |
+
logits: torch.FloatTensor,
|
1116 |
+
labels: torch.LongTensor,
|
1117 |
+
average_log_prob: bool = False,
|
1118 |
+
label_pad_token_id: int = -100,
|
1119 |
+
is_encoder_decoder: bool = False,
|
1120 |
+
) -> torch.FloatTensor:
|
1121 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1122 |
+
|
1123 |
+
Args:
|
1124 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1125 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1126 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1127 |
+
|
1128 |
+
Returns:
|
1129 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1130 |
+
"""
|
1131 |
+
if logits.shape[:-1] != labels.shape:
|
1132 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1133 |
+
|
1134 |
+
if not is_encoder_decoder:
|
1135 |
+
labels = labels[:, 1:].clone()
|
1136 |
+
logits = logits[:, :-1, :]
|
1137 |
+
else:
|
1138 |
+
# Fixes end-dec RuntimeError
|
1139 |
+
labels = labels.clone()
|
1140 |
+
|
1141 |
+
loss_mask = labels != label_pad_token_id
|
1142 |
+
|
1143 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1144 |
+
labels[labels == label_pad_token_id] = 0
|
1145 |
+
|
1146 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1147 |
+
|
1148 |
+
if average_log_prob:
|
1149 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1150 |
+
else:
|
1151 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1152 |
+
|
1153 |
+
def forward(
|
1154 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1155 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1156 |
+
if self.calculate_KL:
|
1157 |
+
KL_logps = None
|
1158 |
+
KL_model_kwargs = (
|
1159 |
+
{
|
1160 |
+
"input_ids": batch["KL_prompt_input_ids"],
|
1161 |
+
"attention_mask": batch["KL_prompt_attention_mask"],
|
1162 |
+
"labels": batch["KL_completion_labels"],
|
1163 |
+
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
|
1164 |
+
}
|
1165 |
+
if self.is_encoder_decoder
|
1166 |
+
else {
|
1167 |
+
"input_ids": batch["KL_completion_input_ids"],
|
1168 |
+
"attention_mask": batch["KL_completion_attention_mask"],
|
1169 |
+
}
|
1170 |
+
)
|
1171 |
+
with torch.no_grad():
|
1172 |
+
KL_logits = model(
|
1173 |
+
**KL_model_kwargs,
|
1174 |
+
).logits
|
1175 |
+
|
1176 |
+
KL_logps = self.get_batch_logps(
|
1177 |
+
KL_logits,
|
1178 |
+
batch["KL_completion_labels"],
|
1179 |
+
average_log_prob=False,
|
1180 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1181 |
+
label_pad_token_id=self.label_pad_token_id,
|
1182 |
+
)
|
1183 |
+
else:
|
1184 |
+
KL_logps = None
|
1185 |
+
|
1186 |
+
model_kwargs = (
|
1187 |
+
{
|
1188 |
+
"labels": batch["completion_labels"],
|
1189 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1190 |
+
}
|
1191 |
+
if self.is_encoder_decoder
|
1192 |
+
else {}
|
1193 |
+
)
|
1194 |
+
if self.aux_loss_enabled:
|
1195 |
+
model_kwargs["output_router_logits"] = True
|
1196 |
+
|
1197 |
+
outputs = model(
|
1198 |
+
batch["completion_input_ids"],
|
1199 |
+
attention_mask=batch["completion_attention_mask"],
|
1200 |
+
**model_kwargs,
|
1201 |
+
)
|
1202 |
+
completion_logits = outputs.logits
|
1203 |
+
|
1204 |
+
completion_logps = self.get_batch_logps(
|
1205 |
+
completion_logits,
|
1206 |
+
batch["completion_labels"],
|
1207 |
+
average_log_prob=False,
|
1208 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1209 |
+
label_pad_token_id=self.label_pad_token_id,
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
1213 |
+
raise ValueError(
|
1214 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
1215 |
+
"examples for which an output sequence was predicted."
|
1216 |
+
)
|
1217 |
+
|
1218 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1219 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1220 |
+
|
1221 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
1222 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
1223 |
+
|
1224 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
1225 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
1226 |
+
|
1227 |
+
if self.aux_loss_enabled:
|
1228 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
1229 |
+
else:
|
1230 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
1231 |
+
|
1232 |
+
def kto_loss(
|
1233 |
+
self,
|
1234 |
+
policy_chosen_logps: torch.FloatTensor,
|
1235 |
+
policy_rejected_logps: torch.FloatTensor,
|
1236 |
+
policy_KL_logps: torch.FloatTensor,
|
1237 |
+
reference_chosen_logps: torch.FloatTensor,
|
1238 |
+
reference_rejected_logps: torch.FloatTensor,
|
1239 |
+
reference_KL_logps: torch.FloatTensor,
|
1240 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1241 |
+
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
1242 |
+
|
1243 |
+
Args:
|
1244 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1245 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1246 |
+
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
1247 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1248 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1249 |
+
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
1250 |
+
|
1251 |
+
Returns:
|
1252 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
1253 |
+
The losses tensor contains the KTO loss for each example in the batch.
|
1254 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1255 |
+
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
1256 |
+
"""
|
1257 |
+
if self.calculate_KL:
|
1258 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
1259 |
+
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
1260 |
+
else:
|
1261 |
+
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
1262 |
+
|
1263 |
+
# Chosen losses
|
1264 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1265 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1266 |
+
|
1267 |
+
if self.loss_type == "kto":
|
1268 |
+
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
1269 |
+
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
1270 |
+
elif self.loss_type == "apo_zero_unpaired":
|
1271 |
+
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
1272 |
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
1273 |
+
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
1274 |
+
|
1275 |
+
chosen_rewards = self.beta * chosen_logratios.detach()
|
1276 |
+
|
1277 |
+
else:
|
1278 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1279 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1280 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1281 |
+
|
1282 |
+
# Rejected losses
|
1283 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1284 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1285 |
+
|
1286 |
+
if self.loss_type == "kto":
|
1287 |
+
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
1288 |
+
elif self.loss_type == "apo_zero_unpaired":
|
1289 |
+
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
1290 |
+
|
1291 |
+
rejected_rewards = self.beta * rejected_logratios.detach()
|
1292 |
+
else:
|
1293 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1294 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1295 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1296 |
+
|
1297 |
+
losses = torch.cat(
|
1298 |
+
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
1299 |
+
0,
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
return losses, chosen_rewards, rejected_rewards, kl
|
1303 |
+
|
1304 |
+
def get_batch_loss_metrics(
|
1305 |
+
self,
|
1306 |
+
model,
|
1307 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1308 |
+
):
|
1309 |
+
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
1310 |
+
metrics = {}
|
1311 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1312 |
+
|
1313 |
+
forward_output = self.forward(model, batch)
|
1314 |
+
(
|
1315 |
+
policy_chosen_logps,
|
1316 |
+
policy_rejected_logps,
|
1317 |
+
policy_chosen_logits,
|
1318 |
+
policy_rejected_logits,
|
1319 |
+
policy_KL_logps,
|
1320 |
+
) = forward_output[:5]
|
1321 |
+
if self.aux_loss_enabled:
|
1322 |
+
aux_loss = forward_output[5]
|
1323 |
+
|
1324 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
1325 |
+
if "reference_logps" in batch:
|
1326 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1327 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1328 |
+
|
1329 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1330 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1331 |
+
if self.calculate_KL:
|
1332 |
+
reference_KL_logps = batch["reference_KL_logps"]
|
1333 |
+
else:
|
1334 |
+
reference_KL_logps = None
|
1335 |
+
else:
|
1336 |
+
with torch.no_grad():
|
1337 |
+
if self.ref_model is None:
|
1338 |
+
with self.null_ref_context():
|
1339 |
+
(
|
1340 |
+
reference_chosen_logps,
|
1341 |
+
reference_rejected_logps,
|
1342 |
+
_,
|
1343 |
+
_,
|
1344 |
+
reference_KL_logps,
|
1345 |
+
) = self.forward(self.model, batch)[:5]
|
1346 |
+
else:
|
1347 |
+
(
|
1348 |
+
reference_chosen_logps,
|
1349 |
+
reference_rejected_logps,
|
1350 |
+
_,
|
1351 |
+
_,
|
1352 |
+
reference_KL_logps,
|
1353 |
+
) = self.forward(self.ref_model, batch)[:5]
|
1354 |
+
|
1355 |
+
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
1356 |
+
policy_chosen_logps,
|
1357 |
+
policy_rejected_logps,
|
1358 |
+
policy_KL_logps,
|
1359 |
+
reference_chosen_logps,
|
1360 |
+
reference_rejected_logps,
|
1361 |
+
reference_KL_logps,
|
1362 |
+
)
|
1363 |
+
metrics["kl"] = kl.item()
|
1364 |
+
|
1365 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1366 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1367 |
+
|
1368 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1369 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1370 |
+
|
1371 |
+
if all_num_chosen > 0:
|
1372 |
+
metrics["rewards/chosen_sum"] = (
|
1373 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1374 |
+
)
|
1375 |
+
metrics["logps/chosen_sum"] = (
|
1376 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1377 |
+
)
|
1378 |
+
metrics["logits/chosen_sum"] = (
|
1379 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1380 |
+
)
|
1381 |
+
metrics["count/chosen"] = all_num_chosen
|
1382 |
+
|
1383 |
+
if all_num_rejected > 0:
|
1384 |
+
metrics["rewards/rejected_sum"] = (
|
1385 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1386 |
+
)
|
1387 |
+
metrics["logps/rejected_sum"] = (
|
1388 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1389 |
+
)
|
1390 |
+
metrics["logits/rejected_sum"] = (
|
1391 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1392 |
+
)
|
1393 |
+
metrics["count/rejected"] = all_num_rejected
|
1394 |
+
|
1395 |
+
loss = losses.nanmean()
|
1396 |
+
if self.aux_loss_enabled:
|
1397 |
+
loss += self.aux_loss_coef * aux_loss
|
1398 |
+
|
1399 |
+
return loss, metrics
|
1400 |
+
|
1401 |
+
def compute_loss(
|
1402 |
+
self,
|
1403 |
+
model: Union[PreTrainedModel, nn.Module],
|
1404 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1405 |
+
return_outputs=False,
|
1406 |
+
num_items_in_batch=None,
|
1407 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1408 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1409 |
+
|
1410 |
+
with compute_loss_context_manager:
|
1411 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1412 |
+
|
1413 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1414 |
+
loss = loss.to(self.args.device)
|
1415 |
+
# force log the metrics
|
1416 |
+
if self.accelerator.is_main_process:
|
1417 |
+
self.store_metrics(metrics, train_eval="train")
|
1418 |
+
|
1419 |
+
if return_outputs:
|
1420 |
+
return (loss, metrics)
|
1421 |
+
return loss
|
1422 |
+
|
1423 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1424 |
+
for key, value in metrics.items():
|
1425 |
+
self._stored_metrics[train_eval][key].append(value)
|
1426 |
+
|
1427 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1428 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
1429 |
+
return None
|
1430 |
+
return SequentialSampler(self.train_dataset)
|
1431 |
+
|
1432 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1433 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1434 |
+
|
1435 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1436 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1437 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1438 |
+
|
1439 |
+
with generate_context_manager:
|
1440 |
+
policy_output = model.generate(
|
1441 |
+
input_ids=batch["prompt_input_ids"],
|
1442 |
+
attention_mask=batch["prompt_attention_mask"],
|
1443 |
+
max_length=self.max_length,
|
1444 |
+
do_sample=True,
|
1445 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1446 |
+
)
|
1447 |
+
|
1448 |
+
# if reference_output in batch use that otherwise use the reference model
|
1449 |
+
if "reference_output" in batch:
|
1450 |
+
reference_output = batch["reference_output"]
|
1451 |
+
else:
|
1452 |
+
if self.ref_model is None:
|
1453 |
+
with self.null_ref_context():
|
1454 |
+
reference_output = self.model.generate(
|
1455 |
+
input_ids=batch["prompt_input_ids"],
|
1456 |
+
attention_mask=batch["prompt_attention_mask"],
|
1457 |
+
max_length=self.max_length,
|
1458 |
+
do_sample=True,
|
1459 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1460 |
+
)
|
1461 |
+
else:
|
1462 |
+
reference_output = self.ref_model.generate(
|
1463 |
+
input_ids=batch["prompt_input_ids"],
|
1464 |
+
attention_mask=batch["prompt_attention_mask"],
|
1465 |
+
max_length=self.max_length,
|
1466 |
+
do_sample=True,
|
1467 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1468 |
+
)
|
1469 |
+
|
1470 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1471 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1472 |
+
|
1473 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1474 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1475 |
+
|
1476 |
+
return policy_output_decoded, reference_output_decoded
|
1477 |
+
|
1478 |
+
def prediction_step(
|
1479 |
+
self,
|
1480 |
+
model: Union[PreTrainedModel, nn.Module],
|
1481 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1482 |
+
prediction_loss_only: bool,
|
1483 |
+
ignore_keys: Optional[list[str]] = None,
|
1484 |
+
):
|
1485 |
+
if ignore_keys is None:
|
1486 |
+
if hasattr(model, "config"):
|
1487 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1488 |
+
else:
|
1489 |
+
ignore_keys = []
|
1490 |
+
|
1491 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1492 |
+
with torch.no_grad(), prediction_context_manager:
|
1493 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1494 |
+
|
1495 |
+
# force log the metrics
|
1496 |
+
if self.accelerator.is_main_process:
|
1497 |
+
self.store_metrics(metrics, train_eval="eval")
|
1498 |
+
|
1499 |
+
if prediction_loss_only:
|
1500 |
+
return (loss.detach(), None, None)
|
1501 |
+
|
1502 |
+
# logits for the chosen and rejected samples from model
|
1503 |
+
logits_dict = {
|
1504 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
1505 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
1506 |
+
}
|
1507 |
+
logits = torch.tensor(
|
1508 |
+
[v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
|
1509 |
+
)
|
1510 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1511 |
+
|
1512 |
+
return (loss.detach(), logits, labels)
|
1513 |
+
|
1514 |
+
def evaluation_loop(
|
1515 |
+
self,
|
1516 |
+
dataloader: DataLoader,
|
1517 |
+
description: str,
|
1518 |
+
prediction_loss_only: Optional[bool] = None,
|
1519 |
+
ignore_keys: Optional[list[str]] = None,
|
1520 |
+
metric_key_prefix: str = "eval",
|
1521 |
+
) -> EvalLoopOutput:
|
1522 |
+
"""
|
1523 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1524 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1525 |
+
|
1526 |
+
Works both with or without labels.
|
1527 |
+
"""
|
1528 |
+
|
1529 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1530 |
+
if self.generate_during_eval:
|
1531 |
+
# Generate random indices within the range of the total number of samples
|
1532 |
+
num_samples = len(dataloader.dataset)
|
1533 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1534 |
+
|
1535 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1536 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1537 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1538 |
+
random_batch = self._prepare_inputs(random_batch)
|
1539 |
+
|
1540 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1541 |
+
target_batch = {
|
1542 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1543 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1544 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1545 |
+
}
|
1546 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1547 |
+
|
1548 |
+
table = pd.DataFrame(
|
1549 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
1550 |
+
data=[
|
1551 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1552 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1553 |
+
],
|
1554 |
+
)
|
1555 |
+
if "wandb" in self.args.report_to:
|
1556 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1557 |
+
|
1558 |
+
if "comet_ml" in self.args.report_to:
|
1559 |
+
log_table_to_comet_experiment(
|
1560 |
+
name="game_log.csv",
|
1561 |
+
table=table,
|
1562 |
+
)
|
1563 |
+
|
1564 |
+
# Base evaluation
|
1565 |
+
initial_output = super().evaluation_loop(
|
1566 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1567 |
+
)
|
1568 |
+
|
1569 |
+
return initial_output
|
1570 |
+
|
1571 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1572 |
+
"""
|
1573 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1574 |
+
|
1575 |
+
Args:
|
1576 |
+
logs (`dict[str, float]`):
|
1577 |
+
The values to log.
|
1578 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1579 |
+
Start time of the training.
|
1580 |
+
"""
|
1581 |
+
# logs either has 'loss' or 'eval_loss'
|
1582 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1583 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
1584 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1585 |
+
# accumulate average metrics from sums and lengths
|
1586 |
+
for split in ["chosen", "rejected"]:
|
1587 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1588 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1589 |
+
for metric in ["rewards", "logps", "logits"]:
|
1590 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
1591 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1592 |
+
/ count_sum
|
1593 |
+
)
|
1594 |
+
# delete obsolete metric
|
1595 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1596 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
1597 |
+
# calculate reward margin
|
1598 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1599 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1600 |
+
# Add averaged stored metrics to logs
|
1601 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1602 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1603 |
+
del self._stored_metrics[train_eval]
|
1604 |
+
|
1605 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1606 |
+
return super().log(logs, start_time)
|
1607 |
+
else: # transformers<=4.46
|
1608 |
+
return super().log(logs)
|
1609 |
+
|
1610 |
+
def create_model_card(
|
1611 |
+
self,
|
1612 |
+
model_name: Optional[str] = None,
|
1613 |
+
dataset_name: Optional[str] = None,
|
1614 |
+
tags: Union[str, list[str], None] = None,
|
1615 |
+
):
|
1616 |
+
"""
|
1617 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1618 |
+
|
1619 |
+
Args:
|
1620 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1621 |
+
Name of the model.
|
1622 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1623 |
+
Name of the dataset used for training.
|
1624 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1625 |
+
Tags to be associated with the model card.
|
1626 |
+
"""
|
1627 |
+
if not self.is_world_process_zero():
|
1628 |
+
return
|
1629 |
+
|
1630 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1631 |
+
base_model = self.model.config._name_or_path
|
1632 |
+
else:
|
1633 |
+
base_model = None
|
1634 |
+
|
1635 |
+
tags = tags or []
|
1636 |
+
if isinstance(tags, str):
|
1637 |
+
tags = [tags]
|
1638 |
+
|
1639 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1640 |
+
tags.append("unsloth")
|
1641 |
+
|
1642 |
+
citation = textwrap.dedent("""\
|
1643 |
+
@article{ethayarajh2024kto,
|
1644 |
+
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
1645 |
+
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
1646 |
+
year = 2024,
|
1647 |
+
eprint = {arXiv:2402.01306},
|
1648 |
+
}""")
|
1649 |
+
|
1650 |
+
model_card = generate_model_card(
|
1651 |
+
base_model=base_model,
|
1652 |
+
model_name=model_name,
|
1653 |
+
hub_model_id=self.hub_model_id,
|
1654 |
+
dataset_name=dataset_name,
|
1655 |
+
tags=tags,
|
1656 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1657 |
+
comet_url=get_comet_experiment_url(),
|
1658 |
+
trainer_name="KTO",
|
1659 |
+
trainer_citation=citation,
|
1660 |
+
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
|
1661 |
+
paper_id="2402.01306",
|
1662 |
+
)
|
1663 |
+
|
1664 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1665 |
+
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
1666 |
+
"""
|
1667 |
+
|
1668 |
+
Initialize KTOTrainer.
|
1669 |
+
|
1670 |
+
Args:
|
1671 |
+
model (`transformers.PreTrainedModel`):
|
1672 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1673 |
+
ref_model (`PreTrainedModelWrapper`):
|
1674 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1675 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1676 |
+
args (`KTOConfig`):
|
1677 |
+
The arguments to use for training.
|
1678 |
+
train_dataset (`datasets.Dataset`):
|
1679 |
+
The dataset to use for training.
|
1680 |
+
eval_dataset (`datasets.Dataset`):
|
1681 |
+
The dataset to use for evaluation.
|
1682 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1683 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1684 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1685 |
+
reuse the fine-tuned model.
|
1686 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1687 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1688 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1689 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1690 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1691 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1692 |
+
The callbacks to use for training.
|
1693 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1694 |
+
The optimizer and scheduler to use for training.
|
1695 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1696 |
+
The function to use to preprocess the logits before computing the metrics.
|
1697 |
+
peft_config (`dict`, defaults to `None`):
|
1698 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1699 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1700 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1701 |
+
a dictionary string to metric values.
|
1702 |
+
model_adapter_name (`str`, defaults to `None`):
|
1703 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1704 |
+
ref_adapter_name (`str`, defaults to `None`):
|
1705 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1706 |
+
|
1707 |
+
"""
|
1708 |
+
def __init__(
|
1709 |
+
self,
|
1710 |
+
model = None,
|
1711 |
+
ref_model = None,
|
1712 |
+
args = None,
|
1713 |
+
train_dataset = None,
|
1714 |
+
eval_dataset = None,
|
1715 |
+
processing_class = None,
|
1716 |
+
data_collator = None,
|
1717 |
+
model_init = None,
|
1718 |
+
callbacks = None,
|
1719 |
+
preprocess_logits_for_metrics = None,
|
1720 |
+
peft_config = None,
|
1721 |
+
compute_metrics = None,
|
1722 |
+
model_adapter_name = None,
|
1723 |
+
ref_adapter_name = None,
|
1724 |
+
**kwargs
|
1725 |
+
):
|
1726 |
+
if args is None: args = UnslothKTOConfig()
|
1727 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1728 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1729 |
+
force_float32 = False
|
1730 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1731 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1732 |
+
force_float32 = True
|
1733 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1734 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1735 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1736 |
+
from unsloth_zoo.utils import _get_dtype
|
1737 |
+
dtype = _get_dtype(dtype)
|
1738 |
+
float16 = dtype == torch.float16
|
1739 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1740 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1741 |
+
if force_float32:
|
1742 |
+
args.fp16 = False
|
1743 |
+
args.bf16 = False
|
1744 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1745 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1746 |
+
args.fp16 = float16
|
1747 |
+
args.bf16 = not float16
|
1748 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1749 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1750 |
+
args.eval_strategy = 'steps'
|
1751 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1752 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1753 |
+
if ga_steps is not None and ga_steps > 1:
|
1754 |
+
from transformers import __version__ as transformers_version
|
1755 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1756 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1757 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1758 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1759 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1760 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1761 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1762 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1763 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1764 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1765 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1766 |
+
if force_float32:
|
1767 |
+
args.bf16_full_eval = False
|
1768 |
+
args.fp16_full_eval = False
|
1769 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1770 |
+
args.bf16_full_eval = True
|
1771 |
+
args.fp16_full_eval = False
|
1772 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1773 |
+
args.bf16_full_eval = args.bf16
|
1774 |
+
args.fp16_full_eval = args.fp16
|
1775 |
+
_output_logits = False
|
1776 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1777 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1778 |
+
if _output_logits:
|
1779 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1780 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1781 |
+
pass
|
1782 |
+
else:
|
1783 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1784 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1785 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1786 |
+
max_seq_length = model.max_seq_length
|
1787 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1788 |
+
if model is not None and hasattr(model, 'for_training'):
|
1789 |
+
model.for_training()
|
1790 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1791 |
+
if 'processing_class' in locals():
|
1792 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1793 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1794 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1795 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1796 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1797 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1798 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1799 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1800 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1801 |
+
else:
|
1802 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1803 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1804 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1805 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1806 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1807 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1808 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1809 |
+
else:
|
1810 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1811 |
+
other_metrics = []
|
1812 |
+
|
1813 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1814 |
+
PatchRLStatistics('kto_trainer', other_metrics)
|
1815 |
+
|
1816 |
+
super().__init__(
|
1817 |
+
model = model,
|
1818 |
+
ref_model = ref_model,
|
1819 |
+
args = args,
|
1820 |
+
train_dataset = train_dataset,
|
1821 |
+
eval_dataset = eval_dataset,
|
1822 |
+
processing_class = processing_class,
|
1823 |
+
data_collator = data_collator,
|
1824 |
+
model_init = model_init,
|
1825 |
+
callbacks = callbacks,
|
1826 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1827 |
+
peft_config = peft_config,
|
1828 |
+
compute_metrics = compute_metrics,
|
1829 |
+
model_adapter_name = model_adapter_name,
|
1830 |
+
ref_adapter_name = ref_adapter_name,**kwargs)
|
1831 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1832 |
+
self.neftune_hook_handle.remove()
|
1833 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1834 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1835 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1836 |
+
pass
|
1837 |
+
|
1838 |
+
pass
|
unsloth_compiled_cache/UnslothNashMDTrainer.py
ADDED
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothNashMDConfig(NashMDConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`NashMDTrainer`].
|
47 |
+
|
48 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
52 |
+
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
53 |
+
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
54 |
+
epochs.
|
55 |
+
|
56 |
+
"""
|
57 |
+
vllm_sampling_params: Optional[Any] = field(
|
58 |
+
default = None,
|
59 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
60 |
+
)
|
61 |
+
unsloth_num_chunks : Optional[int] = field(
|
62 |
+
default = -1,
|
63 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
64 |
+
)
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
output_dir = None,
|
68 |
+
overwrite_output_dir = None,
|
69 |
+
do_train = False,
|
70 |
+
do_eval = False,
|
71 |
+
do_predict = False,
|
72 |
+
eval_strategy = 'no',
|
73 |
+
prediction_loss_only = False,
|
74 |
+
per_device_train_batch_size = 4,
|
75 |
+
per_device_eval_batch_size = 4,
|
76 |
+
per_gpu_train_batch_size = None,
|
77 |
+
per_gpu_eval_batch_size = None,
|
78 |
+
gradient_accumulation_steps = 2,
|
79 |
+
eval_accumulation_steps = 2,
|
80 |
+
eval_delay = 0,
|
81 |
+
torch_empty_cache_steps = 250,
|
82 |
+
learning_rate = 5e-05,
|
83 |
+
weight_decay = 0.01,
|
84 |
+
adam_beta1 = 0.9,
|
85 |
+
adam_beta2 = 0.999,
|
86 |
+
adam_epsilon = 1e-08,
|
87 |
+
max_grad_norm = 1.0,
|
88 |
+
num_train_epochs = 3.0,
|
89 |
+
max_steps = -1,
|
90 |
+
lr_scheduler_type = 'linear',
|
91 |
+
warmup_ratio = 0.1,
|
92 |
+
warmup_steps = 0,
|
93 |
+
log_level = 'passive',
|
94 |
+
log_level_replica = 'warning',
|
95 |
+
log_on_each_node = True,
|
96 |
+
logging_dir = None,
|
97 |
+
logging_strategy = 'steps',
|
98 |
+
logging_first_step = False,
|
99 |
+
logging_steps = 1,
|
100 |
+
logging_nan_inf_filter = False,
|
101 |
+
save_strategy = 'steps',
|
102 |
+
save_steps = 500,
|
103 |
+
save_total_limit = None,
|
104 |
+
save_safetensors = True,
|
105 |
+
save_on_each_node = False,
|
106 |
+
save_only_model = False,
|
107 |
+
restore_callback_states_from_checkpoint = False,
|
108 |
+
no_cuda = False,
|
109 |
+
use_cpu = False,
|
110 |
+
use_mps_device = False,
|
111 |
+
seed = 3407,
|
112 |
+
data_seed = 3407,
|
113 |
+
jit_mode_eval = False,
|
114 |
+
use_ipex = False,
|
115 |
+
bf16 = False,
|
116 |
+
fp16 = False,
|
117 |
+
fp16_opt_level = 'O1',
|
118 |
+
half_precision_backend = 'auto',
|
119 |
+
bf16_full_eval = False,
|
120 |
+
fp16_full_eval = False,
|
121 |
+
tf32 = None,
|
122 |
+
local_rank = -1,
|
123 |
+
ddp_backend = None,
|
124 |
+
tpu_num_cores = None,
|
125 |
+
tpu_metrics_debug = False,
|
126 |
+
debug = '',
|
127 |
+
dataloader_drop_last = False,
|
128 |
+
eval_steps = None,
|
129 |
+
dataloader_num_workers = 0,
|
130 |
+
dataloader_prefetch_factor = None,
|
131 |
+
past_index = -1,
|
132 |
+
run_name = None,
|
133 |
+
disable_tqdm = None,
|
134 |
+
remove_unused_columns = True,
|
135 |
+
label_names = None,
|
136 |
+
load_best_model_at_end = False,
|
137 |
+
metric_for_best_model = None,
|
138 |
+
greater_is_better = None,
|
139 |
+
ignore_data_skip = False,
|
140 |
+
fsdp = '',
|
141 |
+
fsdp_min_num_params = 0,
|
142 |
+
fsdp_config = None,
|
143 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
144 |
+
accelerator_config = None,
|
145 |
+
deepspeed = None,
|
146 |
+
label_smoothing_factor = 0.0,
|
147 |
+
optim = 'adamw_8bit',
|
148 |
+
optim_args = None,
|
149 |
+
adafactor = False,
|
150 |
+
group_by_length = False,
|
151 |
+
length_column_name = 'length',
|
152 |
+
report_to = None,
|
153 |
+
ddp_find_unused_parameters = None,
|
154 |
+
ddp_bucket_cap_mb = None,
|
155 |
+
ddp_broadcast_buffers = None,
|
156 |
+
dataloader_pin_memory = True,
|
157 |
+
dataloader_persistent_workers = False,
|
158 |
+
skip_memory_metrics = True,
|
159 |
+
use_legacy_prediction_loop = False,
|
160 |
+
push_to_hub = False,
|
161 |
+
resume_from_checkpoint = None,
|
162 |
+
hub_model_id = None,
|
163 |
+
hub_strategy = 'every_save',
|
164 |
+
hub_token = None,
|
165 |
+
hub_private_repo = None,
|
166 |
+
hub_always_push = False,
|
167 |
+
gradient_checkpointing = False,
|
168 |
+
gradient_checkpointing_kwargs = None,
|
169 |
+
include_inputs_for_metrics = False,
|
170 |
+
eval_do_concat_batches = True,
|
171 |
+
fp16_backend = 'auto',
|
172 |
+
evaluation_strategy = None,
|
173 |
+
push_to_hub_model_id = None,
|
174 |
+
push_to_hub_organization = None,
|
175 |
+
push_to_hub_token = None,
|
176 |
+
mp_parameters = '',
|
177 |
+
auto_find_batch_size = False,
|
178 |
+
full_determinism = False,
|
179 |
+
torchdynamo = None,
|
180 |
+
ray_scope = 'last',
|
181 |
+
ddp_timeout = 1800,
|
182 |
+
torch_compile = False,
|
183 |
+
torch_compile_backend = None,
|
184 |
+
torch_compile_mode = None,
|
185 |
+
dispatch_batches = None,
|
186 |
+
split_batches = None,
|
187 |
+
include_tokens_per_second = False,
|
188 |
+
include_num_input_tokens_seen = False,
|
189 |
+
neftune_noise_alpha = None,
|
190 |
+
optim_target_modules = None,
|
191 |
+
batch_eval_metrics = False,
|
192 |
+
eval_on_start = False,
|
193 |
+
use_liger_kernel = False,
|
194 |
+
eval_use_gather_object = False,
|
195 |
+
average_tokens_across_devices = False,
|
196 |
+
reward_model_path = None,
|
197 |
+
judge = None,
|
198 |
+
max_new_tokens = 64,
|
199 |
+
max_length = 512,
|
200 |
+
temperature = 0.9,
|
201 |
+
missing_eos_penalty = None,
|
202 |
+
loss_type = 'sigmoid',
|
203 |
+
dataset_num_proc = None,
|
204 |
+
disable_dropout = True,
|
205 |
+
use_vllm = False,
|
206 |
+
ds3_gather_for_generation = True,
|
207 |
+
vllm_sampling_params = None,
|
208 |
+
unsloth_num_chunks = -1,
|
209 |
+
**kwargs,
|
210 |
+
):
|
211 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
212 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
213 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
214 |
+
output_dir = 'unsloth_training_checkpoints'
|
215 |
+
save_strategy = 'no'
|
216 |
+
if dataset_num_proc is None:
|
217 |
+
from multiprocessing import cpu_count
|
218 |
+
dataset_num_proc = cpu_count()
|
219 |
+
|
220 |
+
super().__init__(
|
221 |
+
output_dir = output_dir,
|
222 |
+
overwrite_output_dir = overwrite_output_dir,
|
223 |
+
do_train = do_train,
|
224 |
+
do_eval = do_eval,
|
225 |
+
do_predict = do_predict,
|
226 |
+
eval_strategy = eval_strategy,
|
227 |
+
prediction_loss_only = prediction_loss_only,
|
228 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
229 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
230 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
231 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
232 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
233 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
234 |
+
eval_delay = eval_delay,
|
235 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
236 |
+
learning_rate = learning_rate,
|
237 |
+
weight_decay = weight_decay,
|
238 |
+
adam_beta1 = adam_beta1,
|
239 |
+
adam_beta2 = adam_beta2,
|
240 |
+
adam_epsilon = adam_epsilon,
|
241 |
+
max_grad_norm = max_grad_norm,
|
242 |
+
num_train_epochs = num_train_epochs,
|
243 |
+
max_steps = max_steps,
|
244 |
+
lr_scheduler_type = lr_scheduler_type,
|
245 |
+
warmup_ratio = warmup_ratio,
|
246 |
+
warmup_steps = warmup_steps,
|
247 |
+
log_level = log_level,
|
248 |
+
log_level_replica = log_level_replica,
|
249 |
+
log_on_each_node = log_on_each_node,
|
250 |
+
logging_dir = logging_dir,
|
251 |
+
logging_strategy = logging_strategy,
|
252 |
+
logging_first_step = logging_first_step,
|
253 |
+
logging_steps = logging_steps,
|
254 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
255 |
+
save_strategy = save_strategy,
|
256 |
+
save_steps = save_steps,
|
257 |
+
save_total_limit = save_total_limit,
|
258 |
+
save_safetensors = save_safetensors,
|
259 |
+
save_on_each_node = save_on_each_node,
|
260 |
+
save_only_model = save_only_model,
|
261 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
262 |
+
no_cuda = no_cuda,
|
263 |
+
use_cpu = use_cpu,
|
264 |
+
use_mps_device = use_mps_device,
|
265 |
+
seed = seed,
|
266 |
+
data_seed = data_seed,
|
267 |
+
jit_mode_eval = jit_mode_eval,
|
268 |
+
use_ipex = use_ipex,
|
269 |
+
bf16 = bf16,
|
270 |
+
fp16 = fp16,
|
271 |
+
fp16_opt_level = fp16_opt_level,
|
272 |
+
half_precision_backend = half_precision_backend,
|
273 |
+
bf16_full_eval = bf16_full_eval,
|
274 |
+
fp16_full_eval = fp16_full_eval,
|
275 |
+
tf32 = tf32,
|
276 |
+
local_rank = local_rank,
|
277 |
+
ddp_backend = ddp_backend,
|
278 |
+
tpu_num_cores = tpu_num_cores,
|
279 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
280 |
+
debug = debug,
|
281 |
+
dataloader_drop_last = dataloader_drop_last,
|
282 |
+
eval_steps = eval_steps,
|
283 |
+
dataloader_num_workers = dataloader_num_workers,
|
284 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
285 |
+
past_index = past_index,
|
286 |
+
run_name = run_name,
|
287 |
+
disable_tqdm = disable_tqdm,
|
288 |
+
remove_unused_columns = remove_unused_columns,
|
289 |
+
label_names = label_names,
|
290 |
+
load_best_model_at_end = load_best_model_at_end,
|
291 |
+
metric_for_best_model = metric_for_best_model,
|
292 |
+
greater_is_better = greater_is_better,
|
293 |
+
ignore_data_skip = ignore_data_skip,
|
294 |
+
fsdp = fsdp,
|
295 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
296 |
+
fsdp_config = fsdp_config,
|
297 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
298 |
+
accelerator_config = accelerator_config,
|
299 |
+
deepspeed = deepspeed,
|
300 |
+
label_smoothing_factor = label_smoothing_factor,
|
301 |
+
optim = optim,
|
302 |
+
optim_args = optim_args,
|
303 |
+
adafactor = adafactor,
|
304 |
+
group_by_length = group_by_length,
|
305 |
+
length_column_name = length_column_name,
|
306 |
+
report_to = report_to,
|
307 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
308 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
309 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
310 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
311 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
312 |
+
skip_memory_metrics = skip_memory_metrics,
|
313 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
314 |
+
push_to_hub = push_to_hub,
|
315 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
316 |
+
hub_model_id = hub_model_id,
|
317 |
+
hub_strategy = hub_strategy,
|
318 |
+
hub_token = hub_token,
|
319 |
+
hub_private_repo = hub_private_repo,
|
320 |
+
hub_always_push = hub_always_push,
|
321 |
+
gradient_checkpointing = gradient_checkpointing,
|
322 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
323 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
324 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
325 |
+
fp16_backend = fp16_backend,
|
326 |
+
evaluation_strategy = evaluation_strategy,
|
327 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
328 |
+
push_to_hub_organization = push_to_hub_organization,
|
329 |
+
push_to_hub_token = push_to_hub_token,
|
330 |
+
mp_parameters = mp_parameters,
|
331 |
+
auto_find_batch_size = auto_find_batch_size,
|
332 |
+
full_determinism = full_determinism,
|
333 |
+
torchdynamo = torchdynamo,
|
334 |
+
ray_scope = ray_scope,
|
335 |
+
ddp_timeout = ddp_timeout,
|
336 |
+
torch_compile = torch_compile,
|
337 |
+
torch_compile_backend = torch_compile_backend,
|
338 |
+
torch_compile_mode = torch_compile_mode,
|
339 |
+
dispatch_batches = dispatch_batches,
|
340 |
+
split_batches = split_batches,
|
341 |
+
include_tokens_per_second = include_tokens_per_second,
|
342 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
343 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
344 |
+
optim_target_modules = optim_target_modules,
|
345 |
+
batch_eval_metrics = batch_eval_metrics,
|
346 |
+
eval_on_start = eval_on_start,
|
347 |
+
use_liger_kernel = use_liger_kernel,
|
348 |
+
eval_use_gather_object = eval_use_gather_object,
|
349 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
350 |
+
reward_model_path = reward_model_path,
|
351 |
+
judge = judge,
|
352 |
+
max_new_tokens = max_new_tokens,
|
353 |
+
max_length = max_length,
|
354 |
+
temperature = temperature,
|
355 |
+
missing_eos_penalty = missing_eos_penalty,
|
356 |
+
loss_type = loss_type,
|
357 |
+
dataset_num_proc = dataset_num_proc,
|
358 |
+
disable_dropout = disable_dropout,
|
359 |
+
use_vllm = use_vllm,
|
360 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
361 |
+
self.vllm_sampling_params = vllm_sampling_params
|
362 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
363 |
+
pass
|
364 |
+
|
365 |
+
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
366 |
+
r""""""
|
367 |
+
|
368 |
+
_tag_names = ["trl", "nash-md"]
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
373 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
374 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
375 |
+
judge: Optional[BasePairwiseJudge] = None,
|
376 |
+
args: Optional[NashMDConfig] = None,
|
377 |
+
data_collator: Optional[Callable] = None,
|
378 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
379 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
380 |
+
processing_class: Optional[
|
381 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
382 |
+
] = None,
|
383 |
+
peft_config: Optional[dict] = None,
|
384 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
385 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
386 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
387 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
388 |
+
) -> None:
|
389 |
+
super().__init__(
|
390 |
+
model=model,
|
391 |
+
ref_model=ref_model,
|
392 |
+
reward_model=reward_model,
|
393 |
+
judge=judge,
|
394 |
+
args=args,
|
395 |
+
data_collator=data_collator,
|
396 |
+
train_dataset=train_dataset,
|
397 |
+
eval_dataset=eval_dataset,
|
398 |
+
processing_class=processing_class,
|
399 |
+
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
400 |
+
peft_config=peft_config,
|
401 |
+
compute_metrics=compute_metrics,
|
402 |
+
callbacks=callbacks,
|
403 |
+
optimizers=optimizers,
|
404 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
405 |
+
)
|
406 |
+
|
407 |
+
self._mixture_coef = self.args.mixture_coef
|
408 |
+
|
409 |
+
# Overwrite the stats dictionary to include NashMD specific statistics
|
410 |
+
self.stats = {
|
411 |
+
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
412 |
+
# Add "mixture_coef"
|
413 |
+
"loss/kl": [],
|
414 |
+
"objective/entropy": [],
|
415 |
+
"loss/score": [],
|
416 |
+
"rewards/probabilities": [],
|
417 |
+
"rewards/accuracies": [],
|
418 |
+
"rewards/margins": [],
|
419 |
+
"logps/chosen": [],
|
420 |
+
"logps/rejected": [],
|
421 |
+
"val/model_contain_eos_token": [],
|
422 |
+
"val/ref_contain_eos_token": [],
|
423 |
+
"beta": [],
|
424 |
+
"mixture_coef": [],
|
425 |
+
}
|
426 |
+
if self.reward_model is not None:
|
427 |
+
self.stats["rewards/chosen"] = []
|
428 |
+
self.stats["rewards/rejected"] = []
|
429 |
+
|
430 |
+
@property
|
431 |
+
def mixture_coef(self):
|
432 |
+
if isinstance(self._mixture_coef, list):
|
433 |
+
epoch = self.state.epoch
|
434 |
+
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
435 |
+
else:
|
436 |
+
return self._mixture_coef
|
437 |
+
|
438 |
+
def _generate_completions(self, model, prompts):
|
439 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
440 |
+
model_output = unwrapped_model.generate(
|
441 |
+
input_ids=prompts["input_ids"],
|
442 |
+
attention_mask=prompts["attention_mask"],
|
443 |
+
generation_config=self.generation_config,
|
444 |
+
)
|
445 |
+
|
446 |
+
ref_model = model if self.ref_model is None else self.ref_model
|
447 |
+
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
448 |
+
mixture_model = GeometricMixtureWrapper(
|
449 |
+
model=unwrapped_model,
|
450 |
+
ref_model=unwrapped_ref_model,
|
451 |
+
generation_config=self.generation_config,
|
452 |
+
mixture_coef=self.mixture_coef,
|
453 |
+
device=self.accelerator.device,
|
454 |
+
)
|
455 |
+
|
456 |
+
mixture_output = mixture_model.generate(
|
457 |
+
input_ids=prompts["input_ids"],
|
458 |
+
attention_mask=prompts["attention_mask"],
|
459 |
+
generation_config=self.generation_config,
|
460 |
+
)
|
461 |
+
|
462 |
+
return model_output, mixture_output
|
463 |
+
|
464 |
+
def _process_completions(self, model_output, mixture_output, prompts):
|
465 |
+
context_length = prompts["input_ids"].shape[1]
|
466 |
+
|
467 |
+
# Process model completions
|
468 |
+
model_completion_ids = model_output[:, context_length:]
|
469 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
470 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
471 |
+
)
|
472 |
+
model_data = {
|
473 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
474 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
475 |
+
"raw": prompts["raw"],
|
476 |
+
}
|
477 |
+
|
478 |
+
# Process reference model completions
|
479 |
+
mixture_completion_ids = mixture_output[:, context_length:]
|
480 |
+
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
481 |
+
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
482 |
+
)
|
483 |
+
mixture_data = {
|
484 |
+
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
485 |
+
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
486 |
+
"raw": prompts["raw"],
|
487 |
+
}
|
488 |
+
|
489 |
+
return model_data, mixture_data
|
490 |
+
|
491 |
+
def _compute_rewards(self, model_data, mixture_data, context_length):
|
492 |
+
with torch.no_grad():
|
493 |
+
_, model_scores, _ = get_reward(
|
494 |
+
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
495 |
+
)
|
496 |
+
_, mixture_scores, _ = get_reward(
|
497 |
+
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
498 |
+
)
|
499 |
+
|
500 |
+
# Apply EOS penalty if needed
|
501 |
+
if self.args.missing_eos_penalty is not None:
|
502 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
503 |
+
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
504 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
505 |
+
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
506 |
+
|
507 |
+
return model_scores, mixture_scores
|
508 |
+
|
509 |
+
def _compute_judge(self, model_data, mixture_data, context_length):
|
510 |
+
prompts = model_data["raw"]
|
511 |
+
model_data_completions = self.processing_class.batch_decode(
|
512 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
513 |
+
)
|
514 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
515 |
+
|
516 |
+
mixture_data_completions = self.processing_class.batch_decode(
|
517 |
+
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
518 |
+
)
|
519 |
+
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
520 |
+
if is_conversational({"prompt": prompts[0]}):
|
521 |
+
model_data_completions = [
|
522 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
523 |
+
]
|
524 |
+
environment = jinja2.Environment()
|
525 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
526 |
+
prompts = [template.render(messages=message) for message in prompts]
|
527 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
528 |
+
|
529 |
+
mixture_data_completions = [
|
530 |
+
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
531 |
+
]
|
532 |
+
mixture_data_completions = [
|
533 |
+
template.render(messages=completion) for completion in mixture_data_completions
|
534 |
+
]
|
535 |
+
|
536 |
+
probability = self.judge.judge(
|
537 |
+
prompts,
|
538 |
+
list(zip(model_data_completions, mixture_data_completions)),
|
539 |
+
return_scores=True,
|
540 |
+
)
|
541 |
+
return torch.tensor(probability, device=model_data["input_ids"].device)
|
542 |
+
|
543 |
+
def _compute_logprobs(self, model, model_data, context_length):
|
544 |
+
def compute_logprobs_for_data(m, data):
|
545 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
546 |
+
logits = output.logits[:, context_length - 1 : -1]
|
547 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
548 |
+
return token_logprobs
|
549 |
+
|
550 |
+
# Compute logprobs for model completions under the model
|
551 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
552 |
+
|
553 |
+
# Compute logprobs of model completions under the reference model
|
554 |
+
with torch.no_grad():
|
555 |
+
if self.ref_model is None:
|
556 |
+
with model.disable_adapter():
|
557 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
558 |
+
else:
|
559 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
560 |
+
|
561 |
+
# Mask padding tokens
|
562 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
563 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
564 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
565 |
+
|
566 |
+
return (model_logprobs_model_data, ref_logprobs_model_data)
|
567 |
+
|
568 |
+
def _compute_losses(
|
569 |
+
self,
|
570 |
+
model_logprobs_model_data,
|
571 |
+
ref_logprobs_model_data,
|
572 |
+
probability,
|
573 |
+
):
|
574 |
+
# reinforce score where 0.5 is a control variate
|
575 |
+
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
576 |
+
|
577 |
+
# kl divergence via reinforce
|
578 |
+
with torch.no_grad():
|
579 |
+
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
580 |
+
kl_div_log = log_ratio.sum(1)
|
581 |
+
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
582 |
+
|
583 |
+
# final loss
|
584 |
+
loss = self.beta * kl_div_loss - score
|
585 |
+
|
586 |
+
return loss.mean(), score, kl_div_log
|
587 |
+
|
588 |
+
def _log_statistics(
|
589 |
+
self,
|
590 |
+
model_data,
|
591 |
+
mixture_data,
|
592 |
+
model_logprobs_model_data,
|
593 |
+
ref_logprobs_model_data,
|
594 |
+
probability,
|
595 |
+
score,
|
596 |
+
kl_div,
|
597 |
+
context_length,
|
598 |
+
model_scores=None,
|
599 |
+
mixture_scores=None,
|
600 |
+
):
|
601 |
+
# Helper function to gather and compute mean
|
602 |
+
def gather_mean(tensor):
|
603 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
604 |
+
|
605 |
+
# Log score
|
606 |
+
self.stats["loss/score"].append(gather_mean(score))
|
607 |
+
# Log KL divergence
|
608 |
+
self.stats["loss/kl"].append(gather_mean(kl_div))
|
609 |
+
|
610 |
+
# Log logprobs
|
611 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
612 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
613 |
+
|
614 |
+
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
615 |
+
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
616 |
+
|
617 |
+
# Log rewards
|
618 |
+
if self.reward_model is not None:
|
619 |
+
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
620 |
+
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
621 |
+
|
622 |
+
# Log probabilities
|
623 |
+
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
624 |
+
|
625 |
+
# Calculate entropy for model data
|
626 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
627 |
+
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
628 |
+
|
629 |
+
# Calculate margins
|
630 |
+
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
631 |
+
self.stats["rewards/margins"].append(gather_mean(margin))
|
632 |
+
|
633 |
+
# Calculate accuracy
|
634 |
+
accuracy = (margin > 0).float()
|
635 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
636 |
+
|
637 |
+
# Log EOS token statistics
|
638 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
639 |
+
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
640 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
641 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
642 |
+
|
643 |
+
# Log beta and mixture coef
|
644 |
+
self.stats["beta"].append(self.beta)
|
645 |
+
self.stats["mixture_coef"].append(self.mixture_coef)
|
646 |
+
|
647 |
+
def training_step(
|
648 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
649 |
+
) -> torch.Tensor:
|
650 |
+
model.train()
|
651 |
+
|
652 |
+
# Apply chat template and tokenize the input
|
653 |
+
batch_size = len(next(iter(inputs.values())))
|
654 |
+
prompts = inputs["prompt"]
|
655 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
656 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
657 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
658 |
+
inputs = self.data_collator(inputs)
|
659 |
+
|
660 |
+
# need the prompt_ only
|
661 |
+
inputs = self._prepare_inputs(inputs)
|
662 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
663 |
+
prompts = {
|
664 |
+
"input_ids": inputs["prompt_input_ids"],
|
665 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
666 |
+
"raw": prompts,
|
667 |
+
}
|
668 |
+
del inputs
|
669 |
+
|
670 |
+
# Sample completions from both the model and the reference model
|
671 |
+
model_output, mixture_output = self._generate_completions(model, prompts)
|
672 |
+
|
673 |
+
# Process model completions
|
674 |
+
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
675 |
+
|
676 |
+
# Compute rewards
|
677 |
+
if self.reward_model is not None:
|
678 |
+
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
679 |
+
# probability of the model data vs the mixture data
|
680 |
+
probability = F.sigmoid(model_scores - mixture_scores)
|
681 |
+
else:
|
682 |
+
model_scores, mixture_scores = None, None
|
683 |
+
probability = self._compute_judge(model_data, mixture_data, context_length)
|
684 |
+
|
685 |
+
# Compute logprobs
|
686 |
+
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
687 |
+
|
688 |
+
# Compute loss
|
689 |
+
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
690 |
+
|
691 |
+
# Log everything
|
692 |
+
self._log_statistics(
|
693 |
+
model_data,
|
694 |
+
mixture_data,
|
695 |
+
model_logprobs_model_data.detach(),
|
696 |
+
ref_logprobs_model_data,
|
697 |
+
probability,
|
698 |
+
score.detach(),
|
699 |
+
kl_div.detach(),
|
700 |
+
context_length,
|
701 |
+
model_scores,
|
702 |
+
mixture_scores,
|
703 |
+
)
|
704 |
+
|
705 |
+
if (
|
706 |
+
self.args.torch_empty_cache_steps is not None
|
707 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
708 |
+
):
|
709 |
+
empty_cache()
|
710 |
+
|
711 |
+
kwargs = {}
|
712 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
713 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
714 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
715 |
+
|
716 |
+
if self.args.n_gpu > 1:
|
717 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
718 |
+
|
719 |
+
if self.use_apex:
|
720 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
721 |
+
scaled_loss.backward()
|
722 |
+
else:
|
723 |
+
self.accelerator.backward(loss, **kwargs)
|
724 |
+
|
725 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
726 |
+
|
727 |
+
def create_model_card(
|
728 |
+
self,
|
729 |
+
model_name: Optional[str] = None,
|
730 |
+
dataset_name: Optional[str] = None,
|
731 |
+
tags: Union[str, list[str], None] = None,
|
732 |
+
):
|
733 |
+
"""
|
734 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
735 |
+
|
736 |
+
Args:
|
737 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
738 |
+
Name of the model.
|
739 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
740 |
+
Name of the dataset used for training.
|
741 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
742 |
+
Tags to be associated with the model card.
|
743 |
+
"""
|
744 |
+
if not self.is_world_process_zero():
|
745 |
+
return
|
746 |
+
|
747 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
748 |
+
base_model = self.model.config._name_or_path
|
749 |
+
else:
|
750 |
+
base_model = None
|
751 |
+
|
752 |
+
tags = tags or []
|
753 |
+
if isinstance(tags, str):
|
754 |
+
tags = [tags]
|
755 |
+
|
756 |
+
if hasattr(self.model.config, "unsloth_version"):
|
757 |
+
tags.append("unsloth")
|
758 |
+
|
759 |
+
citation = textwrap.dedent("""\
|
760 |
+
@inproceedings{munos2024nash,
|
761 |
+
title = {{Nash Learning from Human Feedback}},
|
762 |
+
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
763 |
+
year = 2024,
|
764 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
765 |
+
publisher = {OpenReview.net},
|
766 |
+
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
767 |
+
}""")
|
768 |
+
|
769 |
+
model_card = generate_model_card(
|
770 |
+
base_model=base_model,
|
771 |
+
model_name=model_name,
|
772 |
+
hub_model_id=self.hub_model_id,
|
773 |
+
dataset_name=dataset_name,
|
774 |
+
tags=tags,
|
775 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
776 |
+
comet_url=get_comet_experiment_url(),
|
777 |
+
trainer_name="Nash-MD",
|
778 |
+
trainer_citation=citation,
|
779 |
+
paper_title="Nash Learning from Human Feedback",
|
780 |
+
paper_id="2312.00886",
|
781 |
+
)
|
782 |
+
|
783 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
784 |
+
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
785 |
+
"""
|
786 |
+
|
787 |
+
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
788 |
+
|
789 |
+
Args:
|
790 |
+
model (`transformers.PreTrainedModel`):
|
791 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
792 |
+
ref_model (`PreTrainedModelWrapper`):
|
793 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
794 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
795 |
+
reward_model (`transformers.PreTrainedModel`):
|
796 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
797 |
+
judge (`BasePairwiseJudge`):
|
798 |
+
The judge to use for pairwise comparison of model completions.
|
799 |
+
args (`NashMDConfig`):
|
800 |
+
The NashMD config arguments to use for training.
|
801 |
+
data_collator (`transformers.DataCollator`):
|
802 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
803 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
804 |
+
train_dataset (`datasets.Dataset`):
|
805 |
+
The dataset to use for training.
|
806 |
+
eval_dataset (`datasets.Dataset`):
|
807 |
+
The dataset to use for evaluation.
|
808 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
809 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
810 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
811 |
+
reuse the fine-tuned model.
|
812 |
+
peft_config (`dict`):
|
813 |
+
The peft config to use for training.
|
814 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
815 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
816 |
+
a dictionary string to metric values.
|
817 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
818 |
+
The callbacks to use for training.
|
819 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
820 |
+
The optimizer and scheduler to use for training.
|
821 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
822 |
+
The function to use to preprocess the logits before computing the metrics.
|
823 |
+
|
824 |
+
"""
|
825 |
+
def __init__(
|
826 |
+
self,
|
827 |
+
model = None,
|
828 |
+
ref_model = None,
|
829 |
+
reward_model = None,
|
830 |
+
judge = None,
|
831 |
+
args = None,
|
832 |
+
data_collator = None,
|
833 |
+
train_dataset = None,
|
834 |
+
eval_dataset = None,
|
835 |
+
processing_class = None,
|
836 |
+
peft_config = None,
|
837 |
+
compute_metrics = None,
|
838 |
+
callbacks = None,
|
839 |
+
preprocess_logits_for_metrics = None,
|
840 |
+
**kwargs
|
841 |
+
):
|
842 |
+
if args is None: args = UnslothNashMDConfig()
|
843 |
+
use_bf16 = getattr(args, 'bf16', False)
|
844 |
+
use_fp16 = getattr(args, 'fp16', False)
|
845 |
+
force_float32 = False
|
846 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
847 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
848 |
+
force_float32 = True
|
849 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
850 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
851 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
852 |
+
from unsloth_zoo.utils import _get_dtype
|
853 |
+
dtype = _get_dtype(dtype)
|
854 |
+
float16 = dtype == torch.float16
|
855 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
856 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
857 |
+
if force_float32:
|
858 |
+
args.fp16 = False
|
859 |
+
args.bf16 = False
|
860 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
861 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
862 |
+
args.fp16 = float16
|
863 |
+
args.bf16 = not float16
|
864 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
865 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
866 |
+
args.eval_strategy = 'steps'
|
867 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
868 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
869 |
+
if ga_steps is not None and ga_steps > 1:
|
870 |
+
from transformers import __version__ as transformers_version
|
871 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
872 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
873 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
874 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
875 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
876 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
877 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
878 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
879 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
880 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
881 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
882 |
+
if force_float32:
|
883 |
+
args.bf16_full_eval = False
|
884 |
+
args.fp16_full_eval = False
|
885 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
886 |
+
args.bf16_full_eval = True
|
887 |
+
args.fp16_full_eval = False
|
888 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
889 |
+
args.bf16_full_eval = args.bf16
|
890 |
+
args.fp16_full_eval = args.fp16
|
891 |
+
_output_logits = False
|
892 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
893 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
894 |
+
if _output_logits:
|
895 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
896 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
897 |
+
pass
|
898 |
+
else:
|
899 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
900 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
901 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
902 |
+
max_seq_length = model.max_seq_length
|
903 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
904 |
+
if model is not None and hasattr(model, 'for_training'):
|
905 |
+
model.for_training()
|
906 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
907 |
+
if 'processing_class' in locals():
|
908 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
909 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
910 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
911 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
912 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
913 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
914 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
915 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
916 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
917 |
+
else:
|
918 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
919 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
920 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
921 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
922 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
923 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
924 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
925 |
+
else:
|
926 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
927 |
+
other_metrics = []
|
928 |
+
|
929 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
930 |
+
PatchRLStatistics('nash_md_trainer', other_metrics)
|
931 |
+
|
932 |
+
super().__init__(
|
933 |
+
model = model,
|
934 |
+
ref_model = ref_model,
|
935 |
+
reward_model = reward_model,
|
936 |
+
judge = judge,
|
937 |
+
args = args,
|
938 |
+
data_collator = data_collator,
|
939 |
+
train_dataset = train_dataset,
|
940 |
+
eval_dataset = eval_dataset,
|
941 |
+
processing_class = processing_class,
|
942 |
+
peft_config = peft_config,
|
943 |
+
compute_metrics = compute_metrics,
|
944 |
+
callbacks = callbacks,
|
945 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
946 |
+
if hasattr(self, 'neftune_hook_handle'):
|
947 |
+
self.neftune_hook_handle.remove()
|
948 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
949 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
950 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
951 |
+
pass
|
952 |
+
|
953 |
+
pass
|
unsloth_compiled_cache/UnslothORPOTrainer.py
ADDED
@@ -0,0 +1,1541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, wandb, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothORPOConfig(ORPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`ORPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
|
66 |
+
it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
67 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to disable dropout in the model.
|
69 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
70 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
71 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
72 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
73 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
74 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
75 |
+
This argument is required if you want to use the default data collator.
|
76 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
77 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
78 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
79 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
80 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
81 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
82 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
83 |
+
string.
|
84 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
85 |
+
Number of processes to use for processing the dataset.
|
86 |
+
|
87 |
+
"""
|
88 |
+
vllm_sampling_params: Optional[Any] = field(
|
89 |
+
default = None,
|
90 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
91 |
+
)
|
92 |
+
unsloth_num_chunks : Optional[int] = field(
|
93 |
+
default = -1,
|
94 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
95 |
+
)
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
output_dir = None,
|
99 |
+
overwrite_output_dir = None,
|
100 |
+
do_train = False,
|
101 |
+
do_eval = False,
|
102 |
+
do_predict = False,
|
103 |
+
eval_strategy = 'no',
|
104 |
+
prediction_loss_only = False,
|
105 |
+
per_device_train_batch_size = 4,
|
106 |
+
per_device_eval_batch_size = 4,
|
107 |
+
per_gpu_train_batch_size = None,
|
108 |
+
per_gpu_eval_batch_size = None,
|
109 |
+
gradient_accumulation_steps = 2,
|
110 |
+
eval_accumulation_steps = 2,
|
111 |
+
eval_delay = 0,
|
112 |
+
torch_empty_cache_steps = 250,
|
113 |
+
learning_rate = 5e-05,
|
114 |
+
weight_decay = 0.01,
|
115 |
+
adam_beta1 = 0.9,
|
116 |
+
adam_beta2 = 0.999,
|
117 |
+
adam_epsilon = 1e-08,
|
118 |
+
max_grad_norm = 1.0,
|
119 |
+
num_train_epochs = 3.0,
|
120 |
+
max_steps = -1,
|
121 |
+
lr_scheduler_type = 'linear',
|
122 |
+
warmup_ratio = 0.1,
|
123 |
+
warmup_steps = 0,
|
124 |
+
log_level = 'passive',
|
125 |
+
log_level_replica = 'warning',
|
126 |
+
log_on_each_node = True,
|
127 |
+
logging_dir = None,
|
128 |
+
logging_strategy = 'steps',
|
129 |
+
logging_first_step = False,
|
130 |
+
logging_steps = 1,
|
131 |
+
logging_nan_inf_filter = False,
|
132 |
+
save_strategy = 'steps',
|
133 |
+
save_steps = 500,
|
134 |
+
save_total_limit = None,
|
135 |
+
save_safetensors = True,
|
136 |
+
save_on_each_node = False,
|
137 |
+
save_only_model = False,
|
138 |
+
restore_callback_states_from_checkpoint = False,
|
139 |
+
no_cuda = False,
|
140 |
+
use_cpu = False,
|
141 |
+
use_mps_device = False,
|
142 |
+
seed = 3407,
|
143 |
+
data_seed = 3407,
|
144 |
+
jit_mode_eval = False,
|
145 |
+
use_ipex = False,
|
146 |
+
bf16 = False,
|
147 |
+
fp16 = False,
|
148 |
+
fp16_opt_level = 'O1',
|
149 |
+
half_precision_backend = 'auto',
|
150 |
+
bf16_full_eval = False,
|
151 |
+
fp16_full_eval = False,
|
152 |
+
tf32 = None,
|
153 |
+
local_rank = -1,
|
154 |
+
ddp_backend = None,
|
155 |
+
tpu_num_cores = None,
|
156 |
+
tpu_metrics_debug = False,
|
157 |
+
debug = '',
|
158 |
+
dataloader_drop_last = False,
|
159 |
+
eval_steps = None,
|
160 |
+
dataloader_num_workers = 0,
|
161 |
+
dataloader_prefetch_factor = None,
|
162 |
+
past_index = -1,
|
163 |
+
run_name = None,
|
164 |
+
disable_tqdm = None,
|
165 |
+
remove_unused_columns = True,
|
166 |
+
label_names = None,
|
167 |
+
load_best_model_at_end = False,
|
168 |
+
metric_for_best_model = None,
|
169 |
+
greater_is_better = None,
|
170 |
+
ignore_data_skip = False,
|
171 |
+
fsdp = '',
|
172 |
+
fsdp_min_num_params = 0,
|
173 |
+
fsdp_config = None,
|
174 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
175 |
+
accelerator_config = None,
|
176 |
+
deepspeed = None,
|
177 |
+
label_smoothing_factor = 0.0,
|
178 |
+
optim = 'adamw_8bit',
|
179 |
+
optim_args = None,
|
180 |
+
adafactor = False,
|
181 |
+
group_by_length = False,
|
182 |
+
length_column_name = 'length',
|
183 |
+
report_to = None,
|
184 |
+
ddp_find_unused_parameters = None,
|
185 |
+
ddp_bucket_cap_mb = None,
|
186 |
+
ddp_broadcast_buffers = None,
|
187 |
+
dataloader_pin_memory = True,
|
188 |
+
dataloader_persistent_workers = False,
|
189 |
+
skip_memory_metrics = True,
|
190 |
+
use_legacy_prediction_loop = False,
|
191 |
+
push_to_hub = False,
|
192 |
+
resume_from_checkpoint = None,
|
193 |
+
hub_model_id = None,
|
194 |
+
hub_strategy = 'every_save',
|
195 |
+
hub_token = None,
|
196 |
+
hub_private_repo = None,
|
197 |
+
hub_always_push = False,
|
198 |
+
gradient_checkpointing = False,
|
199 |
+
gradient_checkpointing_kwargs = None,
|
200 |
+
include_inputs_for_metrics = False,
|
201 |
+
eval_do_concat_batches = True,
|
202 |
+
fp16_backend = 'auto',
|
203 |
+
evaluation_strategy = None,
|
204 |
+
push_to_hub_model_id = None,
|
205 |
+
push_to_hub_organization = None,
|
206 |
+
push_to_hub_token = None,
|
207 |
+
mp_parameters = '',
|
208 |
+
auto_find_batch_size = False,
|
209 |
+
full_determinism = False,
|
210 |
+
torchdynamo = None,
|
211 |
+
ray_scope = 'last',
|
212 |
+
ddp_timeout = 1800,
|
213 |
+
torch_compile = False,
|
214 |
+
torch_compile_backend = None,
|
215 |
+
torch_compile_mode = None,
|
216 |
+
dispatch_batches = None,
|
217 |
+
split_batches = None,
|
218 |
+
include_tokens_per_second = False,
|
219 |
+
include_num_input_tokens_seen = False,
|
220 |
+
neftune_noise_alpha = None,
|
221 |
+
optim_target_modules = None,
|
222 |
+
batch_eval_metrics = False,
|
223 |
+
eval_on_start = False,
|
224 |
+
use_liger_kernel = False,
|
225 |
+
eval_use_gather_object = False,
|
226 |
+
average_tokens_across_devices = False,
|
227 |
+
max_length = 1024,
|
228 |
+
max_prompt_length = 512,
|
229 |
+
max_completion_length = None,
|
230 |
+
beta = 0.1,
|
231 |
+
disable_dropout = True,
|
232 |
+
label_pad_token_id = -100,
|
233 |
+
padding_value = None,
|
234 |
+
truncation_mode = 'keep_end',
|
235 |
+
generate_during_eval = False,
|
236 |
+
is_encoder_decoder = None,
|
237 |
+
model_init_kwargs = None,
|
238 |
+
dataset_num_proc = None,
|
239 |
+
vllm_sampling_params = None,
|
240 |
+
unsloth_num_chunks = -1,
|
241 |
+
**kwargs,
|
242 |
+
):
|
243 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
244 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
245 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
246 |
+
output_dir = 'unsloth_training_checkpoints'
|
247 |
+
save_strategy = 'no'
|
248 |
+
if dataset_num_proc is None:
|
249 |
+
from multiprocessing import cpu_count
|
250 |
+
dataset_num_proc = cpu_count()
|
251 |
+
|
252 |
+
super().__init__(
|
253 |
+
output_dir = output_dir,
|
254 |
+
overwrite_output_dir = overwrite_output_dir,
|
255 |
+
do_train = do_train,
|
256 |
+
do_eval = do_eval,
|
257 |
+
do_predict = do_predict,
|
258 |
+
eval_strategy = eval_strategy,
|
259 |
+
prediction_loss_only = prediction_loss_only,
|
260 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
261 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
262 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
263 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
264 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
265 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
266 |
+
eval_delay = eval_delay,
|
267 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
268 |
+
learning_rate = learning_rate,
|
269 |
+
weight_decay = weight_decay,
|
270 |
+
adam_beta1 = adam_beta1,
|
271 |
+
adam_beta2 = adam_beta2,
|
272 |
+
adam_epsilon = adam_epsilon,
|
273 |
+
max_grad_norm = max_grad_norm,
|
274 |
+
num_train_epochs = num_train_epochs,
|
275 |
+
max_steps = max_steps,
|
276 |
+
lr_scheduler_type = lr_scheduler_type,
|
277 |
+
warmup_ratio = warmup_ratio,
|
278 |
+
warmup_steps = warmup_steps,
|
279 |
+
log_level = log_level,
|
280 |
+
log_level_replica = log_level_replica,
|
281 |
+
log_on_each_node = log_on_each_node,
|
282 |
+
logging_dir = logging_dir,
|
283 |
+
logging_strategy = logging_strategy,
|
284 |
+
logging_first_step = logging_first_step,
|
285 |
+
logging_steps = logging_steps,
|
286 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
287 |
+
save_strategy = save_strategy,
|
288 |
+
save_steps = save_steps,
|
289 |
+
save_total_limit = save_total_limit,
|
290 |
+
save_safetensors = save_safetensors,
|
291 |
+
save_on_each_node = save_on_each_node,
|
292 |
+
save_only_model = save_only_model,
|
293 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
294 |
+
no_cuda = no_cuda,
|
295 |
+
use_cpu = use_cpu,
|
296 |
+
use_mps_device = use_mps_device,
|
297 |
+
seed = seed,
|
298 |
+
data_seed = data_seed,
|
299 |
+
jit_mode_eval = jit_mode_eval,
|
300 |
+
use_ipex = use_ipex,
|
301 |
+
bf16 = bf16,
|
302 |
+
fp16 = fp16,
|
303 |
+
fp16_opt_level = fp16_opt_level,
|
304 |
+
half_precision_backend = half_precision_backend,
|
305 |
+
bf16_full_eval = bf16_full_eval,
|
306 |
+
fp16_full_eval = fp16_full_eval,
|
307 |
+
tf32 = tf32,
|
308 |
+
local_rank = local_rank,
|
309 |
+
ddp_backend = ddp_backend,
|
310 |
+
tpu_num_cores = tpu_num_cores,
|
311 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
312 |
+
debug = debug,
|
313 |
+
dataloader_drop_last = dataloader_drop_last,
|
314 |
+
eval_steps = eval_steps,
|
315 |
+
dataloader_num_workers = dataloader_num_workers,
|
316 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
317 |
+
past_index = past_index,
|
318 |
+
run_name = run_name,
|
319 |
+
disable_tqdm = disable_tqdm,
|
320 |
+
remove_unused_columns = remove_unused_columns,
|
321 |
+
label_names = label_names,
|
322 |
+
load_best_model_at_end = load_best_model_at_end,
|
323 |
+
metric_for_best_model = metric_for_best_model,
|
324 |
+
greater_is_better = greater_is_better,
|
325 |
+
ignore_data_skip = ignore_data_skip,
|
326 |
+
fsdp = fsdp,
|
327 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
328 |
+
fsdp_config = fsdp_config,
|
329 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
330 |
+
accelerator_config = accelerator_config,
|
331 |
+
deepspeed = deepspeed,
|
332 |
+
label_smoothing_factor = label_smoothing_factor,
|
333 |
+
optim = optim,
|
334 |
+
optim_args = optim_args,
|
335 |
+
adafactor = adafactor,
|
336 |
+
group_by_length = group_by_length,
|
337 |
+
length_column_name = length_column_name,
|
338 |
+
report_to = report_to,
|
339 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
340 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
341 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
342 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
343 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
344 |
+
skip_memory_metrics = skip_memory_metrics,
|
345 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
346 |
+
push_to_hub = push_to_hub,
|
347 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
348 |
+
hub_model_id = hub_model_id,
|
349 |
+
hub_strategy = hub_strategy,
|
350 |
+
hub_token = hub_token,
|
351 |
+
hub_private_repo = hub_private_repo,
|
352 |
+
hub_always_push = hub_always_push,
|
353 |
+
gradient_checkpointing = gradient_checkpointing,
|
354 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
355 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
356 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
357 |
+
fp16_backend = fp16_backend,
|
358 |
+
evaluation_strategy = evaluation_strategy,
|
359 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
360 |
+
push_to_hub_organization = push_to_hub_organization,
|
361 |
+
push_to_hub_token = push_to_hub_token,
|
362 |
+
mp_parameters = mp_parameters,
|
363 |
+
auto_find_batch_size = auto_find_batch_size,
|
364 |
+
full_determinism = full_determinism,
|
365 |
+
torchdynamo = torchdynamo,
|
366 |
+
ray_scope = ray_scope,
|
367 |
+
ddp_timeout = ddp_timeout,
|
368 |
+
torch_compile = torch_compile,
|
369 |
+
torch_compile_backend = torch_compile_backend,
|
370 |
+
torch_compile_mode = torch_compile_mode,
|
371 |
+
dispatch_batches = dispatch_batches,
|
372 |
+
split_batches = split_batches,
|
373 |
+
include_tokens_per_second = include_tokens_per_second,
|
374 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
375 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
376 |
+
optim_target_modules = optim_target_modules,
|
377 |
+
batch_eval_metrics = batch_eval_metrics,
|
378 |
+
eval_on_start = eval_on_start,
|
379 |
+
use_liger_kernel = use_liger_kernel,
|
380 |
+
eval_use_gather_object = eval_use_gather_object,
|
381 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
382 |
+
max_length = max_length,
|
383 |
+
max_prompt_length = max_prompt_length,
|
384 |
+
max_completion_length = max_completion_length,
|
385 |
+
beta = beta,
|
386 |
+
disable_dropout = disable_dropout,
|
387 |
+
label_pad_token_id = label_pad_token_id,
|
388 |
+
padding_value = padding_value,
|
389 |
+
truncation_mode = truncation_mode,
|
390 |
+
generate_during_eval = generate_during_eval,
|
391 |
+
is_encoder_decoder = is_encoder_decoder,
|
392 |
+
model_init_kwargs = model_init_kwargs,
|
393 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
394 |
+
self.vllm_sampling_params = vllm_sampling_params
|
395 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
396 |
+
pass
|
397 |
+
|
398 |
+
class _UnslothORPOTrainer(Trainer):
|
399 |
+
r""""""
|
400 |
+
|
401 |
+
_tag_names = ["trl", "orpo"]
|
402 |
+
|
403 |
+
def __init__(
|
404 |
+
self,
|
405 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
406 |
+
args: Optional[ORPOConfig] = None,
|
407 |
+
data_collator: Optional[DataCollator] = None,
|
408 |
+
train_dataset: Optional[Dataset] = None,
|
409 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
410 |
+
processing_class: Optional[
|
411 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
412 |
+
] = None,
|
413 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
414 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
415 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
416 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
417 |
+
peft_config: Optional[dict] = None,
|
418 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
419 |
+
):
|
420 |
+
if args.model_init_kwargs is None:
|
421 |
+
model_init_kwargs = {}
|
422 |
+
elif not isinstance(model, str):
|
423 |
+
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
424 |
+
else:
|
425 |
+
model_init_kwargs = args.model_init_kwargs
|
426 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
427 |
+
if torch_dtype is not None:
|
428 |
+
# Convert to `torch.dtype` if an str is passed
|
429 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
430 |
+
torch_dtype = getattr(torch, torch_dtype)
|
431 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
432 |
+
raise ValueError(
|
433 |
+
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
434 |
+
)
|
435 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
436 |
+
|
437 |
+
if isinstance(model, str):
|
438 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
439 |
+
|
440 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
441 |
+
# has been called in order to properly call autocast if needed.
|
442 |
+
self._peft_has_been_casted_to_bf16 = False
|
443 |
+
|
444 |
+
if not is_peft_available() and peft_config is not None:
|
445 |
+
raise ValueError(
|
446 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
447 |
+
)
|
448 |
+
elif is_peft_available() and peft_config is not None:
|
449 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
450 |
+
if isinstance(model, PeftModel):
|
451 |
+
model = model.merge_and_unload()
|
452 |
+
|
453 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
454 |
+
_support_gc_kwargs = hasattr(
|
455 |
+
args, "gradient_checkpointing_kwargs"
|
456 |
+
) and "gradient_checkpointing_kwargs" in list(
|
457 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
458 |
+
)
|
459 |
+
|
460 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
461 |
+
|
462 |
+
if _support_gc_kwargs:
|
463 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
464 |
+
|
465 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
466 |
+
elif getattr(args, "gradient_checkpointing", False):
|
467 |
+
# For backward compatibility with older versions of transformers
|
468 |
+
if hasattr(model, "enable_input_require_grads"):
|
469 |
+
model.enable_input_require_grads()
|
470 |
+
else:
|
471 |
+
|
472 |
+
def make_inputs_require_grad(module, input, output):
|
473 |
+
output.requires_grad_(True)
|
474 |
+
|
475 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
476 |
+
|
477 |
+
# get peft model with the given config
|
478 |
+
model = model
|
479 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
480 |
+
peft_module_casting_to_bf16(model)
|
481 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
482 |
+
self._peft_has_been_casted_to_bf16 = True
|
483 |
+
|
484 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
485 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
486 |
+
# fail or completely fail.
|
487 |
+
elif getattr(args, "gradient_checkpointing", False):
|
488 |
+
# For backward compatibility with older versions of transformers
|
489 |
+
if hasattr(model, "enable_input_require_grads"):
|
490 |
+
model.enable_input_require_grads()
|
491 |
+
else:
|
492 |
+
|
493 |
+
def make_inputs_require_grad(module, input, output):
|
494 |
+
output.requires_grad_(True)
|
495 |
+
|
496 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
497 |
+
|
498 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
499 |
+
raise ValueError(
|
500 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
501 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
502 |
+
)
|
503 |
+
|
504 |
+
if model is not None:
|
505 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
506 |
+
elif args.is_encoder_decoder is None:
|
507 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
508 |
+
else:
|
509 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
510 |
+
|
511 |
+
if self.is_encoder_decoder:
|
512 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
513 |
+
self.pad_token_id = model.config.pad_token_id
|
514 |
+
|
515 |
+
if processing_class is None:
|
516 |
+
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
517 |
+
if args.max_length is None:
|
518 |
+
warnings.warn(
|
519 |
+
"`max_length` is not set in the ORPOConfig's init"
|
520 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
521 |
+
UserWarning,
|
522 |
+
)
|
523 |
+
max_length = 512
|
524 |
+
else:
|
525 |
+
max_length = args.max_length
|
526 |
+
if args.max_prompt_length is None:
|
527 |
+
warnings.warn(
|
528 |
+
"`max_prompt_length` is not set in the ORPOConfig's init"
|
529 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
530 |
+
UserWarning,
|
531 |
+
)
|
532 |
+
max_prompt_length = 128
|
533 |
+
else:
|
534 |
+
max_prompt_length = args.max_prompt_length
|
535 |
+
|
536 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
537 |
+
warnings.warn(
|
538 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
539 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
540 |
+
UserWarning,
|
541 |
+
)
|
542 |
+
self.max_completion_length = 128
|
543 |
+
else:
|
544 |
+
self.max_completion_length = args.max_completion_length
|
545 |
+
|
546 |
+
if data_collator is None:
|
547 |
+
data_collator = DPODataCollatorWithPadding(
|
548 |
+
pad_token_id=processing_class.pad_token_id,
|
549 |
+
label_pad_token_id=args.label_pad_token_id,
|
550 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
551 |
+
)
|
552 |
+
|
553 |
+
if args.remove_unused_columns:
|
554 |
+
args.remove_unused_columns = False
|
555 |
+
# warn users
|
556 |
+
warnings.warn(
|
557 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
558 |
+
" we have set it for you, but you should do it yourself in the future.",
|
559 |
+
UserWarning,
|
560 |
+
)
|
561 |
+
|
562 |
+
self.use_dpo_data_collator = True
|
563 |
+
else:
|
564 |
+
self.use_dpo_data_collator = False
|
565 |
+
|
566 |
+
# Disable dropout in the model and reference model
|
567 |
+
if args.disable_dropout:
|
568 |
+
disable_dropout_in_model(model)
|
569 |
+
|
570 |
+
self.max_length = max_length
|
571 |
+
self.generate_during_eval = args.generate_during_eval
|
572 |
+
self.label_pad_token_id = args.label_pad_token_id
|
573 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
574 |
+
self.max_prompt_length = max_prompt_length
|
575 |
+
self.truncation_mode = args.truncation_mode
|
576 |
+
self.processing_class = processing_class
|
577 |
+
|
578 |
+
self.beta = args.beta
|
579 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
580 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
581 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
582 |
+
warnings.warn(
|
583 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
584 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
585 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
586 |
+
"loss.",
|
587 |
+
UserWarning,
|
588 |
+
)
|
589 |
+
|
590 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
591 |
+
|
592 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
593 |
+
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
594 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
595 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
596 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
597 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
598 |
+
# that the warning has already been issued.
|
599 |
+
model.warnings_issued["estimate_tokens"] = True
|
600 |
+
|
601 |
+
# Compute that only on the main process for faster data processing.
|
602 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
603 |
+
with PartialState().local_main_process_first():
|
604 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
605 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
606 |
+
train_dataset = train_dataset.map(
|
607 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
608 |
+
)
|
609 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
610 |
+
if eval_dataset is not None:
|
611 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
612 |
+
eval_dataset = eval_dataset.map(
|
613 |
+
maybe_apply_chat_template,
|
614 |
+
fn_kwargs={"tokenizer": processing_class},
|
615 |
+
num_proc=args.dataset_num_proc,
|
616 |
+
)
|
617 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
618 |
+
|
619 |
+
super().__init__(
|
620 |
+
model=model,
|
621 |
+
args=args,
|
622 |
+
data_collator=data_collator,
|
623 |
+
train_dataset=train_dataset,
|
624 |
+
eval_dataset=eval_dataset,
|
625 |
+
processing_class=processing_class,
|
626 |
+
model_init=model_init,
|
627 |
+
compute_metrics=compute_metrics,
|
628 |
+
callbacks=callbacks,
|
629 |
+
optimizers=optimizers,
|
630 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
631 |
+
)
|
632 |
+
|
633 |
+
# Add tags for models that have been loaded with the correct transformers version
|
634 |
+
if hasattr(self.model, "add_model_tags"):
|
635 |
+
self.model.add_model_tags(self._tag_names)
|
636 |
+
|
637 |
+
if not hasattr(self, "accelerator"):
|
638 |
+
raise AttributeError(
|
639 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
640 |
+
)
|
641 |
+
|
642 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
643 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
644 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
645 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
646 |
+
|
647 |
+
if model is not None:
|
648 |
+
if hasattr(model, "config"):
|
649 |
+
hidden_size = (
|
650 |
+
max(model.config.hidden_sizes)
|
651 |
+
if getattr(model.config, "hidden_sizes", None)
|
652 |
+
else getattr(model.config, "hidden_size", None)
|
653 |
+
)
|
654 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
655 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
656 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
657 |
+
config_kwargs.update(
|
658 |
+
{
|
659 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
660 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
661 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
662 |
+
}
|
663 |
+
)
|
664 |
+
|
665 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
666 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
667 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
668 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
669 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
670 |
+
model.eval()
|
671 |
+
return model
|
672 |
+
|
673 |
+
def build_tokenized_answer(self, prompt, answer):
|
674 |
+
"""
|
675 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
676 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
677 |
+
Reference:
|
678 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
679 |
+
"""
|
680 |
+
|
681 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
682 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
683 |
+
|
684 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
685 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
686 |
+
|
687 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
688 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
689 |
+
|
690 |
+
# Prepare input tokens for token by token comparison
|
691 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
692 |
+
|
693 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
694 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
695 |
+
|
696 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
697 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
698 |
+
# on the last token from the prompt being different when tokenized on its own
|
699 |
+
# vs when done as prompt+answer.
|
700 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
701 |
+
|
702 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
703 |
+
# last token has changed due to merging.
|
704 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
705 |
+
response_token_ids_start_idx -= 1
|
706 |
+
|
707 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
708 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
709 |
+
|
710 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
711 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
712 |
+
|
713 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
714 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
715 |
+
|
716 |
+
return dict(
|
717 |
+
prompt_input_ids=prompt_input_ids,
|
718 |
+
prompt_attention_mask=prompt_attention_mask,
|
719 |
+
input_ids=answer_input_ids,
|
720 |
+
attention_mask=answer_attention_mask,
|
721 |
+
)
|
722 |
+
|
723 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
724 |
+
"""Tokenize a single row from a ORPO specific dataset.
|
725 |
+
|
726 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
727 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
728 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
729 |
+
|
730 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
731 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
732 |
+
label_pad_token_id for the prompt tokens.
|
733 |
+
"""
|
734 |
+
batch = {}
|
735 |
+
prompt = feature["prompt"]
|
736 |
+
chosen = feature["chosen"]
|
737 |
+
rejected = feature["rejected"]
|
738 |
+
|
739 |
+
if not self.is_encoder_decoder:
|
740 |
+
# Check issues below for more details
|
741 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
742 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
743 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
744 |
+
|
745 |
+
if not isinstance(prompt, str):
|
746 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
747 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
748 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
749 |
+
|
750 |
+
if not isinstance(chosen, str):
|
751 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
752 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
753 |
+
|
754 |
+
if not isinstance(rejected, str):
|
755 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
756 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
757 |
+
|
758 |
+
# Last prompt token might get merged by tokenizer and
|
759 |
+
# it should not be included for generation if that happens
|
760 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
761 |
+
|
762 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
763 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
764 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
765 |
+
|
766 |
+
for k, v in prompt_tokens.items():
|
767 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
768 |
+
|
769 |
+
# Make sure prompts only have one different token at most an
|
770 |
+
# and length only differs by 1 at most
|
771 |
+
num_diff_tokens = sum(
|
772 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
773 |
+
)
|
774 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
775 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
776 |
+
raise ValueError(
|
777 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
778 |
+
"last token due to tokenizer merge ops."
|
779 |
+
)
|
780 |
+
|
781 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
782 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
783 |
+
self.processing_class.bos_token_id,
|
784 |
+
prompt_len_input_ids,
|
785 |
+
prompt_tokens,
|
786 |
+
chosen_prompt_len_input_ids,
|
787 |
+
chosen_tokens,
|
788 |
+
rejected_prompt_len_input_ids,
|
789 |
+
rejected_tokens,
|
790 |
+
)
|
791 |
+
|
792 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
793 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
794 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
795 |
+
)
|
796 |
+
|
797 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
798 |
+
|
799 |
+
# if combined sequence is too long, truncate the prompt
|
800 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
801 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
802 |
+
if self.truncation_mode == "keep_start":
|
803 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
804 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
805 |
+
elif self.truncation_mode == "keep_end":
|
806 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
807 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
808 |
+
else:
|
809 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
810 |
+
|
811 |
+
# if that's still too long, truncate the response
|
812 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
813 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
814 |
+
for k in ["input_ids", "attention_mask"]:
|
815 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
816 |
+
|
817 |
+
# Create labels
|
818 |
+
chosen_sequence_tokens = {
|
819 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
820 |
+
}
|
821 |
+
rejected_sequence_tokens = {
|
822 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
823 |
+
}
|
824 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
825 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
826 |
+
self.label_pad_token_id
|
827 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
828 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
829 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
830 |
+
self.label_pad_token_id
|
831 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
832 |
+
|
833 |
+
for k, toks in {
|
834 |
+
"chosen_": chosen_sequence_tokens,
|
835 |
+
"rejected_": rejected_sequence_tokens,
|
836 |
+
"": prompt_tokens,
|
837 |
+
}.items():
|
838 |
+
for type_key, tokens in toks.items():
|
839 |
+
if type_key == "token_type_ids":
|
840 |
+
continue
|
841 |
+
batch[f"{k}{type_key}"] = tokens
|
842 |
+
|
843 |
+
else:
|
844 |
+
chosen_tokens = self.processing_class(
|
845 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
846 |
+
)
|
847 |
+
rejected_tokens = self.processing_class(
|
848 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
849 |
+
)
|
850 |
+
prompt_tokens = self.processing_class(
|
851 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
852 |
+
)
|
853 |
+
|
854 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
855 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
856 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
857 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
858 |
+
|
859 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
860 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
861 |
+
labels=torch.tensor(batch["rejected_labels"])
|
862 |
+
)
|
863 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
864 |
+
labels=torch.tensor(batch["chosen_labels"])
|
865 |
+
)
|
866 |
+
|
867 |
+
if is_torch_xla_available():
|
868 |
+
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
869 |
+
for k in batch:
|
870 |
+
if "labels" in k or self.is_encoder_decoder:
|
871 |
+
pad_value = self.label_pad_token_id
|
872 |
+
elif k.endswith("_input_ids"):
|
873 |
+
pad_value = self.padding_value
|
874 |
+
elif k.endswith("_attention_mask"):
|
875 |
+
pad_value = 0
|
876 |
+
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
877 |
+
return batch
|
878 |
+
|
879 |
+
@staticmethod
|
880 |
+
def concatenated_inputs(
|
881 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
882 |
+
is_encoder_decoder: bool = False,
|
883 |
+
label_pad_token_id: int = -100,
|
884 |
+
padding_value: int = 0,
|
885 |
+
device: Optional[torch.device] = None,
|
886 |
+
) -> dict[str, torch.LongTensor]:
|
887 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
888 |
+
|
889 |
+
Args:
|
890 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
891 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
892 |
+
label_pad_token_id: The label pad token id.
|
893 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
894 |
+
device: The device for the concatenated inputs.
|
895 |
+
|
896 |
+
Returns:
|
897 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
898 |
+
"""
|
899 |
+
concatenated_batch = {}
|
900 |
+
|
901 |
+
if is_encoder_decoder:
|
902 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
903 |
+
else:
|
904 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
905 |
+
|
906 |
+
for k in batch:
|
907 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
908 |
+
if "labels" in k or is_encoder_decoder:
|
909 |
+
pad_value = label_pad_token_id
|
910 |
+
elif k.endswith("_input_ids"):
|
911 |
+
pad_value = padding_value
|
912 |
+
elif k.endswith("_attention_mask"):
|
913 |
+
pad_value = 0
|
914 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
915 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
916 |
+
for k in batch:
|
917 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
918 |
+
if "labels" in k or is_encoder_decoder:
|
919 |
+
pad_value = label_pad_token_id
|
920 |
+
elif k.endswith("_input_ids"):
|
921 |
+
pad_value = padding_value
|
922 |
+
elif k.endswith("_attention_mask"):
|
923 |
+
pad_value = 0
|
924 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
925 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
926 |
+
(
|
927 |
+
concatenated_batch[concatenated_key],
|
928 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
929 |
+
),
|
930 |
+
dim=0,
|
931 |
+
).to(device=device)
|
932 |
+
|
933 |
+
if is_encoder_decoder:
|
934 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
935 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
936 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
937 |
+
)
|
938 |
+
|
939 |
+
return concatenated_batch
|
940 |
+
|
941 |
+
def odds_ratio_loss(
|
942 |
+
self,
|
943 |
+
policy_chosen_logps: torch.FloatTensor,
|
944 |
+
policy_rejected_logps: torch.FloatTensor,
|
945 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
946 |
+
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
947 |
+
|
948 |
+
Args:
|
949 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
950 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
951 |
+
|
952 |
+
Returns:
|
953 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
954 |
+
The losses tensor contains the ORPO loss for each example in the batch.
|
955 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
956 |
+
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
957 |
+
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
958 |
+
"""
|
959 |
+
|
960 |
+
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
961 |
+
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
962 |
+
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
963 |
+
)
|
964 |
+
ratio = F.logsigmoid(log_odds)
|
965 |
+
losses = self.beta * ratio
|
966 |
+
|
967 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
968 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
969 |
+
|
970 |
+
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
971 |
+
|
972 |
+
@staticmethod
|
973 |
+
def get_batch_logps(
|
974 |
+
logits: torch.FloatTensor,
|
975 |
+
labels: torch.LongTensor,
|
976 |
+
average_log_prob: bool = False,
|
977 |
+
label_pad_token_id: int = -100,
|
978 |
+
is_encoder_decoder: bool = False,
|
979 |
+
) -> torch.FloatTensor:
|
980 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
981 |
+
|
982 |
+
Args:
|
983 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
984 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
985 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
986 |
+
label_pad_token_id: The label pad token id.
|
987 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
988 |
+
|
989 |
+
Returns:
|
990 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
991 |
+
"""
|
992 |
+
if logits.shape[:-1] != labels.shape:
|
993 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
994 |
+
|
995 |
+
if not is_encoder_decoder:
|
996 |
+
labels = labels[:, 1:].clone()
|
997 |
+
logits = logits[:, :-1, :]
|
998 |
+
loss_mask = labels != label_pad_token_id
|
999 |
+
|
1000 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1001 |
+
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
1002 |
+
|
1003 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1004 |
+
|
1005 |
+
if average_log_prob:
|
1006 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1007 |
+
else:
|
1008 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1009 |
+
|
1010 |
+
def concatenated_forward(
|
1011 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1012 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1013 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1014 |
+
|
1015 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1016 |
+
"""
|
1017 |
+
concatenated_batch = self.concatenated_inputs(
|
1018 |
+
batch,
|
1019 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1020 |
+
label_pad_token_id=self.label_pad_token_id,
|
1021 |
+
padding_value=self.padding_value,
|
1022 |
+
device=self.accelerator.device,
|
1023 |
+
)
|
1024 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
1025 |
+
|
1026 |
+
model_kwargs = (
|
1027 |
+
{
|
1028 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1029 |
+
}
|
1030 |
+
if self.is_encoder_decoder
|
1031 |
+
else {}
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
if self.aux_loss_enabled:
|
1035 |
+
model_kwargs["output_router_logits"] = True
|
1036 |
+
|
1037 |
+
outputs = model(
|
1038 |
+
concatenated_batch["concatenated_input_ids"],
|
1039 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1040 |
+
use_cache=False,
|
1041 |
+
**model_kwargs,
|
1042 |
+
)
|
1043 |
+
all_logits = outputs.logits
|
1044 |
+
|
1045 |
+
def cross_entropy_loss(logits, labels):
|
1046 |
+
if not self.is_encoder_decoder:
|
1047 |
+
# Shift so that tokens < n predict n
|
1048 |
+
logits = logits[..., :-1, :].contiguous()
|
1049 |
+
labels = labels[..., 1:].contiguous()
|
1050 |
+
# Flatten the tokens
|
1051 |
+
loss_fct = nn.CrossEntropyLoss()
|
1052 |
+
logits = logits.view(-1, logits.shape[-1])
|
1053 |
+
labels = labels.view(-1)
|
1054 |
+
# Enable model parallelism
|
1055 |
+
labels = labels.to(logits.device)
|
1056 |
+
loss = loss_fct(logits, labels)
|
1057 |
+
return loss
|
1058 |
+
|
1059 |
+
if self.is_encoder_decoder:
|
1060 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1061 |
+
else:
|
1062 |
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
1063 |
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
1064 |
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
1065 |
+
# orpo chosen nll loss is computed over the full prompt and response
|
1066 |
+
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1067 |
+
|
1068 |
+
all_logps = self.get_batch_logps(
|
1069 |
+
all_logits,
|
1070 |
+
concatenated_batch["concatenated_labels"],
|
1071 |
+
average_log_prob=True,
|
1072 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1073 |
+
label_pad_token_id=self.label_pad_token_id,
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
chosen_logps = all_logps[:len_chosen]
|
1077 |
+
rejected_logps = all_logps[len_chosen:]
|
1078 |
+
|
1079 |
+
if not self.is_encoder_decoder:
|
1080 |
+
chosen_logits = all_logits[:len_chosen, :-1, :]
|
1081 |
+
rejected_logits = all_logits[len_chosen:, :-1, :]
|
1082 |
+
else:
|
1083 |
+
chosen_logits = all_logits[:len_chosen]
|
1084 |
+
rejected_logits = all_logits[len_chosen:]
|
1085 |
+
|
1086 |
+
if self.aux_loss_enabled:
|
1087 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
1088 |
+
|
1089 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
1090 |
+
|
1091 |
+
def get_batch_loss_metrics(
|
1092 |
+
self,
|
1093 |
+
model,
|
1094 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1095 |
+
train_eval: Literal["train", "eval"] = "train",
|
1096 |
+
):
|
1097 |
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
1098 |
+
metrics = {}
|
1099 |
+
|
1100 |
+
forward_output = self.concatenated_forward(model, batch)
|
1101 |
+
(
|
1102 |
+
policy_chosen_logps,
|
1103 |
+
policy_rejected_logps,
|
1104 |
+
policy_chosen_logits,
|
1105 |
+
policy_rejected_logits,
|
1106 |
+
policy_nll_loss,
|
1107 |
+
) = forward_output[:5]
|
1108 |
+
if self.aux_loss_enabled:
|
1109 |
+
aux_loss = forward_output[5]
|
1110 |
+
|
1111 |
+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
1112 |
+
policy_chosen_logps, policy_rejected_logps
|
1113 |
+
)
|
1114 |
+
# full ORPO loss
|
1115 |
+
loss = policy_nll_loss - losses.mean()
|
1116 |
+
|
1117 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1118 |
+
|
1119 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1120 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
1121 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
1122 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
1123 |
+
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
1124 |
+
chosen_rewards - rejected_rewards
|
1125 |
+
).mean()
|
1126 |
+
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
1127 |
+
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
1128 |
+
metrics[f"{prefix}logits/rejected"] = (
|
1129 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
|
1130 |
+
)
|
1131 |
+
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
|
1132 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
1133 |
+
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
|
1134 |
+
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
|
1135 |
+
if is_torch_xla_available():
|
1136 |
+
xm.mark_step() # needed because .item() calls
|
1137 |
+
for k, v in metrics.items():
|
1138 |
+
metrics[k] = v.item()
|
1139 |
+
if self.aux_loss_enabled:
|
1140 |
+
loss += self.aux_loss_coef * aux_loss
|
1141 |
+
|
1142 |
+
return loss, metrics
|
1143 |
+
|
1144 |
+
def compute_loss(
|
1145 |
+
self,
|
1146 |
+
model: Union[PreTrainedModel, nn.Module],
|
1147 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1148 |
+
return_outputs=False,
|
1149 |
+
num_items_in_batch=None,
|
1150 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1151 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1152 |
+
|
1153 |
+
with compute_loss_context_manager:
|
1154 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1155 |
+
|
1156 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1157 |
+
loss = loss.to(self.args.device)
|
1158 |
+
|
1159 |
+
# force log the metrics
|
1160 |
+
self.store_metrics(metrics, train_eval="train")
|
1161 |
+
|
1162 |
+
if return_outputs:
|
1163 |
+
return (loss, metrics)
|
1164 |
+
return loss
|
1165 |
+
|
1166 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1167 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1168 |
+
|
1169 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1170 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1171 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1172 |
+
|
1173 |
+
with generate_context_manager:
|
1174 |
+
policy_output = model.generate(
|
1175 |
+
input_ids=batch["prompt_input_ids"],
|
1176 |
+
attention_mask=batch["prompt_attention_mask"],
|
1177 |
+
max_length=self.max_length,
|
1178 |
+
do_sample=True,
|
1179 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1180 |
+
)
|
1181 |
+
|
1182 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1183 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1184 |
+
|
1185 |
+
return policy_output_decoded
|
1186 |
+
|
1187 |
+
def prediction_step(
|
1188 |
+
self,
|
1189 |
+
model: Union[PreTrainedModel, nn.Module],
|
1190 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1191 |
+
prediction_loss_only: bool,
|
1192 |
+
ignore_keys: Optional[list[str]] = None,
|
1193 |
+
):
|
1194 |
+
if not self.use_dpo_data_collator:
|
1195 |
+
warnings.warn(
|
1196 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1197 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1198 |
+
)
|
1199 |
+
if ignore_keys is None:
|
1200 |
+
if hasattr(model, "config"):
|
1201 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1202 |
+
else:
|
1203 |
+
ignore_keys = []
|
1204 |
+
|
1205 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1206 |
+
|
1207 |
+
with torch.no_grad(), prediction_context_manager:
|
1208 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1209 |
+
|
1210 |
+
# force log the metrics
|
1211 |
+
self.store_metrics(metrics, train_eval="eval")
|
1212 |
+
|
1213 |
+
if prediction_loss_only:
|
1214 |
+
return (loss.detach(), None, None)
|
1215 |
+
|
1216 |
+
# logits for the chosen and rejected samples from model
|
1217 |
+
logits_dict = {
|
1218 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1219 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1220 |
+
}
|
1221 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1222 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1223 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1224 |
+
|
1225 |
+
return (loss.detach(), logits, labels)
|
1226 |
+
|
1227 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1228 |
+
for key, value in metrics.items():
|
1229 |
+
self._stored_metrics[train_eval][key].append(value)
|
1230 |
+
|
1231 |
+
def evaluation_loop(
|
1232 |
+
self,
|
1233 |
+
dataloader: DataLoader,
|
1234 |
+
description: str,
|
1235 |
+
prediction_loss_only: Optional[bool] = None,
|
1236 |
+
ignore_keys: Optional[list[str]] = None,
|
1237 |
+
metric_key_prefix: str = "eval",
|
1238 |
+
) -> EvalLoopOutput:
|
1239 |
+
"""
|
1240 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1241 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1242 |
+
|
1243 |
+
Works both with or without labels.
|
1244 |
+
"""
|
1245 |
+
|
1246 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1247 |
+
if self.generate_during_eval:
|
1248 |
+
# Generate random indices within the range of the total number of samples
|
1249 |
+
num_samples = len(dataloader.dataset)
|
1250 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1251 |
+
|
1252 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1253 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1254 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1255 |
+
random_batch = self._prepare_inputs(random_batch)
|
1256 |
+
|
1257 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1258 |
+
|
1259 |
+
table = pd.DataFrame(
|
1260 |
+
columns=["Prompt", "Policy"],
|
1261 |
+
data=[
|
1262 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1263 |
+
],
|
1264 |
+
)
|
1265 |
+
if "wandb" in self.args.report_to:
|
1266 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1267 |
+
|
1268 |
+
if "comet_ml" in self.args.report_to:
|
1269 |
+
log_table_to_comet_experiment(
|
1270 |
+
name="game_log.csv",
|
1271 |
+
table=table,
|
1272 |
+
)
|
1273 |
+
|
1274 |
+
# Base evaluation
|
1275 |
+
initial_output = super().evaluation_loop(
|
1276 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1277 |
+
)
|
1278 |
+
|
1279 |
+
return initial_output
|
1280 |
+
|
1281 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1282 |
+
"""
|
1283 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1284 |
+
|
1285 |
+
Args:
|
1286 |
+
logs (`dict[str, float]`):
|
1287 |
+
The values to log.
|
1288 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1289 |
+
Start time of the training.
|
1290 |
+
"""
|
1291 |
+
# logs either has 'loss' or 'eval_loss'
|
1292 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1293 |
+
# Add averaged stored metrics to logs
|
1294 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1295 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1296 |
+
del self._stored_metrics[train_eval]
|
1297 |
+
|
1298 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1299 |
+
return super().log(logs, start_time)
|
1300 |
+
else: # transformers<=4.46
|
1301 |
+
return super().log(logs)
|
1302 |
+
|
1303 |
+
def _shift_right(self, input_ids):
|
1304 |
+
if self.decoder_start_token_id is None:
|
1305 |
+
raise ValueError(
|
1306 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1307 |
+
)
|
1308 |
+
|
1309 |
+
# shift inputs to the right
|
1310 |
+
if is_torch_fx_proxy(input_ids):
|
1311 |
+
# Item assignment is not supported natively for proxies.
|
1312 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1313 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1314 |
+
else:
|
1315 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1316 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1317 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1318 |
+
|
1319 |
+
if self.pad_token_id is None:
|
1320 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1321 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1322 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1323 |
+
|
1324 |
+
return shifted_input_ids
|
1325 |
+
|
1326 |
+
def create_model_card(
|
1327 |
+
self,
|
1328 |
+
model_name: Optional[str] = None,
|
1329 |
+
dataset_name: Optional[str] = None,
|
1330 |
+
tags: Union[str, list[str], None] = None,
|
1331 |
+
):
|
1332 |
+
"""
|
1333 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1334 |
+
|
1335 |
+
Args:
|
1336 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1337 |
+
Name of the model.
|
1338 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1339 |
+
Name of the dataset used for training.
|
1340 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1341 |
+
Tags to be associated with the model card.
|
1342 |
+
"""
|
1343 |
+
if not self.is_world_process_zero():
|
1344 |
+
return
|
1345 |
+
|
1346 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1347 |
+
base_model = self.model.config._name_or_path
|
1348 |
+
else:
|
1349 |
+
base_model = None
|
1350 |
+
|
1351 |
+
tags = tags or []
|
1352 |
+
if isinstance(tags, str):
|
1353 |
+
tags = [tags]
|
1354 |
+
|
1355 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1356 |
+
tags.append("unsloth")
|
1357 |
+
|
1358 |
+
citation = textwrap.dedent("""\
|
1359 |
+
@article{hong2024orpo,
|
1360 |
+
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
1361 |
+
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
1362 |
+
year = 2024,
|
1363 |
+
eprint = {arXiv:2403.07691}
|
1364 |
+
}""")
|
1365 |
+
|
1366 |
+
model_card = generate_model_card(
|
1367 |
+
base_model=base_model,
|
1368 |
+
model_name=model_name,
|
1369 |
+
hub_model_id=self.hub_model_id,
|
1370 |
+
dataset_name=dataset_name,
|
1371 |
+
tags=tags,
|
1372 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1373 |
+
comet_url=get_comet_experiment_url(),
|
1374 |
+
trainer_name="ORPO",
|
1375 |
+
trainer_citation=citation,
|
1376 |
+
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
|
1377 |
+
paper_id="2403.07691",
|
1378 |
+
)
|
1379 |
+
|
1380 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1381 |
+
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
1382 |
+
"""
|
1383 |
+
|
1384 |
+
Initialize ORPOTrainer.
|
1385 |
+
|
1386 |
+
Args:
|
1387 |
+
model (`transformers.PreTrainedModel`):
|
1388 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1389 |
+
args (`ORPOConfig`):
|
1390 |
+
The ORPO config arguments to use for training.
|
1391 |
+
data_collator (`transformers.DataCollator`):
|
1392 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1393 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1394 |
+
train_dataset (`datasets.Dataset`):
|
1395 |
+
The dataset to use for training.
|
1396 |
+
eval_dataset (`datasets.Dataset`):
|
1397 |
+
The dataset to use for evaluation.
|
1398 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1399 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1400 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1401 |
+
reuse the fine-tuned model.
|
1402 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1403 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1404 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1405 |
+
The callbacks to use for training.
|
1406 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1407 |
+
The optimizer and scheduler to use for training.
|
1408 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1409 |
+
The function to use to preprocess the logits before computing the metrics.
|
1410 |
+
peft_config (`dict`, defaults to `None`):
|
1411 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1412 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1413 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1414 |
+
a dictionary string to metric values.
|
1415 |
+
|
1416 |
+
"""
|
1417 |
+
def __init__(
|
1418 |
+
self,
|
1419 |
+
model = None,
|
1420 |
+
args = None,
|
1421 |
+
data_collator = None,
|
1422 |
+
train_dataset = None,
|
1423 |
+
eval_dataset = None,
|
1424 |
+
processing_class = None,
|
1425 |
+
model_init = None,
|
1426 |
+
callbacks = None,
|
1427 |
+
preprocess_logits_for_metrics = None,
|
1428 |
+
peft_config = None,
|
1429 |
+
compute_metrics = None,
|
1430 |
+
**kwargs
|
1431 |
+
):
|
1432 |
+
if args is None: args = UnslothORPOConfig()
|
1433 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1434 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1435 |
+
force_float32 = False
|
1436 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1437 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1438 |
+
force_float32 = True
|
1439 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1440 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1441 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1442 |
+
from unsloth_zoo.utils import _get_dtype
|
1443 |
+
dtype = _get_dtype(dtype)
|
1444 |
+
float16 = dtype == torch.float16
|
1445 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1446 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1447 |
+
if force_float32:
|
1448 |
+
args.fp16 = False
|
1449 |
+
args.bf16 = False
|
1450 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1451 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1452 |
+
args.fp16 = float16
|
1453 |
+
args.bf16 = not float16
|
1454 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1455 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1456 |
+
args.eval_strategy = 'steps'
|
1457 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1458 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1459 |
+
if ga_steps is not None and ga_steps > 1:
|
1460 |
+
from transformers import __version__ as transformers_version
|
1461 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1462 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1463 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1464 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1465 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1466 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1467 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1468 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1469 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1470 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1471 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1472 |
+
if force_float32:
|
1473 |
+
args.bf16_full_eval = False
|
1474 |
+
args.fp16_full_eval = False
|
1475 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1476 |
+
args.bf16_full_eval = True
|
1477 |
+
args.fp16_full_eval = False
|
1478 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1479 |
+
args.bf16_full_eval = args.bf16
|
1480 |
+
args.fp16_full_eval = args.fp16
|
1481 |
+
_output_logits = False
|
1482 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1483 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1484 |
+
if _output_logits:
|
1485 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1486 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1487 |
+
pass
|
1488 |
+
else:
|
1489 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1490 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1491 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1492 |
+
max_seq_length = model.max_seq_length
|
1493 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1494 |
+
if model is not None and hasattr(model, 'for_training'):
|
1495 |
+
model.for_training()
|
1496 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1497 |
+
if 'processing_class' in locals():
|
1498 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1499 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1500 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1501 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1502 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1503 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1504 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1505 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1506 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1507 |
+
else:
|
1508 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1509 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1510 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1511 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1512 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1513 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1514 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1515 |
+
else:
|
1516 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1517 |
+
other_metrics = []
|
1518 |
+
|
1519 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1520 |
+
PatchRLStatistics('orpo_trainer', other_metrics)
|
1521 |
+
|
1522 |
+
super().__init__(
|
1523 |
+
model = model,
|
1524 |
+
args = args,
|
1525 |
+
data_collator = data_collator,
|
1526 |
+
train_dataset = train_dataset,
|
1527 |
+
eval_dataset = eval_dataset,
|
1528 |
+
processing_class = processing_class,
|
1529 |
+
model_init = model_init,
|
1530 |
+
callbacks = callbacks,
|
1531 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1532 |
+
peft_config = peft_config,
|
1533 |
+
compute_metrics = compute_metrics,**kwargs)
|
1534 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1535 |
+
self.neftune_hook_handle.remove()
|
1536 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1537 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1538 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1539 |
+
pass
|
1540 |
+
|
1541 |
+
pass
|
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
ADDED
@@ -0,0 +1,1267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
def vLLMSamplingParams(**kwargs):
|
43 |
+
from vllm import SamplingParams
|
44 |
+
sampling_params = SamplingParams(**kwargs)
|
45 |
+
sampling_params._set_kwargs = kwargs
|
46 |
+
return sampling_params
|
47 |
+
@dataclass
|
48 |
+
class UnslothOnlineDPOConfig(OnlineDPOConfig):
|
49 |
+
"""
|
50 |
+
|
51 |
+
Configuration class for the [`OnlineDPOTrainer`].
|
52 |
+
|
53 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
54 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
55 |
+
command line.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
59 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
60 |
+
[`~transformers.TrainingArguments`].
|
61 |
+
reward_model_path (`str` or `None`, *optional*, defaults to `None`):
|
62 |
+
Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
|
63 |
+
judge (`str` or `None`, *optional*, defaults to `None`):
|
64 |
+
Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
|
65 |
+
max_new_tokens (`int`, *optional*, defaults to `64`):
|
66 |
+
Maximum number of tokens to generate per completion.
|
67 |
+
max_length (`int`, *optional*, defaults to `256`):
|
68 |
+
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
|
69 |
+
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
|
70 |
+
possible.
|
71 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
72 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
73 |
+
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
|
74 |
+
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
|
75 |
+
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
|
76 |
+
value.
|
77 |
+
beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
|
78 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
79 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
80 |
+
the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
|
81 |
+
selected for each new epoch and the last β is used for the rest of the epochs.
|
82 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
83 |
+
Type of loss to use. Possible values are:
|
84 |
+
|
85 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
86 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
87 |
+
|
88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
+
Number of processes to use for processing the dataset.
|
90 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
91 |
+
Whether to disable dropout in the model and reference model.
|
92 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
|
94 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
95 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
96 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
97 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
98 |
+
|
99 |
+
"""
|
100 |
+
vllm_sampling_params: Optional[Any] = field(
|
101 |
+
default = None,
|
102 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
103 |
+
)
|
104 |
+
unsloth_num_chunks : Optional[int] = field(
|
105 |
+
default = -1,
|
106 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
107 |
+
)
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
output_dir = None,
|
111 |
+
overwrite_output_dir = None,
|
112 |
+
do_train = False,
|
113 |
+
do_eval = False,
|
114 |
+
do_predict = False,
|
115 |
+
eval_strategy = 'no',
|
116 |
+
prediction_loss_only = False,
|
117 |
+
per_device_train_batch_size = 4,
|
118 |
+
per_device_eval_batch_size = 4,
|
119 |
+
per_gpu_train_batch_size = None,
|
120 |
+
per_gpu_eval_batch_size = None,
|
121 |
+
gradient_accumulation_steps = 2,
|
122 |
+
eval_accumulation_steps = 2,
|
123 |
+
eval_delay = 0,
|
124 |
+
torch_empty_cache_steps = 250,
|
125 |
+
learning_rate = 5e-05,
|
126 |
+
weight_decay = 0.01,
|
127 |
+
adam_beta1 = 0.9,
|
128 |
+
adam_beta2 = 0.999,
|
129 |
+
adam_epsilon = 1e-08,
|
130 |
+
max_grad_norm = 1.0,
|
131 |
+
num_train_epochs = 3.0,
|
132 |
+
max_steps = -1,
|
133 |
+
lr_scheduler_type = 'linear',
|
134 |
+
warmup_ratio = 0.1,
|
135 |
+
warmup_steps = 0,
|
136 |
+
log_level = 'passive',
|
137 |
+
log_level_replica = 'warning',
|
138 |
+
log_on_each_node = True,
|
139 |
+
logging_dir = None,
|
140 |
+
logging_strategy = 'steps',
|
141 |
+
logging_first_step = False,
|
142 |
+
logging_steps = 1,
|
143 |
+
logging_nan_inf_filter = False,
|
144 |
+
save_strategy = 'steps',
|
145 |
+
save_steps = 500,
|
146 |
+
save_total_limit = None,
|
147 |
+
save_safetensors = True,
|
148 |
+
save_on_each_node = False,
|
149 |
+
save_only_model = False,
|
150 |
+
restore_callback_states_from_checkpoint = False,
|
151 |
+
no_cuda = False,
|
152 |
+
use_cpu = False,
|
153 |
+
use_mps_device = False,
|
154 |
+
seed = 3407,
|
155 |
+
data_seed = 3407,
|
156 |
+
jit_mode_eval = False,
|
157 |
+
use_ipex = False,
|
158 |
+
bf16 = False,
|
159 |
+
fp16 = False,
|
160 |
+
fp16_opt_level = 'O1',
|
161 |
+
half_precision_backend = 'auto',
|
162 |
+
bf16_full_eval = False,
|
163 |
+
fp16_full_eval = False,
|
164 |
+
tf32 = None,
|
165 |
+
local_rank = -1,
|
166 |
+
ddp_backend = None,
|
167 |
+
tpu_num_cores = None,
|
168 |
+
tpu_metrics_debug = False,
|
169 |
+
debug = '',
|
170 |
+
dataloader_drop_last = False,
|
171 |
+
eval_steps = None,
|
172 |
+
dataloader_num_workers = 0,
|
173 |
+
dataloader_prefetch_factor = None,
|
174 |
+
past_index = -1,
|
175 |
+
run_name = None,
|
176 |
+
disable_tqdm = None,
|
177 |
+
remove_unused_columns = True,
|
178 |
+
label_names = None,
|
179 |
+
load_best_model_at_end = False,
|
180 |
+
metric_for_best_model = None,
|
181 |
+
greater_is_better = None,
|
182 |
+
ignore_data_skip = False,
|
183 |
+
fsdp = '',
|
184 |
+
fsdp_min_num_params = 0,
|
185 |
+
fsdp_config = None,
|
186 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
187 |
+
accelerator_config = None,
|
188 |
+
deepspeed = None,
|
189 |
+
label_smoothing_factor = 0.0,
|
190 |
+
optim = 'adamw_8bit',
|
191 |
+
optim_args = None,
|
192 |
+
adafactor = False,
|
193 |
+
group_by_length = False,
|
194 |
+
length_column_name = 'length',
|
195 |
+
report_to = None,
|
196 |
+
ddp_find_unused_parameters = None,
|
197 |
+
ddp_bucket_cap_mb = None,
|
198 |
+
ddp_broadcast_buffers = None,
|
199 |
+
dataloader_pin_memory = True,
|
200 |
+
dataloader_persistent_workers = False,
|
201 |
+
skip_memory_metrics = True,
|
202 |
+
use_legacy_prediction_loop = False,
|
203 |
+
push_to_hub = False,
|
204 |
+
resume_from_checkpoint = None,
|
205 |
+
hub_model_id = None,
|
206 |
+
hub_strategy = 'every_save',
|
207 |
+
hub_token = None,
|
208 |
+
hub_private_repo = None,
|
209 |
+
hub_always_push = False,
|
210 |
+
gradient_checkpointing = False,
|
211 |
+
gradient_checkpointing_kwargs = None,
|
212 |
+
include_inputs_for_metrics = False,
|
213 |
+
eval_do_concat_batches = True,
|
214 |
+
fp16_backend = 'auto',
|
215 |
+
evaluation_strategy = None,
|
216 |
+
push_to_hub_model_id = None,
|
217 |
+
push_to_hub_organization = None,
|
218 |
+
push_to_hub_token = None,
|
219 |
+
mp_parameters = '',
|
220 |
+
auto_find_batch_size = False,
|
221 |
+
full_determinism = False,
|
222 |
+
torchdynamo = None,
|
223 |
+
ray_scope = 'last',
|
224 |
+
ddp_timeout = 1800,
|
225 |
+
torch_compile = False,
|
226 |
+
torch_compile_backend = None,
|
227 |
+
torch_compile_mode = None,
|
228 |
+
dispatch_batches = None,
|
229 |
+
split_batches = None,
|
230 |
+
include_tokens_per_second = False,
|
231 |
+
include_num_input_tokens_seen = False,
|
232 |
+
neftune_noise_alpha = None,
|
233 |
+
optim_target_modules = None,
|
234 |
+
batch_eval_metrics = False,
|
235 |
+
eval_on_start = False,
|
236 |
+
use_liger_kernel = False,
|
237 |
+
eval_use_gather_object = False,
|
238 |
+
average_tokens_across_devices = False,
|
239 |
+
reward_model_path = None,
|
240 |
+
judge = None,
|
241 |
+
max_new_tokens = 64,
|
242 |
+
max_length = 512,
|
243 |
+
temperature = 0.9,
|
244 |
+
missing_eos_penalty = None,
|
245 |
+
loss_type = 'sigmoid',
|
246 |
+
dataset_num_proc = None,
|
247 |
+
disable_dropout = True,
|
248 |
+
use_vllm = False,
|
249 |
+
ds3_gather_for_generation = True,
|
250 |
+
vllm_sampling_params = None,
|
251 |
+
unsloth_num_chunks = -1,
|
252 |
+
**kwargs,
|
253 |
+
):
|
254 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
255 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
256 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
257 |
+
output_dir = 'unsloth_training_checkpoints'
|
258 |
+
save_strategy = 'no'
|
259 |
+
if dataset_num_proc is None:
|
260 |
+
from multiprocessing import cpu_count
|
261 |
+
dataset_num_proc = cpu_count()
|
262 |
+
|
263 |
+
super().__init__(
|
264 |
+
output_dir = output_dir,
|
265 |
+
overwrite_output_dir = overwrite_output_dir,
|
266 |
+
do_train = do_train,
|
267 |
+
do_eval = do_eval,
|
268 |
+
do_predict = do_predict,
|
269 |
+
eval_strategy = eval_strategy,
|
270 |
+
prediction_loss_only = prediction_loss_only,
|
271 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
272 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
273 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
274 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
275 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
276 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
277 |
+
eval_delay = eval_delay,
|
278 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
279 |
+
learning_rate = learning_rate,
|
280 |
+
weight_decay = weight_decay,
|
281 |
+
adam_beta1 = adam_beta1,
|
282 |
+
adam_beta2 = adam_beta2,
|
283 |
+
adam_epsilon = adam_epsilon,
|
284 |
+
max_grad_norm = max_grad_norm,
|
285 |
+
num_train_epochs = num_train_epochs,
|
286 |
+
max_steps = max_steps,
|
287 |
+
lr_scheduler_type = lr_scheduler_type,
|
288 |
+
warmup_ratio = warmup_ratio,
|
289 |
+
warmup_steps = warmup_steps,
|
290 |
+
log_level = log_level,
|
291 |
+
log_level_replica = log_level_replica,
|
292 |
+
log_on_each_node = log_on_each_node,
|
293 |
+
logging_dir = logging_dir,
|
294 |
+
logging_strategy = logging_strategy,
|
295 |
+
logging_first_step = logging_first_step,
|
296 |
+
logging_steps = logging_steps,
|
297 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
298 |
+
save_strategy = save_strategy,
|
299 |
+
save_steps = save_steps,
|
300 |
+
save_total_limit = save_total_limit,
|
301 |
+
save_safetensors = save_safetensors,
|
302 |
+
save_on_each_node = save_on_each_node,
|
303 |
+
save_only_model = save_only_model,
|
304 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
305 |
+
no_cuda = no_cuda,
|
306 |
+
use_cpu = use_cpu,
|
307 |
+
use_mps_device = use_mps_device,
|
308 |
+
seed = seed,
|
309 |
+
data_seed = data_seed,
|
310 |
+
jit_mode_eval = jit_mode_eval,
|
311 |
+
use_ipex = use_ipex,
|
312 |
+
bf16 = bf16,
|
313 |
+
fp16 = fp16,
|
314 |
+
fp16_opt_level = fp16_opt_level,
|
315 |
+
half_precision_backend = half_precision_backend,
|
316 |
+
bf16_full_eval = bf16_full_eval,
|
317 |
+
fp16_full_eval = fp16_full_eval,
|
318 |
+
tf32 = tf32,
|
319 |
+
local_rank = local_rank,
|
320 |
+
ddp_backend = ddp_backend,
|
321 |
+
tpu_num_cores = tpu_num_cores,
|
322 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
323 |
+
debug = debug,
|
324 |
+
dataloader_drop_last = dataloader_drop_last,
|
325 |
+
eval_steps = eval_steps,
|
326 |
+
dataloader_num_workers = dataloader_num_workers,
|
327 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
328 |
+
past_index = past_index,
|
329 |
+
run_name = run_name,
|
330 |
+
disable_tqdm = disable_tqdm,
|
331 |
+
remove_unused_columns = remove_unused_columns,
|
332 |
+
label_names = label_names,
|
333 |
+
load_best_model_at_end = load_best_model_at_end,
|
334 |
+
metric_for_best_model = metric_for_best_model,
|
335 |
+
greater_is_better = greater_is_better,
|
336 |
+
ignore_data_skip = ignore_data_skip,
|
337 |
+
fsdp = fsdp,
|
338 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
339 |
+
fsdp_config = fsdp_config,
|
340 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
341 |
+
accelerator_config = accelerator_config,
|
342 |
+
deepspeed = deepspeed,
|
343 |
+
label_smoothing_factor = label_smoothing_factor,
|
344 |
+
optim = optim,
|
345 |
+
optim_args = optim_args,
|
346 |
+
adafactor = adafactor,
|
347 |
+
group_by_length = group_by_length,
|
348 |
+
length_column_name = length_column_name,
|
349 |
+
report_to = report_to,
|
350 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
351 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
352 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
353 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
354 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
355 |
+
skip_memory_metrics = skip_memory_metrics,
|
356 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
357 |
+
push_to_hub = push_to_hub,
|
358 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
359 |
+
hub_model_id = hub_model_id,
|
360 |
+
hub_strategy = hub_strategy,
|
361 |
+
hub_token = hub_token,
|
362 |
+
hub_private_repo = hub_private_repo,
|
363 |
+
hub_always_push = hub_always_push,
|
364 |
+
gradient_checkpointing = gradient_checkpointing,
|
365 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
366 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
367 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
368 |
+
fp16_backend = fp16_backend,
|
369 |
+
evaluation_strategy = evaluation_strategy,
|
370 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
371 |
+
push_to_hub_organization = push_to_hub_organization,
|
372 |
+
push_to_hub_token = push_to_hub_token,
|
373 |
+
mp_parameters = mp_parameters,
|
374 |
+
auto_find_batch_size = auto_find_batch_size,
|
375 |
+
full_determinism = full_determinism,
|
376 |
+
torchdynamo = torchdynamo,
|
377 |
+
ray_scope = ray_scope,
|
378 |
+
ddp_timeout = ddp_timeout,
|
379 |
+
torch_compile = torch_compile,
|
380 |
+
torch_compile_backend = torch_compile_backend,
|
381 |
+
torch_compile_mode = torch_compile_mode,
|
382 |
+
dispatch_batches = dispatch_batches,
|
383 |
+
split_batches = split_batches,
|
384 |
+
include_tokens_per_second = include_tokens_per_second,
|
385 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
386 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
387 |
+
optim_target_modules = optim_target_modules,
|
388 |
+
batch_eval_metrics = batch_eval_metrics,
|
389 |
+
eval_on_start = eval_on_start,
|
390 |
+
use_liger_kernel = use_liger_kernel,
|
391 |
+
eval_use_gather_object = eval_use_gather_object,
|
392 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
393 |
+
reward_model_path = reward_model_path,
|
394 |
+
judge = judge,
|
395 |
+
max_new_tokens = max_new_tokens,
|
396 |
+
max_length = max_length,
|
397 |
+
temperature = temperature,
|
398 |
+
missing_eos_penalty = missing_eos_penalty,
|
399 |
+
loss_type = loss_type,
|
400 |
+
dataset_num_proc = dataset_num_proc,
|
401 |
+
disable_dropout = disable_dropout,
|
402 |
+
use_vllm = use_vllm,
|
403 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
404 |
+
self.vllm_sampling_params = vllm_sampling_params
|
405 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
406 |
+
pass
|
407 |
+
|
408 |
+
class _UnslothOnlineDPOTrainer(Trainer):
|
409 |
+
r""""""
|
410 |
+
|
411 |
+
_tag_names = ["trl", "online-dpo"]
|
412 |
+
|
413 |
+
def __init__(
|
414 |
+
self,
|
415 |
+
model: Union[PreTrainedModel, nn.Module],
|
416 |
+
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
417 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
418 |
+
judge: Optional[BasePairwiseJudge] = None,
|
419 |
+
args: Optional[OnlineDPOConfig] = None,
|
420 |
+
data_collator: Optional[DataCollator] = None,
|
421 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
422 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
|
423 |
+
processing_class: Optional[
|
424 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
425 |
+
] = None,
|
426 |
+
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
427 |
+
peft_config: Optional[dict] = None,
|
428 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
429 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
430 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
431 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
432 |
+
) -> None:
|
433 |
+
|
434 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
435 |
+
if ref_model is model:
|
436 |
+
raise ValueError(
|
437 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
438 |
+
"same as `model`, either omit the `ref_model` argument or pass `None`."
|
439 |
+
)
|
440 |
+
|
441 |
+
self.ref_model = ref_model
|
442 |
+
|
443 |
+
if reward_model is not None and judge is not None:
|
444 |
+
warnings.warn(
|
445 |
+
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
|
446 |
+
"Ignoring `judge` and using `reward_model`.",
|
447 |
+
UserWarning,
|
448 |
+
)
|
449 |
+
judge = None
|
450 |
+
elif reward_model is None and judge is None:
|
451 |
+
raise ValueError("Either `reward_model` or `judge` must be provided.")
|
452 |
+
|
453 |
+
self.reward_model = reward_model
|
454 |
+
self.reward_processing_class = reward_processing_class
|
455 |
+
self.judge = judge
|
456 |
+
|
457 |
+
if args.missing_eos_penalty is not None and judge is not None:
|
458 |
+
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
|
459 |
+
|
460 |
+
if args is None:
|
461 |
+
raise ValueError("`args` must be provided.")
|
462 |
+
|
463 |
+
# Check that the processing_class is provided
|
464 |
+
if processing_class is None:
|
465 |
+
raise ValueError("`processing_class` must be provided.")
|
466 |
+
|
467 |
+
# Convert to PEFT model if peft_config is provided
|
468 |
+
if False:
|
469 |
+
# Check if PEFT is available
|
470 |
+
if not is_peft_available():
|
471 |
+
raise ImportError(
|
472 |
+
"PEFT is not available and passed `peft_config`. Please install PEFT with "
|
473 |
+
"`pip install peft` to use it."
|
474 |
+
)
|
475 |
+
|
476 |
+
# If the model is already a PeftModel, we need to merge and unload it.
|
477 |
+
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
|
478 |
+
if isinstance(model, PeftModel):
|
479 |
+
model = model.merge_and_unload()
|
480 |
+
|
481 |
+
# Get peft model with the given config
|
482 |
+
model = model
|
483 |
+
|
484 |
+
# Disable dropout in the model and reference model
|
485 |
+
if args.disable_dropout:
|
486 |
+
disable_dropout_in_model(model)
|
487 |
+
if self.ref_model is not None:
|
488 |
+
disable_dropout_in_model(self.ref_model)
|
489 |
+
|
490 |
+
# Handle the ref_model
|
491 |
+
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
|
492 |
+
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
|
493 |
+
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
|
494 |
+
if ref_model is None: # No ref model provided, the most common case
|
495 |
+
if False:
|
496 |
+
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
|
497 |
+
else:
|
498 |
+
self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
|
499 |
+
else: # rare case, the user provided a ref model
|
500 |
+
self.ref_model = ref_model
|
501 |
+
self.ref_model.eval()
|
502 |
+
|
503 |
+
# Disable the gradient and set the reward model in eval mode
|
504 |
+
if self.reward_model is not None:
|
505 |
+
self.reward_model.eval()
|
506 |
+
|
507 |
+
# Define the collator is not provided
|
508 |
+
if data_collator is None:
|
509 |
+
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
|
510 |
+
|
511 |
+
self.max_length = args.max_length
|
512 |
+
|
513 |
+
self.stats = {
|
514 |
+
"objective/kl": [],
|
515 |
+
"objective/entropy": [],
|
516 |
+
"objective/non_score_reward": [],
|
517 |
+
"rewards/chosen": [],
|
518 |
+
"rewards/rejected": [],
|
519 |
+
"rewards/accuracies": [],
|
520 |
+
"rewards/margins": [],
|
521 |
+
"logps/chosen": [],
|
522 |
+
"logps/rejected": [],
|
523 |
+
"val/contain_eos_token": [],
|
524 |
+
"beta": [],
|
525 |
+
}
|
526 |
+
if self.reward_model is not None:
|
527 |
+
self.stats["objective/rlhf_reward"] = []
|
528 |
+
self.stats["objective/scores_margin"] = []
|
529 |
+
self.stats["objective/scores"] = []
|
530 |
+
|
531 |
+
if args.use_vllm:
|
532 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
|
533 |
+
n=2, max_tokens=args.max_new_tokens,
|
534 |
+
temperature=args.temperature,
|
535 |
+
top_k=50,
|
536 |
+
top_p=1.0,
|
537 |
+
detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
538 |
+
else:
|
539 |
+
self.generation_config = GenerationConfig(
|
540 |
+
max_new_tokens=args.max_new_tokens,
|
541 |
+
temperature=args.temperature,
|
542 |
+
top_k=50,
|
543 |
+
top_p=1.0,
|
544 |
+
do_sample=True,
|
545 |
+
use_cache=False if args.gradient_checkpointing else True,
|
546 |
+
)
|
547 |
+
|
548 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
549 |
+
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
|
550 |
+
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
551 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
552 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
553 |
+
# that the warning has already been issued.
|
554 |
+
model.warnings_issued["estimate_tokens"] = True
|
555 |
+
|
556 |
+
super().__init__(
|
557 |
+
model=model,
|
558 |
+
args=args,
|
559 |
+
data_collator=data_collator,
|
560 |
+
train_dataset=train_dataset,
|
561 |
+
eval_dataset=eval_dataset,
|
562 |
+
processing_class=processing_class,
|
563 |
+
compute_metrics=compute_metrics,
|
564 |
+
callbacks=callbacks,
|
565 |
+
optimizers=optimizers,
|
566 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
567 |
+
)
|
568 |
+
|
569 |
+
# Add tags for models that have been loaded with the correct transformers version
|
570 |
+
if hasattr(self.model, "add_model_tags"):
|
571 |
+
self.model.add_model_tags(self._tag_names)
|
572 |
+
|
573 |
+
self._beta = args.beta
|
574 |
+
|
575 |
+
# Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
|
576 |
+
if self.is_deepspeed_enabled:
|
577 |
+
if self.reward_model is not None:
|
578 |
+
self.reward_model = prepare_deepspeed(
|
579 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
580 |
+
)
|
581 |
+
if self.ref_model is not None:
|
582 |
+
self.ref_model = prepare_deepspeed(
|
583 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
584 |
+
)
|
585 |
+
else:
|
586 |
+
if self.ref_model is not None:
|
587 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
588 |
+
if self.reward_model is not None:
|
589 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
590 |
+
|
591 |
+
@property
|
592 |
+
def beta(self):
|
593 |
+
if isinstance(self._beta, list):
|
594 |
+
epoch = self.state.epoch
|
595 |
+
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
|
596 |
+
else:
|
597 |
+
return self._beta
|
598 |
+
|
599 |
+
@staticmethod
|
600 |
+
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
|
601 |
+
"""Tokenize a single row from a DPO specific dataset."""
|
602 |
+
if not is_encoder_decoder:
|
603 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=False)
|
604 |
+
# Add BOS token to head of prompt. Avoid adding if it's already there
|
605 |
+
if tokenizer.bos_token_id is not None:
|
606 |
+
prompt_len_input_ids = len(batch["input_ids"])
|
607 |
+
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
|
608 |
+
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
|
609 |
+
batch["attention_mask"] = [1] + batch["attention_mask"]
|
610 |
+
else:
|
611 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=True)
|
612 |
+
batch = {f"prompt_{key}": value for key, value in batch.items()}
|
613 |
+
return batch
|
614 |
+
|
615 |
+
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
|
616 |
+
@wraps(Trainer.get_train_dataloader)
|
617 |
+
def get_train_dataloader(self) -> DataLoader:
|
618 |
+
if self.train_dataset is None:
|
619 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
620 |
+
|
621 |
+
train_dataset = self.train_dataset
|
622 |
+
data_collator = self.data_collator
|
623 |
+
dataloader_params = {
|
624 |
+
"batch_size": self._train_batch_size,
|
625 |
+
"collate_fn": data_collator,
|
626 |
+
"num_workers": self.args.dataloader_num_workers,
|
627 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
628 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
629 |
+
}
|
630 |
+
|
631 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
632 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
633 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
634 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
635 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
636 |
+
|
637 |
+
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
638 |
+
|
639 |
+
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
640 |
+
@wraps(Trainer.get_eval_dataloader)
|
641 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
642 |
+
if eval_dataset is None and self.eval_dataset is None:
|
643 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
644 |
+
|
645 |
+
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
646 |
+
# don't change during training
|
647 |
+
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
648 |
+
if (
|
649 |
+
hasattr(self, "_eval_dataloaders")
|
650 |
+
and dataloader_key in self._eval_dataloaders
|
651 |
+
and self.args.dataloader_persistent_workers
|
652 |
+
):
|
653 |
+
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
654 |
+
|
655 |
+
eval_dataset = (
|
656 |
+
self.eval_dataset[eval_dataset]
|
657 |
+
if isinstance(eval_dataset, str)
|
658 |
+
else eval_dataset
|
659 |
+
if eval_dataset is not None
|
660 |
+
else self.eval_dataset
|
661 |
+
)
|
662 |
+
data_collator = self.data_collator
|
663 |
+
|
664 |
+
dataloader_params = {
|
665 |
+
"batch_size": self.args.eval_batch_size,
|
666 |
+
"collate_fn": data_collator,
|
667 |
+
"num_workers": self.args.dataloader_num_workers,
|
668 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
669 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
670 |
+
}
|
671 |
+
|
672 |
+
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
673 |
+
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
674 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
675 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
676 |
+
|
677 |
+
# accelerator.free_memory() will destroy the references, so
|
678 |
+
# we need to store the non-prepared version
|
679 |
+
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
680 |
+
if self.args.dataloader_persistent_workers:
|
681 |
+
if hasattr(self, "_eval_dataloaders"):
|
682 |
+
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
683 |
+
else:
|
684 |
+
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
685 |
+
|
686 |
+
return self.accelerator.prepare(eval_dataloader)
|
687 |
+
|
688 |
+
def _generate_vllm(self, model, prompts):
|
689 |
+
eos_token_id = self.processing_class.eos_token_id
|
690 |
+
pad_token_id = self.processing_class.pad_token_id
|
691 |
+
|
692 |
+
# Load the latest weights
|
693 |
+
|
694 |
+
pass
|
695 |
+
|
696 |
+
pass
|
697 |
+
|
698 |
+
if is_conversational({"prompt": prompts[0]}):
|
699 |
+
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
700 |
+
else:
|
701 |
+
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
702 |
+
|
703 |
+
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
|
704 |
+
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
|
705 |
+
|
706 |
+
# Create mask and pad the prompt and completion
|
707 |
+
max_prompt_length = max(len(ids) for ids in prompt_ids)
|
708 |
+
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
|
709 |
+
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
|
710 |
+
max_tokens = self.generation_config.max_tokens
|
711 |
+
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
|
712 |
+
completion_ids = [
|
713 |
+
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
|
714 |
+
for ids in completion_ids
|
715 |
+
]
|
716 |
+
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
|
717 |
+
|
718 |
+
# Convert to tensors
|
719 |
+
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
|
720 |
+
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
|
721 |
+
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
|
722 |
+
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
|
723 |
+
|
724 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
725 |
+
|
726 |
+
def _generate(self, model, prompts):
|
727 |
+
eos_token_id = self.processing_class.eos_token_id
|
728 |
+
pad_token_id = self.processing_class.pad_token_id
|
729 |
+
|
730 |
+
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
|
731 |
+
# policies with different tokenizers / chat templates.
|
732 |
+
inputs = [{"prompt": prompt} for prompt in prompts]
|
733 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
734 |
+
inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
735 |
+
inputs = self.data_collator(inputs)
|
736 |
+
|
737 |
+
# Sample 2 completions per prompt of size `max_new_tokens` from the model
|
738 |
+
inputs = self._prepare_inputs(inputs)
|
739 |
+
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
|
740 |
+
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
|
741 |
+
with unwrap_model_for_generation(
|
742 |
+
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
743 |
+
) as unwrapped_model:
|
744 |
+
output = unwrapped_model.generate(
|
745 |
+
input_ids=prompt_ids,
|
746 |
+
attention_mask=prompt_mask,
|
747 |
+
generation_config=self.generation_config,
|
748 |
+
)
|
749 |
+
|
750 |
+
completion_ids = output[:, prompt_ids.size(1) :]
|
751 |
+
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
|
752 |
+
|
753 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
754 |
+
|
755 |
+
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
|
756 |
+
# Get the number of tokens to truncate from prompt
|
757 |
+
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
|
758 |
+
|
759 |
+
# Truncate left to avoid oom
|
760 |
+
prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
|
761 |
+
prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
|
762 |
+
|
763 |
+
# Concat the prompt and completion
|
764 |
+
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
765 |
+
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
766 |
+
|
767 |
+
# Get the logprobs of the completions from the model
|
768 |
+
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
|
769 |
+
|
770 |
+
# There is 1 offset, because the model predict the next token
|
771 |
+
logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
|
772 |
+
|
773 |
+
# Take the completion tokens logprob
|
774 |
+
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
775 |
+
return logprobs
|
776 |
+
|
777 |
+
def training_step(
|
778 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
779 |
+
) -> torch.Tensor:
|
780 |
+
model.train()
|
781 |
+
|
782 |
+
prompts = inputs["prompt"]
|
783 |
+
batch_size = len(prompts)
|
784 |
+
|
785 |
+
if self.args.use_vllm:
|
786 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
|
787 |
+
else:
|
788 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
|
789 |
+
|
790 |
+
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
|
791 |
+
|
792 |
+
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
793 |
+
with torch.no_grad():
|
794 |
+
if self.ref_model is not None:
|
795 |
+
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
796 |
+
else: # peft case: we just need to disable the adapter
|
797 |
+
with self.model.disable_adapter():
|
798 |
+
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
799 |
+
|
800 |
+
# Decode the completions, and format them if the input is conversational
|
801 |
+
device = logprobs.device
|
802 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
803 |
+
if is_conversational({"prompt": prompts[0]}):
|
804 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
805 |
+
|
806 |
+
# Get the reward from the reward model or judge
|
807 |
+
if self.judge is not None:
|
808 |
+
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
|
809 |
+
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
|
810 |
+
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
|
811 |
+
# template to it.
|
812 |
+
if is_conversational({"prompt": prompts[0]}):
|
813 |
+
environment = jinja2.Environment()
|
814 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
815 |
+
prompts = [template.render(messages=prompt) for prompt in prompts]
|
816 |
+
completions = [template.render(messages=completion) for completion in completions]
|
817 |
+
|
818 |
+
ranks_of_first_completion = self.judge.judge(
|
819 |
+
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
820 |
+
)
|
821 |
+
|
822 |
+
# convert ranks to a True/False mask:
|
823 |
+
# when rank == 0, it means the first completion is the best
|
824 |
+
# when rank == 1, it means the second completion is the best
|
825 |
+
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
|
826 |
+
else:
|
827 |
+
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
|
828 |
+
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
|
829 |
+
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
|
830 |
+
if is_conversational({"prompt": prompts[0]}):
|
831 |
+
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
|
832 |
+
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
|
833 |
+
prompts = [example["prompt"] for example in examples]
|
834 |
+
completions = [example["completion"] for example in examples]
|
835 |
+
|
836 |
+
# Tokenize the prompts
|
837 |
+
prompts_ids = self.reward_processing_class(
|
838 |
+
prompts, padding=True, return_tensors="pt", padding_side="left"
|
839 |
+
)["input_ids"].to(device)
|
840 |
+
context_length = prompts_ids.shape[1]
|
841 |
+
|
842 |
+
# Tokenize the completions
|
843 |
+
completions_ids = self.reward_processing_class(
|
844 |
+
completions, padding=True, return_tensors="pt", padding_side="right"
|
845 |
+
)["input_ids"].to(device)
|
846 |
+
|
847 |
+
# Concatenate the prompts and completions and get the reward
|
848 |
+
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
|
849 |
+
with torch.inference_mode():
|
850 |
+
_, scores, _ = get_reward(
|
851 |
+
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
|
852 |
+
)
|
853 |
+
|
854 |
+
# Filter completion. Ensure that the sample contains stop_token_id
|
855 |
+
# Completions not passing that filter will receive a lower score.
|
856 |
+
if self.args.missing_eos_penalty is not None:
|
857 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
858 |
+
|
859 |
+
# Split the scores in 2 (the prompts of the first half are the same as the second half)
|
860 |
+
first_half, second_half = scores.split(batch_size)
|
861 |
+
|
862 |
+
# Get the indices of the chosen and rejected examples
|
863 |
+
mask = first_half >= second_half
|
864 |
+
|
865 |
+
batch_range = torch.arange(batch_size, device=device)
|
866 |
+
chosen_indices = batch_range + (~mask * batch_size)
|
867 |
+
rejected_indices = batch_range + (mask * batch_size)
|
868 |
+
|
869 |
+
# Build tensor so that the first half is the chosen examples and the second half the rejected examples
|
870 |
+
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
|
871 |
+
cr_logprobs = logprobs[cr_indices]
|
872 |
+
cr_ref_logprobs = ref_logprobs[cr_indices]
|
873 |
+
|
874 |
+
# mask out the padding tokens
|
875 |
+
padding_mask = ~completion_mask.bool()
|
876 |
+
cr_padding_mask = padding_mask[cr_indices]
|
877 |
+
|
878 |
+
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
|
879 |
+
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
|
880 |
+
|
881 |
+
# Split the chosen and rejected examples
|
882 |
+
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
|
883 |
+
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
|
884 |
+
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
|
885 |
+
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
|
886 |
+
|
887 |
+
logits = pi_logratios - ref_logratios
|
888 |
+
|
889 |
+
if self.args.loss_type == "sigmoid":
|
890 |
+
losses = -F.logsigmoid(self.beta * logits)
|
891 |
+
elif self.args.loss_type == "ipo":
|
892 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
893 |
+
else:
|
894 |
+
raise NotImplementedError(f"invalid loss type {self.loss_type}")
|
895 |
+
|
896 |
+
loss = losses.mean()
|
897 |
+
|
898 |
+
# Log everything
|
899 |
+
if self.reward_model is not None:
|
900 |
+
scores_margin = scores[chosen_indices] - scores[rejected_indices]
|
901 |
+
self.stats["objective/scores_margin"].append(
|
902 |
+
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
|
903 |
+
)
|
904 |
+
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
|
905 |
+
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
|
906 |
+
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
|
907 |
+
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
|
908 |
+
|
909 |
+
kl = logprobs - ref_logprobs
|
910 |
+
mean_kl = kl.sum(1).mean()
|
911 |
+
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
912 |
+
non_score_reward = (-self.beta * kl).sum(1)
|
913 |
+
mean_non_score_reward = non_score_reward.mean()
|
914 |
+
self.stats["objective/non_score_reward"].append(
|
915 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
916 |
+
)
|
917 |
+
if self.reward_model is not None:
|
918 |
+
rlhf_reward = scores + non_score_reward
|
919 |
+
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
|
920 |
+
mean_entropy = -logprobs.sum(1).mean()
|
921 |
+
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
|
922 |
+
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
|
923 |
+
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
|
924 |
+
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
|
925 |
+
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
|
926 |
+
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
|
927 |
+
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
|
928 |
+
margin = gathered_chosen_rewards - gathered_rejected_rewards
|
929 |
+
self.stats["rewards/margins"].append(margin.mean().item())
|
930 |
+
accuracy = margin > 0
|
931 |
+
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
|
932 |
+
self.stats["beta"].append(self.beta)
|
933 |
+
|
934 |
+
if (
|
935 |
+
self.args.torch_empty_cache_steps is not None
|
936 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
937 |
+
):
|
938 |
+
empty_cache()
|
939 |
+
|
940 |
+
kwargs = {}
|
941 |
+
|
942 |
+
# For LOMO optimizers you need to explicitly use the learnign rate
|
943 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
944 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
945 |
+
|
946 |
+
if self.args.n_gpu > 1:
|
947 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
948 |
+
|
949 |
+
if self.use_apex:
|
950 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
951 |
+
scaled_loss.backward()
|
952 |
+
else:
|
953 |
+
self.accelerator.backward(loss, **kwargs)
|
954 |
+
|
955 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
956 |
+
|
957 |
+
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
958 |
+
# start_time defaults to None to allow compatibility with transformers<=4.46
|
959 |
+
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
|
960 |
+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
961 |
+
logs: dict[str, float] = {}
|
962 |
+
|
963 |
+
# all_gather + mean() to get average loss over all processes
|
964 |
+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
965 |
+
|
966 |
+
# reset tr_loss to zero
|
967 |
+
tr_loss -= tr_loss
|
968 |
+
|
969 |
+
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
970 |
+
if grad_norm is not None:
|
971 |
+
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
972 |
+
logs["learning_rate"] = self._get_learning_rate()
|
973 |
+
|
974 |
+
# Add our metrics
|
975 |
+
for key, val in self.stats.items():
|
976 |
+
logs[key] = sum(val) / len(val)
|
977 |
+
self.stats = {key: [] for key in self.stats} # reset stats
|
978 |
+
|
979 |
+
self._total_loss_scalar += tr_loss_scalar
|
980 |
+
self._globalstep_last_logged = self.state.global_step
|
981 |
+
self.store_flos()
|
982 |
+
|
983 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
984 |
+
self.log(logs, start_time)
|
985 |
+
else: # transformers<=4.46
|
986 |
+
self.log(logs)
|
987 |
+
|
988 |
+
metrics = None
|
989 |
+
if self.control.should_evaluate:
|
990 |
+
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
991 |
+
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
992 |
+
|
993 |
+
if self.args.save_strategy == "best":
|
994 |
+
self.control.should_save = is_new_best_metric
|
995 |
+
|
996 |
+
if self.control.should_save:
|
997 |
+
self._save_checkpoint(model, trial)
|
998 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
999 |
+
|
1000 |
+
# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
|
1001 |
+
# This can be removed once the minimum transformers version is updated to 4.47.
|
1002 |
+
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
|
1003 |
+
def _determine_best_metric(self, metrics, trial):
|
1004 |
+
"""
|
1005 |
+
Determine if the model should be saved based on the evaluation metrics.
|
1006 |
+
If args.metric_for_best_model is not set, the loss is used.
|
1007 |
+
Returns:
|
1008 |
+
bool: True if a new best metric was found, else False
|
1009 |
+
"""
|
1010 |
+
is_new_best_metric = False
|
1011 |
+
|
1012 |
+
if self.args.metric_for_best_model is not None:
|
1013 |
+
metric_to_check = self.args.metric_for_best_model
|
1014 |
+
|
1015 |
+
if not metric_to_check.startswith("eval_"):
|
1016 |
+
metric_to_check = f"eval_{metric_to_check}"
|
1017 |
+
|
1018 |
+
try:
|
1019 |
+
metric_value = metrics[metric_to_check]
|
1020 |
+
except KeyError as exc:
|
1021 |
+
raise KeyError(
|
1022 |
+
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
1023 |
+
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
1024 |
+
) from exc
|
1025 |
+
|
1026 |
+
operator = np.greater if self.args.greater_is_better else np.less
|
1027 |
+
|
1028 |
+
if self.state.best_metric is None:
|
1029 |
+
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
1030 |
+
|
1031 |
+
if operator(metric_value, self.state.best_metric):
|
1032 |
+
run_dir = self._get_output_dir(trial=trial)
|
1033 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
1034 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
1035 |
+
self.state.best_metric = metric_value
|
1036 |
+
self.state.best_model_checkpoint = output_dir
|
1037 |
+
|
1038 |
+
is_new_best_metric = True
|
1039 |
+
|
1040 |
+
return is_new_best_metric
|
1041 |
+
|
1042 |
+
def create_model_card(
|
1043 |
+
self,
|
1044 |
+
model_name: Optional[str] = None,
|
1045 |
+
dataset_name: Optional[str] = None,
|
1046 |
+
tags: Union[str, list[str], None] = None,
|
1047 |
+
):
|
1048 |
+
"""
|
1049 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1050 |
+
|
1051 |
+
Args:
|
1052 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1053 |
+
Name of the model.
|
1054 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1055 |
+
Name of the dataset used for training.
|
1056 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1057 |
+
Tags to be associated with the model card.
|
1058 |
+
"""
|
1059 |
+
if not self.is_world_process_zero():
|
1060 |
+
return
|
1061 |
+
|
1062 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1063 |
+
base_model = self.model.config._name_or_path
|
1064 |
+
else:
|
1065 |
+
base_model = None
|
1066 |
+
|
1067 |
+
tags = tags or []
|
1068 |
+
if isinstance(tags, str):
|
1069 |
+
tags = [tags]
|
1070 |
+
|
1071 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1072 |
+
tags.append("unsloth")
|
1073 |
+
|
1074 |
+
citation = textwrap.dedent("""\
|
1075 |
+
@article{guo2024direct,
|
1076 |
+
title = {{Direct Language Model Alignment from Online AI Feedback}},
|
1077 |
+
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
|
1078 |
+
year = 2024,
|
1079 |
+
eprint = {arXiv:2402.04792}
|
1080 |
+
}""")
|
1081 |
+
|
1082 |
+
model_card = generate_model_card(
|
1083 |
+
base_model=base_model,
|
1084 |
+
model_name=model_name,
|
1085 |
+
hub_model_id=self.hub_model_id,
|
1086 |
+
dataset_name=dataset_name,
|
1087 |
+
tags=tags,
|
1088 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1089 |
+
comet_url=get_comet_experiment_url(),
|
1090 |
+
trainer_name="Online DPO",
|
1091 |
+
trainer_citation=citation,
|
1092 |
+
paper_title="Direct Language Model Alignment from Online AI Feedback",
|
1093 |
+
paper_id="2402.04792",
|
1094 |
+
)
|
1095 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1096 |
+
class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
|
1097 |
+
"""
|
1098 |
+
|
1099 |
+
Initialize OnlineDPOTrainer.
|
1100 |
+
|
1101 |
+
Args:
|
1102 |
+
model (`transformers.PreTrainedModel` or `torch.nn.Module`):
|
1103 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
1104 |
+
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1105 |
+
The reference model to use for training. If None is specified, the reference model will be created from
|
1106 |
+
the model.
|
1107 |
+
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1108 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
1109 |
+
judge (`BasePairwiseJudge`):
|
1110 |
+
The judge to use for pairwise comparison of model completions.
|
1111 |
+
args (`OnlineDPOConfig`):
|
1112 |
+
The online DPO config arguments to use for training.
|
1113 |
+
data_collator (`transformers.DataCollator`):
|
1114 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1115 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1116 |
+
train_dataset (`datasets.Dataset`):
|
1117 |
+
The dataset to use for training.
|
1118 |
+
eval_dataset (`datasets.Dataset`):
|
1119 |
+
The dataset to use for evaluation.
|
1120 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1121 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1122 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1123 |
+
reuse the fine-tuned model.
|
1124 |
+
peft_config (`dict`):
|
1125 |
+
The peft config to use for training.
|
1126 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1127 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1128 |
+
a dictionary string to metric values.
|
1129 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1130 |
+
The callbacks to use for training.
|
1131 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1132 |
+
The optimizer and scheduler to use for training.
|
1133 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1134 |
+
The function to use to preprocess the logits before computing the metrics.
|
1135 |
+
|
1136 |
+
"""
|
1137 |
+
def __init__(
|
1138 |
+
self,
|
1139 |
+
model,
|
1140 |
+
ref_model = None,
|
1141 |
+
reward_model = None,
|
1142 |
+
judge = None,
|
1143 |
+
args = None,
|
1144 |
+
data_collator = None,
|
1145 |
+
train_dataset = None,
|
1146 |
+
eval_dataset = None,
|
1147 |
+
processing_class = None,
|
1148 |
+
reward_processing_class = None,
|
1149 |
+
peft_config = None,
|
1150 |
+
compute_metrics = None,
|
1151 |
+
callbacks = None,
|
1152 |
+
preprocess_logits_for_metrics = None,
|
1153 |
+
**kwargs
|
1154 |
+
):
|
1155 |
+
if args is None: args = UnslothOnlineDPOConfig()
|
1156 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1157 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1158 |
+
force_float32 = False
|
1159 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1160 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1161 |
+
force_float32 = True
|
1162 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1163 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1164 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1165 |
+
from unsloth_zoo.utils import _get_dtype
|
1166 |
+
dtype = _get_dtype(dtype)
|
1167 |
+
float16 = dtype == torch.float16
|
1168 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1169 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1170 |
+
if force_float32:
|
1171 |
+
args.fp16 = False
|
1172 |
+
args.bf16 = False
|
1173 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1174 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1175 |
+
args.fp16 = float16
|
1176 |
+
args.bf16 = not float16
|
1177 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1178 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1179 |
+
args.eval_strategy = 'steps'
|
1180 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1181 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1182 |
+
if ga_steps is not None and ga_steps > 1:
|
1183 |
+
from transformers import __version__ as transformers_version
|
1184 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1185 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1186 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1187 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1188 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1189 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1190 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1191 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1192 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1193 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1194 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1195 |
+
if force_float32:
|
1196 |
+
args.bf16_full_eval = False
|
1197 |
+
args.fp16_full_eval = False
|
1198 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1199 |
+
args.bf16_full_eval = True
|
1200 |
+
args.fp16_full_eval = False
|
1201 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1202 |
+
args.bf16_full_eval = args.bf16
|
1203 |
+
args.fp16_full_eval = args.fp16
|
1204 |
+
_output_logits = False
|
1205 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1206 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1207 |
+
if _output_logits:
|
1208 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1209 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1210 |
+
pass
|
1211 |
+
else:
|
1212 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1213 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1214 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1215 |
+
max_seq_length = model.max_seq_length
|
1216 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1217 |
+
if model is not None and hasattr(model, 'for_training'):
|
1218 |
+
model.for_training()
|
1219 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1220 |
+
if 'processing_class' in locals():
|
1221 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1222 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1223 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1224 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1225 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1226 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1227 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1228 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1229 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1230 |
+
else:
|
1231 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1232 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1233 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1234 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1235 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1236 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1237 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1238 |
+
else:
|
1239 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1240 |
+
other_metrics = []
|
1241 |
+
|
1242 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1243 |
+
PatchRLStatistics('online_dpo_trainer', other_metrics)
|
1244 |
+
|
1245 |
+
super().__init__(
|
1246 |
+
model = model,
|
1247 |
+
ref_model = ref_model,
|
1248 |
+
reward_model = reward_model,
|
1249 |
+
judge = judge,
|
1250 |
+
args = args,
|
1251 |
+
data_collator = data_collator,
|
1252 |
+
train_dataset = train_dataset,
|
1253 |
+
eval_dataset = eval_dataset,
|
1254 |
+
processing_class = processing_class,
|
1255 |
+
reward_processing_class = reward_processing_class,
|
1256 |
+
peft_config = peft_config,
|
1257 |
+
compute_metrics = compute_metrics,
|
1258 |
+
callbacks = callbacks,
|
1259 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
1260 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1261 |
+
self.neftune_hook_handle.remove()
|
1262 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1263 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1264 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1265 |
+
pass
|
1266 |
+
|
1267 |
+
pass
|
unsloth_compiled_cache/UnslothPPOTrainer.py
ADDED
@@ -0,0 +1,1257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothPPOConfig(PPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`PPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
54 |
+
Name of this experiment.
|
55 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
56 |
+
Path to the reward model.
|
57 |
+
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
58 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
59 |
+
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
61 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
62 |
+
Number of epochs to train.
|
63 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
64 |
+
Whether to whiten the rewards.
|
65 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
66 |
+
KL coefficient.
|
67 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
68 |
+
Clip range.
|
69 |
+
vf_coef (`float`, *optional*, defaults to `0.1`):
|
70 |
+
Value function coefficient.
|
71 |
+
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
72 |
+
Clip range for the value function.
|
73 |
+
gamma (`float`, *optional*, defaults to `1.0`):
|
74 |
+
Discount factor.
|
75 |
+
lam (`float`, *optional*, defaults to `0.95`):
|
76 |
+
Lambda value for GAE.
|
77 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
78 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
79 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
80 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
81 |
+
|
82 |
+
"""
|
83 |
+
vllm_sampling_params: Optional[Any] = field(
|
84 |
+
default = None,
|
85 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
86 |
+
)
|
87 |
+
unsloth_num_chunks : Optional[int] = field(
|
88 |
+
default = -1,
|
89 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
90 |
+
)
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
output_dir = None,
|
94 |
+
overwrite_output_dir = None,
|
95 |
+
do_train = False,
|
96 |
+
do_eval = False,
|
97 |
+
do_predict = False,
|
98 |
+
eval_strategy = 'no',
|
99 |
+
prediction_loss_only = False,
|
100 |
+
per_device_train_batch_size = 4,
|
101 |
+
per_device_eval_batch_size = 4,
|
102 |
+
per_gpu_train_batch_size = None,
|
103 |
+
per_gpu_eval_batch_size = None,
|
104 |
+
gradient_accumulation_steps = 2,
|
105 |
+
eval_accumulation_steps = 2,
|
106 |
+
eval_delay = 0,
|
107 |
+
torch_empty_cache_steps = 250,
|
108 |
+
learning_rate = 5e-05,
|
109 |
+
weight_decay = 0.01,
|
110 |
+
adam_beta1 = 0.9,
|
111 |
+
adam_beta2 = 0.999,
|
112 |
+
adam_epsilon = 1e-08,
|
113 |
+
max_grad_norm = 1.0,
|
114 |
+
num_train_epochs = 3.0,
|
115 |
+
max_steps = -1,
|
116 |
+
lr_scheduler_type = 'linear',
|
117 |
+
warmup_ratio = 0.1,
|
118 |
+
warmup_steps = 0,
|
119 |
+
log_level = 'passive',
|
120 |
+
log_level_replica = 'warning',
|
121 |
+
log_on_each_node = True,
|
122 |
+
logging_dir = None,
|
123 |
+
logging_strategy = 'steps',
|
124 |
+
logging_first_step = False,
|
125 |
+
logging_steps = 1,
|
126 |
+
logging_nan_inf_filter = False,
|
127 |
+
save_strategy = 'steps',
|
128 |
+
save_steps = 500,
|
129 |
+
save_total_limit = None,
|
130 |
+
save_safetensors = True,
|
131 |
+
save_on_each_node = False,
|
132 |
+
save_only_model = False,
|
133 |
+
restore_callback_states_from_checkpoint = False,
|
134 |
+
no_cuda = False,
|
135 |
+
use_cpu = False,
|
136 |
+
use_mps_device = False,
|
137 |
+
seed = 3407,
|
138 |
+
data_seed = 3407,
|
139 |
+
jit_mode_eval = False,
|
140 |
+
use_ipex = False,
|
141 |
+
bf16 = False,
|
142 |
+
fp16 = False,
|
143 |
+
fp16_opt_level = 'O1',
|
144 |
+
half_precision_backend = 'auto',
|
145 |
+
bf16_full_eval = False,
|
146 |
+
fp16_full_eval = False,
|
147 |
+
tf32 = None,
|
148 |
+
local_rank = -1,
|
149 |
+
ddp_backend = None,
|
150 |
+
tpu_num_cores = None,
|
151 |
+
tpu_metrics_debug = False,
|
152 |
+
debug = '',
|
153 |
+
dataloader_drop_last = False,
|
154 |
+
eval_steps = None,
|
155 |
+
dataloader_num_workers = 0,
|
156 |
+
dataloader_prefetch_factor = None,
|
157 |
+
past_index = -1,
|
158 |
+
run_name = None,
|
159 |
+
disable_tqdm = None,
|
160 |
+
remove_unused_columns = True,
|
161 |
+
label_names = None,
|
162 |
+
load_best_model_at_end = False,
|
163 |
+
metric_for_best_model = None,
|
164 |
+
greater_is_better = None,
|
165 |
+
ignore_data_skip = False,
|
166 |
+
fsdp = '',
|
167 |
+
fsdp_min_num_params = 0,
|
168 |
+
fsdp_config = None,
|
169 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
170 |
+
accelerator_config = None,
|
171 |
+
deepspeed = None,
|
172 |
+
label_smoothing_factor = 0.0,
|
173 |
+
optim = 'adamw_8bit',
|
174 |
+
optim_args = None,
|
175 |
+
adafactor = False,
|
176 |
+
group_by_length = False,
|
177 |
+
length_column_name = 'length',
|
178 |
+
report_to = None,
|
179 |
+
ddp_find_unused_parameters = None,
|
180 |
+
ddp_bucket_cap_mb = None,
|
181 |
+
ddp_broadcast_buffers = None,
|
182 |
+
dataloader_pin_memory = True,
|
183 |
+
dataloader_persistent_workers = False,
|
184 |
+
skip_memory_metrics = True,
|
185 |
+
use_legacy_prediction_loop = False,
|
186 |
+
push_to_hub = False,
|
187 |
+
resume_from_checkpoint = None,
|
188 |
+
hub_model_id = None,
|
189 |
+
hub_strategy = 'every_save',
|
190 |
+
hub_token = None,
|
191 |
+
hub_private_repo = None,
|
192 |
+
hub_always_push = False,
|
193 |
+
gradient_checkpointing = False,
|
194 |
+
gradient_checkpointing_kwargs = None,
|
195 |
+
include_inputs_for_metrics = False,
|
196 |
+
eval_do_concat_batches = True,
|
197 |
+
fp16_backend = 'auto',
|
198 |
+
evaluation_strategy = None,
|
199 |
+
push_to_hub_model_id = None,
|
200 |
+
push_to_hub_organization = None,
|
201 |
+
push_to_hub_token = None,
|
202 |
+
mp_parameters = '',
|
203 |
+
auto_find_batch_size = False,
|
204 |
+
full_determinism = False,
|
205 |
+
torchdynamo = None,
|
206 |
+
ray_scope = 'last',
|
207 |
+
ddp_timeout = 1800,
|
208 |
+
torch_compile = False,
|
209 |
+
torch_compile_backend = None,
|
210 |
+
torch_compile_mode = None,
|
211 |
+
dispatch_batches = None,
|
212 |
+
split_batches = None,
|
213 |
+
include_tokens_per_second = False,
|
214 |
+
include_num_input_tokens_seen = False,
|
215 |
+
neftune_noise_alpha = None,
|
216 |
+
optim_target_modules = None,
|
217 |
+
batch_eval_metrics = False,
|
218 |
+
eval_on_start = False,
|
219 |
+
use_liger_kernel = False,
|
220 |
+
eval_use_gather_object = False,
|
221 |
+
average_tokens_across_devices = False,
|
222 |
+
dataset_num_proc = None,
|
223 |
+
num_mini_batches = 1,
|
224 |
+
total_episodes = None,
|
225 |
+
local_rollout_forward_batch_size = 64,
|
226 |
+
num_sample_generations = 10,
|
227 |
+
response_length = 53,
|
228 |
+
stop_token = None,
|
229 |
+
stop_token_id = None,
|
230 |
+
temperature = 0.7,
|
231 |
+
missing_eos_penalty = None,
|
232 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
233 |
+
world_size = None,
|
234 |
+
num_total_batches = None,
|
235 |
+
micro_batch_size = None,
|
236 |
+
local_batch_size = None,
|
237 |
+
batch_size = None,
|
238 |
+
local_mini_batch_size = None,
|
239 |
+
mini_batch_size = None,
|
240 |
+
exp_name = 'ppo_config',
|
241 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
242 |
+
model_adapter_name = None,
|
243 |
+
ref_adapter_name = None,
|
244 |
+
num_ppo_epochs = 4,
|
245 |
+
whiten_rewards = False,
|
246 |
+
kl_coef = 0.05,
|
247 |
+
cliprange = 0.2,
|
248 |
+
vf_coef = 0.1,
|
249 |
+
cliprange_value = 0.2,
|
250 |
+
gamma = 1.0,
|
251 |
+
lam = 0.95,
|
252 |
+
ds3_gather_for_generation = True,
|
253 |
+
vllm_sampling_params = None,
|
254 |
+
unsloth_num_chunks = -1,
|
255 |
+
**kwargs,
|
256 |
+
):
|
257 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
258 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
259 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
260 |
+
output_dir = 'unsloth_training_checkpoints'
|
261 |
+
save_strategy = 'no'
|
262 |
+
if dataset_num_proc is None:
|
263 |
+
from multiprocessing import cpu_count
|
264 |
+
dataset_num_proc = cpu_count()
|
265 |
+
|
266 |
+
super().__init__(
|
267 |
+
output_dir = output_dir,
|
268 |
+
overwrite_output_dir = overwrite_output_dir,
|
269 |
+
do_train = do_train,
|
270 |
+
do_eval = do_eval,
|
271 |
+
do_predict = do_predict,
|
272 |
+
eval_strategy = eval_strategy,
|
273 |
+
prediction_loss_only = prediction_loss_only,
|
274 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
275 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
276 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
277 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
278 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
279 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
280 |
+
eval_delay = eval_delay,
|
281 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
282 |
+
learning_rate = learning_rate,
|
283 |
+
weight_decay = weight_decay,
|
284 |
+
adam_beta1 = adam_beta1,
|
285 |
+
adam_beta2 = adam_beta2,
|
286 |
+
adam_epsilon = adam_epsilon,
|
287 |
+
max_grad_norm = max_grad_norm,
|
288 |
+
num_train_epochs = num_train_epochs,
|
289 |
+
max_steps = max_steps,
|
290 |
+
lr_scheduler_type = lr_scheduler_type,
|
291 |
+
warmup_ratio = warmup_ratio,
|
292 |
+
warmup_steps = warmup_steps,
|
293 |
+
log_level = log_level,
|
294 |
+
log_level_replica = log_level_replica,
|
295 |
+
log_on_each_node = log_on_each_node,
|
296 |
+
logging_dir = logging_dir,
|
297 |
+
logging_strategy = logging_strategy,
|
298 |
+
logging_first_step = logging_first_step,
|
299 |
+
logging_steps = logging_steps,
|
300 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
301 |
+
save_strategy = save_strategy,
|
302 |
+
save_steps = save_steps,
|
303 |
+
save_total_limit = save_total_limit,
|
304 |
+
save_safetensors = save_safetensors,
|
305 |
+
save_on_each_node = save_on_each_node,
|
306 |
+
save_only_model = save_only_model,
|
307 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
308 |
+
no_cuda = no_cuda,
|
309 |
+
use_cpu = use_cpu,
|
310 |
+
use_mps_device = use_mps_device,
|
311 |
+
seed = seed,
|
312 |
+
data_seed = data_seed,
|
313 |
+
jit_mode_eval = jit_mode_eval,
|
314 |
+
use_ipex = use_ipex,
|
315 |
+
bf16 = bf16,
|
316 |
+
fp16 = fp16,
|
317 |
+
fp16_opt_level = fp16_opt_level,
|
318 |
+
half_precision_backend = half_precision_backend,
|
319 |
+
bf16_full_eval = bf16_full_eval,
|
320 |
+
fp16_full_eval = fp16_full_eval,
|
321 |
+
tf32 = tf32,
|
322 |
+
local_rank = local_rank,
|
323 |
+
ddp_backend = ddp_backend,
|
324 |
+
tpu_num_cores = tpu_num_cores,
|
325 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
326 |
+
debug = debug,
|
327 |
+
dataloader_drop_last = dataloader_drop_last,
|
328 |
+
eval_steps = eval_steps,
|
329 |
+
dataloader_num_workers = dataloader_num_workers,
|
330 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
331 |
+
past_index = past_index,
|
332 |
+
run_name = run_name,
|
333 |
+
disable_tqdm = disable_tqdm,
|
334 |
+
remove_unused_columns = remove_unused_columns,
|
335 |
+
label_names = label_names,
|
336 |
+
load_best_model_at_end = load_best_model_at_end,
|
337 |
+
metric_for_best_model = metric_for_best_model,
|
338 |
+
greater_is_better = greater_is_better,
|
339 |
+
ignore_data_skip = ignore_data_skip,
|
340 |
+
fsdp = fsdp,
|
341 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
342 |
+
fsdp_config = fsdp_config,
|
343 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
344 |
+
accelerator_config = accelerator_config,
|
345 |
+
deepspeed = deepspeed,
|
346 |
+
label_smoothing_factor = label_smoothing_factor,
|
347 |
+
optim = optim,
|
348 |
+
optim_args = optim_args,
|
349 |
+
adafactor = adafactor,
|
350 |
+
group_by_length = group_by_length,
|
351 |
+
length_column_name = length_column_name,
|
352 |
+
report_to = report_to,
|
353 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
354 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
355 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
356 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
357 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
358 |
+
skip_memory_metrics = skip_memory_metrics,
|
359 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
360 |
+
push_to_hub = push_to_hub,
|
361 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
362 |
+
hub_model_id = hub_model_id,
|
363 |
+
hub_strategy = hub_strategy,
|
364 |
+
hub_token = hub_token,
|
365 |
+
hub_private_repo = hub_private_repo,
|
366 |
+
hub_always_push = hub_always_push,
|
367 |
+
gradient_checkpointing = gradient_checkpointing,
|
368 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
369 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
370 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
371 |
+
fp16_backend = fp16_backend,
|
372 |
+
evaluation_strategy = evaluation_strategy,
|
373 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
374 |
+
push_to_hub_organization = push_to_hub_organization,
|
375 |
+
push_to_hub_token = push_to_hub_token,
|
376 |
+
mp_parameters = mp_parameters,
|
377 |
+
auto_find_batch_size = auto_find_batch_size,
|
378 |
+
full_determinism = full_determinism,
|
379 |
+
torchdynamo = torchdynamo,
|
380 |
+
ray_scope = ray_scope,
|
381 |
+
ddp_timeout = ddp_timeout,
|
382 |
+
torch_compile = torch_compile,
|
383 |
+
torch_compile_backend = torch_compile_backend,
|
384 |
+
torch_compile_mode = torch_compile_mode,
|
385 |
+
dispatch_batches = dispatch_batches,
|
386 |
+
split_batches = split_batches,
|
387 |
+
include_tokens_per_second = include_tokens_per_second,
|
388 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
389 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
390 |
+
optim_target_modules = optim_target_modules,
|
391 |
+
batch_eval_metrics = batch_eval_metrics,
|
392 |
+
eval_on_start = eval_on_start,
|
393 |
+
use_liger_kernel = use_liger_kernel,
|
394 |
+
eval_use_gather_object = eval_use_gather_object,
|
395 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
396 |
+
dataset_num_proc = dataset_num_proc,
|
397 |
+
num_mini_batches = num_mini_batches,
|
398 |
+
total_episodes = total_episodes,
|
399 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
400 |
+
num_sample_generations = num_sample_generations,
|
401 |
+
response_length = response_length,
|
402 |
+
stop_token = stop_token,
|
403 |
+
stop_token_id = stop_token_id,
|
404 |
+
temperature = temperature,
|
405 |
+
missing_eos_penalty = missing_eos_penalty,
|
406 |
+
sft_model_path = sft_model_path,
|
407 |
+
world_size = world_size,
|
408 |
+
num_total_batches = num_total_batches,
|
409 |
+
micro_batch_size = micro_batch_size,
|
410 |
+
local_batch_size = local_batch_size,
|
411 |
+
batch_size = batch_size,
|
412 |
+
local_mini_batch_size = local_mini_batch_size,
|
413 |
+
mini_batch_size = mini_batch_size,
|
414 |
+
exp_name = exp_name,
|
415 |
+
reward_model_path = reward_model_path,
|
416 |
+
model_adapter_name = model_adapter_name,
|
417 |
+
ref_adapter_name = ref_adapter_name,
|
418 |
+
num_ppo_epochs = num_ppo_epochs,
|
419 |
+
whiten_rewards = whiten_rewards,
|
420 |
+
kl_coef = kl_coef,
|
421 |
+
cliprange = cliprange,
|
422 |
+
vf_coef = vf_coef,
|
423 |
+
cliprange_value = cliprange_value,
|
424 |
+
gamma = gamma,
|
425 |
+
lam = lam,
|
426 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
427 |
+
self.vllm_sampling_params = vllm_sampling_params
|
428 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
429 |
+
pass
|
430 |
+
|
431 |
+
class _UnslothPPOTrainer(Trainer):
|
432 |
+
_tag_names = ["trl", "ppo"]
|
433 |
+
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
args: PPOConfig,
|
437 |
+
processing_class: Optional[
|
438 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
439 |
+
],
|
440 |
+
model: nn.Module,
|
441 |
+
ref_model: Optional[nn.Module],
|
442 |
+
reward_model: nn.Module,
|
443 |
+
train_dataset: Dataset,
|
444 |
+
value_model: Optional[nn.Module] = None,
|
445 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
446 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
447 |
+
# less commonly used
|
448 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
449 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
450 |
+
peft_config: Optional["PeftConfig"] = None,
|
451 |
+
) -> None:
|
452 |
+
if ref_model is model:
|
453 |
+
raise ValueError(
|
454 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
455 |
+
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
456 |
+
)
|
457 |
+
|
458 |
+
self.args = args
|
459 |
+
self.processing_class = processing_class
|
460 |
+
self.policy_model = model
|
461 |
+
|
462 |
+
# Define the collator if not provided
|
463 |
+
if data_collator is None:
|
464 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
465 |
+
|
466 |
+
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
467 |
+
if args.stop_token and args.stop_token_id:
|
468 |
+
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
469 |
+
elif args.stop_token:
|
470 |
+
if args.stop_token == "eos":
|
471 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
472 |
+
else:
|
473 |
+
raise ValueError(
|
474 |
+
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
478 |
+
|
479 |
+
# peft support
|
480 |
+
if not is_peft_available() and peft_config is not None:
|
481 |
+
raise ImportError(
|
482 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
483 |
+
)
|
484 |
+
elif is_peft_available() and peft_config is not None:
|
485 |
+
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
486 |
+
if isinstance(self.policy_model, PeftModel):
|
487 |
+
self.policy_model = self.policy_model.merge_and_unload()
|
488 |
+
|
489 |
+
# get peft model with the given config
|
490 |
+
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
491 |
+
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
492 |
+
peft_module_casting_to_bf16(self.policy_model)
|
493 |
+
|
494 |
+
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
495 |
+
self.model_adapter_name = args.model_adapter_name
|
496 |
+
self.ref_adapter_name = args.ref_adapter_name
|
497 |
+
|
498 |
+
if ref_model:
|
499 |
+
self.ref_model = ref_model
|
500 |
+
elif self.is_peft_model:
|
501 |
+
self.ref_model = None
|
502 |
+
else:
|
503 |
+
self.ref_model = create_reference_model(self.policy_model)
|
504 |
+
|
505 |
+
self.reward_model = reward_model
|
506 |
+
self.train_dataset = train_dataset
|
507 |
+
self.train_dataset_len = len(train_dataset)
|
508 |
+
self.value_model = value_model
|
509 |
+
self.data_collator = data_collator
|
510 |
+
self.eval_dataset = eval_dataset
|
511 |
+
self.optimizer, self.lr_scheduler = optimizers
|
512 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
513 |
+
|
514 |
+
#########
|
515 |
+
# calculate various batch sizes
|
516 |
+
#########
|
517 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
518 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
519 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
520 |
+
self.accelerator = accelerator
|
521 |
+
args.world_size = accelerator.num_processes
|
522 |
+
args.local_batch_size = (
|
523 |
+
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
524 |
+
)
|
525 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
526 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
527 |
+
args.mini_batch_size = exact_div(
|
528 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
529 |
+
)
|
530 |
+
args.local_mini_batch_size = exact_div(
|
531 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
532 |
+
)
|
533 |
+
if args.whiten_rewards:
|
534 |
+
assert (
|
535 |
+
args.local_mini_batch_size >= 8
|
536 |
+
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
537 |
+
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
538 |
+
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
539 |
+
args.num_total_batches = math.ceil(
|
540 |
+
args.total_episodes / args.batch_size
|
541 |
+
) # we may train for more than `total_episodes`
|
542 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
543 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
544 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
545 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
546 |
+
if args.num_sample_generations > 0:
|
547 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
548 |
+
self.local_dataloader_batch_size = args.local_batch_size
|
549 |
+
|
550 |
+
#########
|
551 |
+
# setup model, optimizer, and others
|
552 |
+
#########
|
553 |
+
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
554 |
+
if module is not None:
|
555 |
+
disable_dropout_in_model(module)
|
556 |
+
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
557 |
+
self.model.config = self.policy_model.config # needed for pushing to hub
|
558 |
+
self.create_optimizer_and_scheduler(
|
559 |
+
num_training_steps=args.num_total_batches
|
560 |
+
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
561 |
+
|
562 |
+
#########
|
563 |
+
### trainer specifics
|
564 |
+
#########
|
565 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
566 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
567 |
+
self.callback_handler = CallbackHandler(
|
568 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
569 |
+
)
|
570 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
571 |
+
self.control = TrainerControl()
|
572 |
+
self.state = OnlineTrainerState(
|
573 |
+
is_local_process_zero=self.is_local_process_zero(),
|
574 |
+
is_world_process_zero=self.is_world_process_zero(),
|
575 |
+
stateful_callbacks=[
|
576 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
577 |
+
],
|
578 |
+
)
|
579 |
+
self.current_flos = 0
|
580 |
+
self.hp_search_backend = None
|
581 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
582 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
583 |
+
# Create distant repo and output directory if needed
|
584 |
+
self.hub_model_id = None
|
585 |
+
if self.args.push_to_hub:
|
586 |
+
self.init_hf_repo()
|
587 |
+
if self.args.should_save:
|
588 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
589 |
+
|
590 |
+
# Add tags for models that have been loaded with the correct transformers version
|
591 |
+
if hasattr(self.model, "add_model_tags"):
|
592 |
+
self.model.add_model_tags(self._tag_names)
|
593 |
+
|
594 |
+
#########
|
595 |
+
### setup dataloader
|
596 |
+
#########
|
597 |
+
self.dataloader = DataLoader(
|
598 |
+
self.train_dataset,
|
599 |
+
batch_size=self.local_dataloader_batch_size,
|
600 |
+
shuffle=True,
|
601 |
+
collate_fn=self.data_collator,
|
602 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
603 |
+
)
|
604 |
+
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
605 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
606 |
+
torch.manual_seed(args.seed)
|
607 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
608 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
609 |
+
|
610 |
+
self.eval_dataloader = DataLoader(
|
611 |
+
self.eval_dataset,
|
612 |
+
batch_size=args.per_device_eval_batch_size,
|
613 |
+
collate_fn=self.data_collator,
|
614 |
+
drop_last=True,
|
615 |
+
) # no need to shuffle eval dataset
|
616 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
617 |
+
|
618 |
+
if self.is_deepspeed_enabled:
|
619 |
+
self.reward_model = prepare_deepspeed(
|
620 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
621 |
+
)
|
622 |
+
|
623 |
+
if self.ref_model is None:
|
624 |
+
if not self.is_peft_model:
|
625 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
626 |
+
else:
|
627 |
+
self.ref_model = prepare_deepspeed(
|
628 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
629 |
+
)
|
630 |
+
else:
|
631 |
+
if self.ref_model is None:
|
632 |
+
if not self.is_peft_model:
|
633 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
634 |
+
else:
|
635 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
636 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
637 |
+
|
638 |
+
def get_train_dataloader(self) -> DataLoader:
|
639 |
+
return self.dataloader
|
640 |
+
|
641 |
+
def get_eval_dataloader(self) -> DataLoader:
|
642 |
+
return self.eval_dataloader
|
643 |
+
|
644 |
+
@contextmanager
|
645 |
+
def null_ref_context(self):
|
646 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
647 |
+
with (
|
648 |
+
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
649 |
+
if self.is_peft_model and not self.ref_adapter_name
|
650 |
+
else nullcontext()
|
651 |
+
):
|
652 |
+
if self.ref_adapter_name:
|
653 |
+
self.model.policy.set_adapter(self.ref_adapter_name)
|
654 |
+
yield
|
655 |
+
if self.ref_adapter_name:
|
656 |
+
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
657 |
+
|
658 |
+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
659 |
+
backup_model = self.model
|
660 |
+
self.model = self.model.policy # save only the policy
|
661 |
+
|
662 |
+
if self.is_deepspeed_enabled:
|
663 |
+
backup_deepspeed = self.deepspeed
|
664 |
+
self.deepspeed = self.model
|
665 |
+
|
666 |
+
super().save_model(output_dir, _internal_call)
|
667 |
+
|
668 |
+
self.model = backup_model
|
669 |
+
|
670 |
+
if self.is_deepspeed_enabled:
|
671 |
+
self.deepspeed = backup_deepspeed
|
672 |
+
|
673 |
+
def train(self):
|
674 |
+
args = self.args
|
675 |
+
accelerator = self.accelerator
|
676 |
+
optimizer = self.optimizer
|
677 |
+
model = self.model
|
678 |
+
ref_policy = self.ref_model
|
679 |
+
reward_model = self.reward_model
|
680 |
+
processing_class = self.processing_class
|
681 |
+
dataloader = self.dataloader
|
682 |
+
device = accelerator.device
|
683 |
+
|
684 |
+
def repeat_generator():
|
685 |
+
while True:
|
686 |
+
yield from dataloader
|
687 |
+
|
688 |
+
iter_dataloader = iter(repeat_generator())
|
689 |
+
generation_config = GenerationConfig(
|
690 |
+
max_new_tokens=args.response_length,
|
691 |
+
temperature=(args.temperature + 1e-7),
|
692 |
+
top_k=0.0,
|
693 |
+
top_p=1.0,
|
694 |
+
do_sample=True,
|
695 |
+
)
|
696 |
+
|
697 |
+
accelerator.print("===training policy===")
|
698 |
+
start_time = time.time()
|
699 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
700 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
701 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
702 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
703 |
+
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
704 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
705 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
706 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
707 |
+
model.train()
|
708 |
+
|
709 |
+
# trainer state initialization
|
710 |
+
self.state.global_step = 0
|
711 |
+
self.state.episode = 0
|
712 |
+
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
713 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
714 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
715 |
+
if args.logging_steps is not None:
|
716 |
+
if args.logging_steps < 1:
|
717 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
718 |
+
else:
|
719 |
+
self.state.logging_steps = args.logging_steps
|
720 |
+
if args.eval_steps is not None:
|
721 |
+
if args.eval_steps < 1:
|
722 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
723 |
+
else:
|
724 |
+
self.state.eval_steps = args.eval_steps
|
725 |
+
if args.save_steps is not None:
|
726 |
+
if args.save_steps < 1:
|
727 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
728 |
+
else:
|
729 |
+
self.state.save_steps = args.save_steps
|
730 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
731 |
+
|
732 |
+
# backward compatibility
|
733 |
+
if self.is_deepspeed_enabled:
|
734 |
+
self.deepspeed = self.model
|
735 |
+
self.model_wrapped = self.model
|
736 |
+
|
737 |
+
for update in range(1, args.num_total_batches + 1):
|
738 |
+
self.state.episode += 1 * args.batch_size
|
739 |
+
data = next(iter_dataloader)
|
740 |
+
with torch.no_grad():
|
741 |
+
queries = data["input_ids"].to(device)
|
742 |
+
context_length = queries.shape[1]
|
743 |
+
responses = []
|
744 |
+
postprocessed_responses = []
|
745 |
+
logprobs = []
|
746 |
+
ref_logprobs = []
|
747 |
+
scores = []
|
748 |
+
sequence_lengths = []
|
749 |
+
values = []
|
750 |
+
with unwrap_model_for_generation(
|
751 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
752 |
+
) as unwrapped_model:
|
753 |
+
query_responses, logitss = batch_generation(
|
754 |
+
unwrapped_model.policy,
|
755 |
+
queries,
|
756 |
+
args.local_rollout_forward_batch_size,
|
757 |
+
processing_class.pad_token_id,
|
758 |
+
generation_config,
|
759 |
+
)
|
760 |
+
|
761 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
762 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
763 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
764 |
+
response = query_response[:, context_length:]
|
765 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
766 |
+
logprob = selective_log_softmax(logits, response)
|
767 |
+
del logits
|
768 |
+
torch.cuda.empty_cache()
|
769 |
+
|
770 |
+
if ref_policy is None:
|
771 |
+
with self.null_ref_context():
|
772 |
+
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
773 |
+
else:
|
774 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
775 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
776 |
+
ref_logits /= args.temperature + 1e-7
|
777 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
778 |
+
del ref_output, ref_logits
|
779 |
+
torch.cuda.empty_cache()
|
780 |
+
|
781 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
782 |
+
postprocessed_response = response
|
783 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
784 |
+
postprocessed_response = truncate_response(
|
785 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
786 |
+
)
|
787 |
+
|
788 |
+
# Response Processing 2. run reward model on the truncated responses
|
789 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
790 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
791 |
+
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
792 |
+
full_value, _, _ = get_reward(
|
793 |
+
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
794 |
+
)
|
795 |
+
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
796 |
+
_, score, _ = get_reward(
|
797 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
798 |
+
)
|
799 |
+
|
800 |
+
responses.append(response)
|
801 |
+
postprocessed_responses.append(postprocessed_response)
|
802 |
+
logprobs.append(logprob)
|
803 |
+
ref_logprobs.append(ref_logprob)
|
804 |
+
sequence_lengths.append(sequence_length)
|
805 |
+
scores.append(score)
|
806 |
+
values.append(value)
|
807 |
+
responses = torch.cat(responses, 0)
|
808 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
809 |
+
logprobs = torch.cat(logprobs, 0)
|
810 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
811 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
812 |
+
scores = torch.cat(scores, 0)
|
813 |
+
values = torch.cat(values, 0)
|
814 |
+
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
815 |
+
torch.cuda.empty_cache()
|
816 |
+
gc.collect()
|
817 |
+
|
818 |
+
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
819 |
+
# Completions not passing that filter will receive a lower score.
|
820 |
+
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
821 |
+
if self.args.missing_eos_penalty is not None:
|
822 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
823 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
824 |
+
|
825 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
826 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
827 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
828 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
829 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
830 |
+
sequence_lengths_p1 = sequence_lengths + 1
|
831 |
+
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
832 |
+
values = torch.masked_fill(values, padding_mask_p1, 0)
|
833 |
+
|
834 |
+
# 4. compute rewards
|
835 |
+
kl = logprobs - ref_logprobs
|
836 |
+
non_score_reward = -args.kl_coef * kl
|
837 |
+
rewards = non_score_reward.clone()
|
838 |
+
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
839 |
+
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
840 |
+
rewards[[actual_start, actual_end]] += scores
|
841 |
+
|
842 |
+
# 5. whiten rewards
|
843 |
+
if args.whiten_rewards:
|
844 |
+
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
845 |
+
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
846 |
+
|
847 |
+
# 6. compute advantages and returns
|
848 |
+
lastgaelam = 0
|
849 |
+
advantages_reversed = []
|
850 |
+
gen_length = responses.shape[1]
|
851 |
+
for t in reversed(range(gen_length)):
|
852 |
+
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
853 |
+
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
854 |
+
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
855 |
+
advantages_reversed.append(lastgaelam)
|
856 |
+
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
857 |
+
returns = advantages + values
|
858 |
+
advantages = masked_whiten(advantages, ~padding_mask)
|
859 |
+
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
860 |
+
torch.cuda.empty_cache()
|
861 |
+
|
862 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
863 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
864 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
865 |
+
minibatch_idx = 0
|
866 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
867 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
868 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
869 |
+
gradient_accumulation_idx = 0
|
870 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
871 |
+
with accelerator.accumulate(model):
|
872 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
873 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
874 |
+
mb_advantage = advantages[micro_batch_inds]
|
875 |
+
mb_responses = responses[micro_batch_inds]
|
876 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
877 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
878 |
+
mb_return = returns[micro_batch_inds]
|
879 |
+
mb_values = values[micro_batch_inds]
|
880 |
+
|
881 |
+
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
882 |
+
logits = output.logits[:, context_length - 1 : -1]
|
883 |
+
logits /= args.temperature + 1e-7
|
884 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
885 |
+
new_logprobs = torch.masked_fill(
|
886 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
887 |
+
)
|
888 |
+
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
889 |
+
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
890 |
+
vpredclipped = torch.clamp(
|
891 |
+
vpred,
|
892 |
+
mb_values - args.cliprange_value,
|
893 |
+
mb_values + args.cliprange_value,
|
894 |
+
)
|
895 |
+
vf_losses1 = torch.square(vpred - mb_return)
|
896 |
+
vf_losses2 = torch.square(vpredclipped - mb_return)
|
897 |
+
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
898 |
+
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
899 |
+
vf_clipfrac = masked_mean(
|
900 |
+
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
901 |
+
)
|
902 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
903 |
+
ratio = torch.exp(logprobs_diff)
|
904 |
+
pg_losses = -mb_advantage * ratio
|
905 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
906 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
907 |
+
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
908 |
+
loss = pg_loss + args.vf_coef * vf_loss
|
909 |
+
accelerator.backward(loss)
|
910 |
+
optimizer.step()
|
911 |
+
optimizer.zero_grad()
|
912 |
+
with torch.no_grad():
|
913 |
+
pg_clipfrac = masked_mean(
|
914 |
+
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
915 |
+
)
|
916 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
917 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
918 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
919 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
920 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
921 |
+
pg_clipfrac
|
922 |
+
)
|
923 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
924 |
+
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
925 |
+
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
926 |
+
vf_clipfrac
|
927 |
+
)
|
928 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
929 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
930 |
+
gradient_accumulation_idx += 1
|
931 |
+
minibatch_idx += 1
|
932 |
+
# del everything and empty cache
|
933 |
+
# fmt: off
|
934 |
+
del (
|
935 |
+
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
936 |
+
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
937 |
+
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
938 |
+
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
939 |
+
)
|
940 |
+
# fmt: on
|
941 |
+
torch.cuda.empty_cache()
|
942 |
+
with torch.no_grad():
|
943 |
+
mean_kl = kl.sum(1).mean()
|
944 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
945 |
+
mean_non_score_reward = non_score_reward.sum(1).mean()
|
946 |
+
rlhf_reward = mean_non_score_reward + scores.mean()
|
947 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
948 |
+
metrics = {}
|
949 |
+
metrics["eps"] = eps
|
950 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
951 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
952 |
+
metrics["objective/non_score_reward"] = (
|
953 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
954 |
+
)
|
955 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
956 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
957 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
958 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
959 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
960 |
+
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
961 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
962 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
963 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
964 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
965 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
966 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
967 |
+
metrics["episode"] = self.state.episode
|
968 |
+
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
969 |
+
self.state.global_step += 1
|
970 |
+
self.log(metrics)
|
971 |
+
|
972 |
+
self.lr_scheduler.step()
|
973 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
974 |
+
if self.control.should_save:
|
975 |
+
self._save_checkpoint(model, trial=None)
|
976 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
977 |
+
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
978 |
+
torch.cuda.empty_cache()
|
979 |
+
gc.collect()
|
980 |
+
|
981 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
982 |
+
self.generate_completions(sampling=True)
|
983 |
+
torch.cuda.empty_cache()
|
984 |
+
del (
|
985 |
+
query_responses,
|
986 |
+
responses,
|
987 |
+
postprocessed_responses,
|
988 |
+
logprobs,
|
989 |
+
ref_logprobs,
|
990 |
+
values,
|
991 |
+
sequence_lengths,
|
992 |
+
contain_eos_token,
|
993 |
+
sequence_lengths_p1,
|
994 |
+
response_idxs,
|
995 |
+
padding_mask,
|
996 |
+
padding_mask_p1,
|
997 |
+
rewards,
|
998 |
+
actual_start,
|
999 |
+
actual_end,
|
1000 |
+
advantages,
|
1001 |
+
returns,
|
1002 |
+
)
|
1003 |
+
torch.cuda.empty_cache()
|
1004 |
+
|
1005 |
+
# HF trainer specifics
|
1006 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
1007 |
+
if self.control.should_save:
|
1008 |
+
self._save_checkpoint(model, trial=None, metrics=None)
|
1009 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
1010 |
+
|
1011 |
+
def generate_completions(self, sampling: bool = False):
|
1012 |
+
args = self.args
|
1013 |
+
processing_class = self.processing_class
|
1014 |
+
generation_config = GenerationConfig(
|
1015 |
+
max_new_tokens=self.args.response_length,
|
1016 |
+
temperature=(0.01 + 1e-7),
|
1017 |
+
top_k=0.0,
|
1018 |
+
top_p=1.0,
|
1019 |
+
do_sample=True,
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
table = defaultdict(list)
|
1023 |
+
with unwrap_model_for_generation(
|
1024 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
1025 |
+
) as unwrapped_model:
|
1026 |
+
for batch in self.eval_dataloader:
|
1027 |
+
query = batch["input_ids"]
|
1028 |
+
with torch.no_grad():
|
1029 |
+
context_length = query.shape[1]
|
1030 |
+
query_response, _ = batch_generation(
|
1031 |
+
unwrapped_model.policy,
|
1032 |
+
query,
|
1033 |
+
query.shape[0],
|
1034 |
+
processing_class.pad_token_id,
|
1035 |
+
generation_config,
|
1036 |
+
)
|
1037 |
+
response = query_response[:, context_length:]
|
1038 |
+
postprocessed_response = response
|
1039 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
1040 |
+
postprocessed_response = truncate_response(
|
1041 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
1042 |
+
)
|
1043 |
+
table["query"].extend(
|
1044 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
1045 |
+
)
|
1046 |
+
table["model response"].extend(
|
1047 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
1051 |
+
_, score, _ = get_reward(
|
1052 |
+
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
1053 |
+
)
|
1054 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
1055 |
+
|
1056 |
+
if sampling:
|
1057 |
+
break
|
1058 |
+
df = pd.DataFrame(table)
|
1059 |
+
|
1060 |
+
if self.accelerator.is_main_process:
|
1061 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
1062 |
+
if "wandb" in args.report_to:
|
1063 |
+
import wandb
|
1064 |
+
|
1065 |
+
if wandb.run is not None:
|
1066 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1067 |
+
|
1068 |
+
if "comet_ml" in args.report_to:
|
1069 |
+
log_table_to_comet_experiment(
|
1070 |
+
name="completions.csv",
|
1071 |
+
table=df,
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
def create_model_card(
|
1075 |
+
self,
|
1076 |
+
model_name: Optional[str] = None,
|
1077 |
+
dataset_name: Optional[str] = None,
|
1078 |
+
tags: Union[str, list[str], None] = None,
|
1079 |
+
):
|
1080 |
+
"""
|
1081 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1082 |
+
|
1083 |
+
Args:
|
1084 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1085 |
+
Name of the model.
|
1086 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1087 |
+
Name of the dataset used for training.
|
1088 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1089 |
+
Tags to be associated with the model card.
|
1090 |
+
"""
|
1091 |
+
if not self.is_world_process_zero():
|
1092 |
+
return
|
1093 |
+
|
1094 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1095 |
+
base_model = self.model.config._name_or_path
|
1096 |
+
else:
|
1097 |
+
base_model = None
|
1098 |
+
|
1099 |
+
tags = tags or []
|
1100 |
+
if isinstance(tags, str):
|
1101 |
+
tags = [tags]
|
1102 |
+
|
1103 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1104 |
+
tags.append("unsloth")
|
1105 |
+
|
1106 |
+
citation = textwrap.dedent("""\
|
1107 |
+
@article{mziegler2019fine-tuning,
|
1108 |
+
title = {{Fine-Tuning Language Models from Human Preferences}},
|
1109 |
+
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
1110 |
+
year = 2019,
|
1111 |
+
eprint = {arXiv:1909.08593}
|
1112 |
+
}""")
|
1113 |
+
|
1114 |
+
model_card = generate_model_card(
|
1115 |
+
base_model=base_model,
|
1116 |
+
model_name=model_name,
|
1117 |
+
hub_model_id=self.hub_model_id,
|
1118 |
+
dataset_name=dataset_name,
|
1119 |
+
tags=tags,
|
1120 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1121 |
+
comet_url=get_comet_experiment_url(),
|
1122 |
+
trainer_name="PPO",
|
1123 |
+
trainer_citation=citation,
|
1124 |
+
paper_title="Fine-Tuning Language Models from Human Preferences",
|
1125 |
+
paper_id="1909.08593",
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1129 |
+
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
1130 |
+
"""
|
1131 |
+
|
1132 |
+
"""
|
1133 |
+
def __init__(
|
1134 |
+
self,
|
1135 |
+
args,
|
1136 |
+
processing_class,
|
1137 |
+
model,
|
1138 |
+
ref_model,
|
1139 |
+
reward_model,
|
1140 |
+
train_dataset,
|
1141 |
+
value_model = None,
|
1142 |
+
data_collator = None,
|
1143 |
+
eval_dataset = None,
|
1144 |
+
callbacks = None,
|
1145 |
+
peft_config = None,
|
1146 |
+
**kwargs
|
1147 |
+
):
|
1148 |
+
if args is None: args = UnslothPPOConfig()
|
1149 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1150 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1151 |
+
force_float32 = False
|
1152 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1153 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1154 |
+
force_float32 = True
|
1155 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1156 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1157 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1158 |
+
from unsloth_zoo.utils import _get_dtype
|
1159 |
+
dtype = _get_dtype(dtype)
|
1160 |
+
float16 = dtype == torch.float16
|
1161 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1162 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1163 |
+
if force_float32:
|
1164 |
+
args.fp16 = False
|
1165 |
+
args.bf16 = False
|
1166 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1167 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1168 |
+
args.fp16 = float16
|
1169 |
+
args.bf16 = not float16
|
1170 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1171 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1172 |
+
args.eval_strategy = 'steps'
|
1173 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1174 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1175 |
+
if ga_steps is not None and ga_steps > 1:
|
1176 |
+
from transformers import __version__ as transformers_version
|
1177 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1178 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1179 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1180 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1181 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1182 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1183 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1184 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1185 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1186 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1187 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1188 |
+
if force_float32:
|
1189 |
+
args.bf16_full_eval = False
|
1190 |
+
args.fp16_full_eval = False
|
1191 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1192 |
+
args.bf16_full_eval = True
|
1193 |
+
args.fp16_full_eval = False
|
1194 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1195 |
+
args.bf16_full_eval = args.bf16
|
1196 |
+
args.fp16_full_eval = args.fp16
|
1197 |
+
_output_logits = False
|
1198 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1199 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1200 |
+
if _output_logits:
|
1201 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1202 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1203 |
+
pass
|
1204 |
+
else:
|
1205 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1206 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1207 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1208 |
+
max_seq_length = model.max_seq_length
|
1209 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1210 |
+
if model is not None and hasattr(model, 'for_training'):
|
1211 |
+
model.for_training()
|
1212 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1213 |
+
if 'processing_class' in locals():
|
1214 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1215 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1216 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1217 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1218 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1219 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1220 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1221 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1222 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1223 |
+
else:
|
1224 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1225 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1226 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1227 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1228 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1229 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1230 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1231 |
+
else:
|
1232 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1233 |
+
other_metrics = []
|
1234 |
+
|
1235 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1236 |
+
PatchRLStatistics('ppo_trainer', other_metrics)
|
1237 |
+
|
1238 |
+
super().__init__(
|
1239 |
+
args = args,
|
1240 |
+
processing_class = processing_class,
|
1241 |
+
model = model,
|
1242 |
+
ref_model = ref_model,
|
1243 |
+
reward_model = reward_model,
|
1244 |
+
train_dataset = train_dataset,
|
1245 |
+
value_model = value_model,
|
1246 |
+
data_collator = data_collator,
|
1247 |
+
eval_dataset = eval_dataset,
|
1248 |
+
callbacks = callbacks,
|
1249 |
+
peft_config = peft_config,**kwargs)
|
1250 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1251 |
+
self.neftune_hook_handle.remove()
|
1252 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1253 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1254 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1255 |
+
pass
|
1256 |
+
|
1257 |
+
pass
|
unsloth_compiled_cache/UnslothPRMTrainer.py
ADDED
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.13
|
3 |
+
2025.3.15
|
4 |
+
4.48.3
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothPRMConfig(PRMConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`PRMTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) used for truncation.
|
58 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
59 |
+
Maximum length of the prompt used for truncation.
|
60 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
61 |
+
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
62 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
63 |
+
Whether to disable dropout in the model.
|
64 |
+
step_separator (`str`, *optional*, defaults to `"\n"`):
|
65 |
+
Separator used to separate each step of the reasoning process.
|
66 |
+
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
67 |
+
Whether to train only on the last step.
|
68 |
+
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
69 |
+
Number of processes to use for processing the dataset.
|
70 |
+
|
71 |
+
"""
|
72 |
+
vllm_sampling_params: Optional[Any] = field(
|
73 |
+
default = None,
|
74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
+
)
|
76 |
+
unsloth_num_chunks : Optional[int] = field(
|
77 |
+
default = -1,
|
78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
+
)
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
output_dir = None,
|
83 |
+
overwrite_output_dir = None,
|
84 |
+
do_train = False,
|
85 |
+
do_eval = False,
|
86 |
+
do_predict = False,
|
87 |
+
eval_strategy = 'no',
|
88 |
+
prediction_loss_only = False,
|
89 |
+
per_device_train_batch_size = 4,
|
90 |
+
per_device_eval_batch_size = 4,
|
91 |
+
per_gpu_train_batch_size = None,
|
92 |
+
per_gpu_eval_batch_size = None,
|
93 |
+
gradient_accumulation_steps = 2,
|
94 |
+
eval_accumulation_steps = 2,
|
95 |
+
eval_delay = 0,
|
96 |
+
torch_empty_cache_steps = 250,
|
97 |
+
learning_rate = 5e-05,
|
98 |
+
weight_decay = 0.01,
|
99 |
+
adam_beta1 = 0.9,
|
100 |
+
adam_beta2 = 0.999,
|
101 |
+
adam_epsilon = 1e-08,
|
102 |
+
max_grad_norm = 1.0,
|
103 |
+
num_train_epochs = 3.0,
|
104 |
+
max_steps = -1,
|
105 |
+
lr_scheduler_type = 'linear',
|
106 |
+
warmup_ratio = 0.1,
|
107 |
+
warmup_steps = 0,
|
108 |
+
log_level = 'passive',
|
109 |
+
log_level_replica = 'warning',
|
110 |
+
log_on_each_node = True,
|
111 |
+
logging_dir = None,
|
112 |
+
logging_strategy = 'steps',
|
113 |
+
logging_first_step = False,
|
114 |
+
logging_steps = 1,
|
115 |
+
logging_nan_inf_filter = False,
|
116 |
+
save_strategy = 'steps',
|
117 |
+
save_steps = 500,
|
118 |
+
save_total_limit = None,
|
119 |
+
save_safetensors = True,
|
120 |
+
save_on_each_node = False,
|
121 |
+
save_only_model = False,
|
122 |
+
restore_callback_states_from_checkpoint = False,
|
123 |
+
no_cuda = False,
|
124 |
+
use_cpu = False,
|
125 |
+
use_mps_device = False,
|
126 |
+
seed = 3407,
|
127 |
+
data_seed = 3407,
|
128 |
+
jit_mode_eval = False,
|
129 |
+
use_ipex = False,
|
130 |
+
bf16 = False,
|
131 |
+
fp16 = False,
|
132 |
+
fp16_opt_level = 'O1',
|
133 |
+
half_precision_backend = 'auto',
|
134 |
+
bf16_full_eval = False,
|
135 |
+
fp16_full_eval = False,
|
136 |
+
tf32 = None,
|
137 |
+
local_rank = -1,
|
138 |
+
ddp_backend = None,
|
139 |
+
tpu_num_cores = None,
|
140 |
+
tpu_metrics_debug = False,
|
141 |
+
debug = '',
|
142 |
+
dataloader_drop_last = False,
|
143 |
+
eval_steps = None,
|
144 |
+
dataloader_num_workers = 0,
|
145 |
+
dataloader_prefetch_factor = None,
|
146 |
+
past_index = -1,
|
147 |
+
run_name = None,
|
148 |
+
disable_tqdm = None,
|
149 |
+
remove_unused_columns = True,
|
150 |
+
label_names = None,
|
151 |
+
load_best_model_at_end = False,
|
152 |
+
metric_for_best_model = None,
|
153 |
+
greater_is_better = None,
|
154 |
+
ignore_data_skip = False,
|
155 |
+
fsdp = '',
|
156 |
+
fsdp_min_num_params = 0,
|
157 |
+
fsdp_config = None,
|
158 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
159 |
+
accelerator_config = None,
|
160 |
+
deepspeed = None,
|
161 |
+
label_smoothing_factor = 0.0,
|
162 |
+
optim = 'adamw_8bit',
|
163 |
+
optim_args = None,
|
164 |
+
adafactor = False,
|
165 |
+
group_by_length = False,
|
166 |
+
length_column_name = 'length',
|
167 |
+
report_to = None,
|
168 |
+
ddp_find_unused_parameters = None,
|
169 |
+
ddp_bucket_cap_mb = None,
|
170 |
+
ddp_broadcast_buffers = None,
|
171 |
+
dataloader_pin_memory = True,
|
172 |
+
dataloader_persistent_workers = False,
|
173 |
+
skip_memory_metrics = True,
|
174 |
+
use_legacy_prediction_loop = False,
|
175 |
+
push_to_hub = False,
|
176 |
+
resume_from_checkpoint = None,
|
177 |
+
hub_model_id = None,
|
178 |
+
hub_strategy = 'every_save',
|
179 |
+
hub_token = None,
|
180 |
+
hub_private_repo = None,
|
181 |
+
hub_always_push = False,
|
182 |
+
gradient_checkpointing = False,
|
183 |
+
gradient_checkpointing_kwargs = None,
|
184 |
+
include_inputs_for_metrics = False,
|
185 |
+
eval_do_concat_batches = True,
|
186 |
+
fp16_backend = 'auto',
|
187 |
+
evaluation_strategy = None,
|
188 |
+
push_to_hub_model_id = None,
|
189 |
+
push_to_hub_organization = None,
|
190 |
+
push_to_hub_token = None,
|
191 |
+
mp_parameters = '',
|
192 |
+
auto_find_batch_size = False,
|
193 |
+
full_determinism = False,
|
194 |
+
torchdynamo = None,
|
195 |
+
ray_scope = 'last',
|
196 |
+
ddp_timeout = 1800,
|
197 |
+
torch_compile = False,
|
198 |
+
torch_compile_backend = None,
|
199 |
+
torch_compile_mode = None,
|
200 |
+
dispatch_batches = None,
|
201 |
+
split_batches = None,
|
202 |
+
include_tokens_per_second = False,
|
203 |
+
include_num_input_tokens_seen = False,
|
204 |
+
neftune_noise_alpha = None,
|
205 |
+
optim_target_modules = None,
|
206 |
+
batch_eval_metrics = False,
|
207 |
+
eval_on_start = False,
|
208 |
+
use_liger_kernel = False,
|
209 |
+
eval_use_gather_object = False,
|
210 |
+
average_tokens_across_devices = False,
|
211 |
+
max_length = 1024,
|
212 |
+
max_prompt_length = 512,
|
213 |
+
max_completion_length = None,
|
214 |
+
disable_dropout = True,
|
215 |
+
step_separator = '\
|
216 |
+
',
|
217 |
+
train_on_last_step_only = False,
|
218 |
+
dataset_num_proc = None,
|
219 |
+
vllm_sampling_params = None,
|
220 |
+
unsloth_num_chunks = -1,
|
221 |
+
**kwargs,
|
222 |
+
):
|
223 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
224 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
225 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
226 |
+
output_dir = 'unsloth_training_checkpoints'
|
227 |
+
save_strategy = 'no'
|
228 |
+
if dataset_num_proc is None:
|
229 |
+
from multiprocessing import cpu_count
|
230 |
+
dataset_num_proc = cpu_count()
|
231 |
+
|
232 |
+
super().__init__(
|
233 |
+
output_dir = output_dir,
|
234 |
+
overwrite_output_dir = overwrite_output_dir,
|
235 |
+
do_train = do_train,
|
236 |
+
do_eval = do_eval,
|
237 |
+
do_predict = do_predict,
|
238 |
+
eval_strategy = eval_strategy,
|
239 |
+
prediction_loss_only = prediction_loss_only,
|
240 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
241 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
242 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
243 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
244 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
245 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
246 |
+
eval_delay = eval_delay,
|
247 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
248 |
+
learning_rate = learning_rate,
|
249 |
+
weight_decay = weight_decay,
|
250 |
+
adam_beta1 = adam_beta1,
|
251 |
+
adam_beta2 = adam_beta2,
|
252 |
+
adam_epsilon = adam_epsilon,
|
253 |
+
max_grad_norm = max_grad_norm,
|
254 |
+
num_train_epochs = num_train_epochs,
|
255 |
+
max_steps = max_steps,
|
256 |
+
lr_scheduler_type = lr_scheduler_type,
|
257 |
+
warmup_ratio = warmup_ratio,
|
258 |
+
warmup_steps = warmup_steps,
|
259 |
+
log_level = log_level,
|
260 |
+
log_level_replica = log_level_replica,
|
261 |
+
log_on_each_node = log_on_each_node,
|
262 |
+
logging_dir = logging_dir,
|
263 |
+
logging_strategy = logging_strategy,
|
264 |
+
logging_first_step = logging_first_step,
|
265 |
+
logging_steps = logging_steps,
|
266 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
267 |
+
save_strategy = save_strategy,
|
268 |
+
save_steps = save_steps,
|
269 |
+
save_total_limit = save_total_limit,
|
270 |
+
save_safetensors = save_safetensors,
|
271 |
+
save_on_each_node = save_on_each_node,
|
272 |
+
save_only_model = save_only_model,
|
273 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
274 |
+
no_cuda = no_cuda,
|
275 |
+
use_cpu = use_cpu,
|
276 |
+
use_mps_device = use_mps_device,
|
277 |
+
seed = seed,
|
278 |
+
data_seed = data_seed,
|
279 |
+
jit_mode_eval = jit_mode_eval,
|
280 |
+
use_ipex = use_ipex,
|
281 |
+
bf16 = bf16,
|
282 |
+
fp16 = fp16,
|
283 |
+
fp16_opt_level = fp16_opt_level,
|
284 |
+
half_precision_backend = half_precision_backend,
|
285 |
+
bf16_full_eval = bf16_full_eval,
|
286 |
+
fp16_full_eval = fp16_full_eval,
|
287 |
+
tf32 = tf32,
|
288 |
+
local_rank = local_rank,
|
289 |
+
ddp_backend = ddp_backend,
|
290 |
+
tpu_num_cores = tpu_num_cores,
|
291 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
292 |
+
debug = debug,
|
293 |
+
dataloader_drop_last = dataloader_drop_last,
|
294 |
+
eval_steps = eval_steps,
|
295 |
+
dataloader_num_workers = dataloader_num_workers,
|
296 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
297 |
+
past_index = past_index,
|
298 |
+
run_name = run_name,
|
299 |
+
disable_tqdm = disable_tqdm,
|
300 |
+
remove_unused_columns = remove_unused_columns,
|
301 |
+
label_names = label_names,
|
302 |
+
load_best_model_at_end = load_best_model_at_end,
|
303 |
+
metric_for_best_model = metric_for_best_model,
|
304 |
+
greater_is_better = greater_is_better,
|
305 |
+
ignore_data_skip = ignore_data_skip,
|
306 |
+
fsdp = fsdp,
|
307 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
308 |
+
fsdp_config = fsdp_config,
|
309 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
310 |
+
accelerator_config = accelerator_config,
|
311 |
+
deepspeed = deepspeed,
|
312 |
+
label_smoothing_factor = label_smoothing_factor,
|
313 |
+
optim = optim,
|
314 |
+
optim_args = optim_args,
|
315 |
+
adafactor = adafactor,
|
316 |
+
group_by_length = group_by_length,
|
317 |
+
length_column_name = length_column_name,
|
318 |
+
report_to = report_to,
|
319 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
320 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
321 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
322 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
323 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
324 |
+
skip_memory_metrics = skip_memory_metrics,
|
325 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
326 |
+
push_to_hub = push_to_hub,
|
327 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
328 |
+
hub_model_id = hub_model_id,
|
329 |
+
hub_strategy = hub_strategy,
|
330 |
+
hub_token = hub_token,
|
331 |
+
hub_private_repo = hub_private_repo,
|
332 |
+
hub_always_push = hub_always_push,
|
333 |
+
gradient_checkpointing = gradient_checkpointing,
|
334 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
335 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
336 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
337 |
+
fp16_backend = fp16_backend,
|
338 |
+
evaluation_strategy = evaluation_strategy,
|
339 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
340 |
+
push_to_hub_organization = push_to_hub_organization,
|
341 |
+
push_to_hub_token = push_to_hub_token,
|
342 |
+
mp_parameters = mp_parameters,
|
343 |
+
auto_find_batch_size = auto_find_batch_size,
|
344 |
+
full_determinism = full_determinism,
|
345 |
+
torchdynamo = torchdynamo,
|
346 |
+
ray_scope = ray_scope,
|
347 |
+
ddp_timeout = ddp_timeout,
|
348 |
+
torch_compile = torch_compile,
|
349 |
+
torch_compile_backend = torch_compile_backend,
|
350 |
+
torch_compile_mode = torch_compile_mode,
|
351 |
+
dispatch_batches = dispatch_batches,
|
352 |
+
split_batches = split_batches,
|
353 |
+
include_tokens_per_second = include_tokens_per_second,
|
354 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
355 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
356 |
+
optim_target_modules = optim_target_modules,
|
357 |
+
batch_eval_metrics = batch_eval_metrics,
|
358 |
+
eval_on_start = eval_on_start,
|
359 |
+
use_liger_kernel = use_liger_kernel,
|
360 |
+
eval_use_gather_object = eval_use_gather_object,
|
361 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
362 |
+
max_length = max_length,
|
363 |
+
max_prompt_length = max_prompt_length,
|
364 |
+
max_completion_length = max_completion_length,
|
365 |
+
disable_dropout = disable_dropout,
|
366 |
+
step_separator = step_separator,
|
367 |
+
train_on_last_step_only = train_on_last_step_only,
|
368 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
369 |
+
self.vllm_sampling_params = vllm_sampling_params
|
370 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
371 |
+
pass
|
372 |
+
|
373 |
+
class _UnslothPRMTrainer(Trainer):
|
374 |
+
""""""
|
375 |
+
|
376 |
+
_tag_names = ["trl", "prm"]
|
377 |
+
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
381 |
+
args: Optional[PRMConfig] = None,
|
382 |
+
data_collator: Optional[DataCollator] = None,
|
383 |
+
train_dataset: Optional[Dataset] = None,
|
384 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
385 |
+
processing_class: Optional[
|
386 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
387 |
+
] = None,
|
388 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
389 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
390 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
391 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
392 |
+
None,
|
393 |
+
None,
|
394 |
+
),
|
395 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
396 |
+
peft_config: Optional[dict] = None,
|
397 |
+
):
|
398 |
+
if not is_peft_available() and peft_config is not None:
|
399 |
+
raise ValueError(
|
400 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
401 |
+
)
|
402 |
+
elif is_peft_available() and peft_config is not None:
|
403 |
+
if not isinstance(model, PeftModel):
|
404 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
405 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
406 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
407 |
+
)
|
408 |
+
|
409 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
410 |
+
|
411 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
412 |
+
warnings.warn(
|
413 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
414 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
415 |
+
)
|
416 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
417 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
418 |
+
|
419 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
420 |
+
|
421 |
+
model = model
|
422 |
+
|
423 |
+
# Disable dropout in the model
|
424 |
+
if args.disable_dropout:
|
425 |
+
disable_dropout_in_model(model)
|
426 |
+
|
427 |
+
if compute_metrics is None:
|
428 |
+
compute_metrics = compute_accuracy
|
429 |
+
|
430 |
+
if data_collator is None:
|
431 |
+
if processing_class is None:
|
432 |
+
raise ValueError(
|
433 |
+
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
434 |
+
)
|
435 |
+
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
436 |
+
|
437 |
+
if "input_ids" not in train_dataset.column_names:
|
438 |
+
with PartialState().local_main_process_first():
|
439 |
+
fn_kwargs = {
|
440 |
+
"tokenizer": processing_class,
|
441 |
+
"step_separator": args.step_separator,
|
442 |
+
"max_length": args.max_length,
|
443 |
+
"max_prompt_length": args.max_prompt_length,
|
444 |
+
"max_completion_length": args.max_completion_length,
|
445 |
+
"train_on_last_step_only": args.train_on_last_step_only,
|
446 |
+
}
|
447 |
+
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
448 |
+
train_dataset = train_dataset.map(
|
449 |
+
self.tokenize_row,
|
450 |
+
fn_kwargs=train_fn_kwargs,
|
451 |
+
num_proc=args.dataset_num_proc,
|
452 |
+
remove_columns=train_dataset.features,
|
453 |
+
desc="Tokenizing train dataset",
|
454 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
455 |
+
{
|
456 |
+
"labels": features.Sequence(features.Value("int64")),
|
457 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
458 |
+
}
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
463 |
+
if eval_dataset is not None:
|
464 |
+
eval_dataset = eval_dataset.map(
|
465 |
+
self.tokenize_row,
|
466 |
+
fn_kwargs=eval_fn_kwargs,
|
467 |
+
num_proc=args.dataset_num_proc,
|
468 |
+
remove_columns=eval_dataset.features,
|
469 |
+
desc="Tokenizing eval dataset",
|
470 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
471 |
+
{
|
472 |
+
"labels": features.Sequence(features.Value("int64")),
|
473 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
474 |
+
}
|
475 |
+
),
|
476 |
+
)
|
477 |
+
|
478 |
+
super().__init__(
|
479 |
+
model=model,
|
480 |
+
args=args,
|
481 |
+
data_collator=data_collator,
|
482 |
+
train_dataset=train_dataset,
|
483 |
+
eval_dataset=eval_dataset,
|
484 |
+
processing_class=processing_class,
|
485 |
+
model_init=model_init,
|
486 |
+
compute_metrics=compute_metrics,
|
487 |
+
callbacks=callbacks,
|
488 |
+
optimizers=optimizers,
|
489 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
490 |
+
)
|
491 |
+
|
492 |
+
# Add tags for models that have been loaded with the correct transformers version
|
493 |
+
if hasattr(self.model, "add_model_tags"):
|
494 |
+
self.model.add_model_tags(self._tag_names)
|
495 |
+
|
496 |
+
@staticmethod
|
497 |
+
def tokenize_row(
|
498 |
+
features,
|
499 |
+
tokenizer,
|
500 |
+
step_separator,
|
501 |
+
max_length,
|
502 |
+
max_prompt_length,
|
503 |
+
max_completion_length,
|
504 |
+
train_on_last_step_only,
|
505 |
+
is_eval,
|
506 |
+
):
|
507 |
+
r"""
|
508 |
+
Tokenize a row of the dataset.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
features (`dict[str, str]`):
|
512 |
+
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
513 |
+
tokenizer (`PreTrainedTokenizerBase`):
|
514 |
+
Tokenizer used to process the data.
|
515 |
+
step_separator (`str`):
|
516 |
+
Separator between steps in the completion.
|
517 |
+
max_length (`int` or `None`):
|
518 |
+
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
519 |
+
max_prompt_length (`int` or `None`):
|
520 |
+
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
521 |
+
max_completion_length (`int` or `None`):
|
522 |
+
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
523 |
+
train_on_last_step_only (`bool`):
|
524 |
+
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
525 |
+
token of the completion.
|
526 |
+
is_eval (`bool`):
|
527 |
+
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
`dict[str, list[int]]`:
|
531 |
+
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
532 |
+
|
533 |
+
Example:
|
534 |
+
```python
|
535 |
+
>>> from transformers import AutoTokenizer
|
536 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
537 |
+
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
538 |
+
... "completions": ["11 is greater than 8.",
|
539 |
+
... "Hence, 9.11 > 9.8."],
|
540 |
+
... "labels": [True, False]}
|
541 |
+
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
542 |
+
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
543 |
+
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
544 |
+
```
|
545 |
+
"""
|
546 |
+
# Tokenize the prompt and completions
|
547 |
+
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
548 |
+
completions_ids = [
|
549 |
+
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
550 |
+
]
|
551 |
+
if train_on_last_step_only and not is_eval:
|
552 |
+
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
553 |
+
else:
|
554 |
+
labels = [int(label) for label in features["labels"]]
|
555 |
+
|
556 |
+
# Get the ID of the separator token and add it to the completions
|
557 |
+
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
558 |
+
completions_ids = [completion + separator_ids for completion in completions_ids]
|
559 |
+
|
560 |
+
# Create the label
|
561 |
+
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
562 |
+
|
563 |
+
# Join the completions and labels steps
|
564 |
+
completion_ids = list(chain(*completions_ids))
|
565 |
+
labels = list(chain(*labels))
|
566 |
+
|
567 |
+
if tokenizer.bos_token_id is not None:
|
568 |
+
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
569 |
+
|
570 |
+
# Truncate prompt and completion sequences
|
571 |
+
if max_prompt_length is not None:
|
572 |
+
prompt_ids = prompt_ids[-max_prompt_length:]
|
573 |
+
if max_completion_length is not None:
|
574 |
+
completion_ids = completion_ids[:max_completion_length]
|
575 |
+
labels = labels[:max_completion_length]
|
576 |
+
|
577 |
+
input_ids = prompt_ids + completion_ids
|
578 |
+
labels = [-100] * len(prompt_ids) + labels
|
579 |
+
|
580 |
+
if max_length is not None:
|
581 |
+
input_ids = input_ids[:max_length]
|
582 |
+
labels = labels[:max_length]
|
583 |
+
|
584 |
+
return {"input_ids": input_ids, "labels": labels}
|
585 |
+
|
586 |
+
def create_model_card(
|
587 |
+
self,
|
588 |
+
model_name: Optional[str] = None,
|
589 |
+
dataset_name: Optional[str] = None,
|
590 |
+
tags: Union[str, list[str], None] = None,
|
591 |
+
):
|
592 |
+
"""
|
593 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
597 |
+
Name of the model.
|
598 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
599 |
+
Name of the dataset used for training.
|
600 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
601 |
+
Tags to be associated with the model card.
|
602 |
+
"""
|
603 |
+
if not self.is_world_process_zero():
|
604 |
+
return
|
605 |
+
|
606 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
607 |
+
base_model = self.model.config._name_or_path
|
608 |
+
else:
|
609 |
+
base_model = None
|
610 |
+
|
611 |
+
tags = tags or []
|
612 |
+
if isinstance(tags, str):
|
613 |
+
tags = [tags]
|
614 |
+
|
615 |
+
if hasattr(self.model.config, "unsloth_version"):
|
616 |
+
tags.append("unsloth")
|
617 |
+
|
618 |
+
citation = textwrap.dedent("""\
|
619 |
+
@article{uesato2022solving,
|
620 |
+
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
621 |
+
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
622 |
+
year = 2022,
|
623 |
+
journal = {arXiv preprint arXiv:2211.14275}
|
624 |
+
}""")
|
625 |
+
|
626 |
+
model_card = generate_model_card(
|
627 |
+
base_model=base_model,
|
628 |
+
model_name=model_name,
|
629 |
+
hub_model_id=self.hub_model_id,
|
630 |
+
dataset_name=dataset_name,
|
631 |
+
tags=tags,
|
632 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
633 |
+
trainer_name="PRM",
|
634 |
+
trainer_citation=citation,
|
635 |
+
paper_title="Solving math word problems with process-and outcome-based feedback",
|
636 |
+
)
|
637 |
+
|
638 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
639 |
+
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
640 |
+
"""
|
641 |
+
|
642 |
+
Initialize PRMTrainer.
|
643 |
+
|
644 |
+
Args:
|
645 |
+
model (`transformers.PreTrainedModel`):
|
646 |
+
The model to train, preferably an `AutoModelForTokenClassification`.
|
647 |
+
args (`PRMConfig`):
|
648 |
+
The arguments to use for training.
|
649 |
+
data_collator (`transformers.DataCollator`):
|
650 |
+
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
651 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
652 |
+
train_dataset (`datasets.Dataset`):
|
653 |
+
The dataset to use for training.
|
654 |
+
eval_dataset (`datasets.Dataset`):
|
655 |
+
The dataset to use for evaluation.
|
656 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
657 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
658 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
659 |
+
reuse the fine-tuned model.
|
660 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
661 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
662 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
663 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
664 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
665 |
+
The callbacks to use for training.
|
666 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
667 |
+
The optimizer and scheduler to use for training.
|
668 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
669 |
+
The function to use to preprocess the logits before computing the metrics.
|
670 |
+
peft_config (`dict`, defaults to `None`):
|
671 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
672 |
+
|
673 |
+
"""
|
674 |
+
def __init__(
|
675 |
+
self,
|
676 |
+
model = None,
|
677 |
+
args = None,
|
678 |
+
data_collator = None,
|
679 |
+
train_dataset = None,
|
680 |
+
eval_dataset = None,
|
681 |
+
processing_class = None,
|
682 |
+
model_init = None,
|
683 |
+
compute_metrics = None,
|
684 |
+
callbacks = None,
|
685 |
+
preprocess_logits_for_metrics = None,
|
686 |
+
peft_config = None,
|
687 |
+
**kwargs
|
688 |
+
):
|
689 |
+
if args is None: args = UnslothPRMConfig()
|
690 |
+
use_bf16 = getattr(args, 'bf16', False)
|
691 |
+
use_fp16 = getattr(args, 'fp16', False)
|
692 |
+
force_float32 = False
|
693 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
694 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
695 |
+
force_float32 = True
|
696 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
697 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
698 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
699 |
+
from unsloth_zoo.utils import _get_dtype
|
700 |
+
dtype = _get_dtype(dtype)
|
701 |
+
float16 = dtype == torch.float16
|
702 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
703 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
704 |
+
if force_float32:
|
705 |
+
args.fp16 = False
|
706 |
+
args.bf16 = False
|
707 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
708 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
709 |
+
args.fp16 = float16
|
710 |
+
args.bf16 = not float16
|
711 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
712 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
713 |
+
args.eval_strategy = 'steps'
|
714 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
715 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
716 |
+
if ga_steps is not None and ga_steps > 1:
|
717 |
+
from transformers import __version__ as transformers_version
|
718 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
719 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
720 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
721 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
722 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
723 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
724 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
725 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
726 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
727 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
728 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
729 |
+
if force_float32:
|
730 |
+
args.bf16_full_eval = False
|
731 |
+
args.fp16_full_eval = False
|
732 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
733 |
+
args.bf16_full_eval = True
|
734 |
+
args.fp16_full_eval = False
|
735 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
736 |
+
args.bf16_full_eval = args.bf16
|
737 |
+
args.fp16_full_eval = args.fp16
|
738 |
+
_output_logits = False
|
739 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
740 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
741 |
+
if _output_logits:
|
742 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
743 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
744 |
+
pass
|
745 |
+
else:
|
746 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
747 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
748 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
749 |
+
max_seq_length = model.max_seq_length
|
750 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
751 |
+
if model is not None and hasattr(model, 'for_training'):
|
752 |
+
model.for_training()
|
753 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
754 |
+
if 'processing_class' in locals():
|
755 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
756 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
757 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
758 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
759 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
760 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
761 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
762 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
763 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
764 |
+
else:
|
765 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
766 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
767 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
768 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
769 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
770 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
771 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
772 |
+
else:
|
773 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
774 |
+
other_metrics = []
|
775 |
+
|
776 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
777 |
+
PatchRLStatistics('prm_trainer', other_metrics)
|
778 |
+
|
779 |
+
super().__init__(
|
780 |
+
model = model,
|
781 |
+
args = args,
|
782 |
+
data_collator = data_collator,
|
783 |
+
train_dataset = train_dataset,
|
784 |
+
eval_dataset = eval_dataset,
|
785 |
+
processing_class = processing_class,
|
786 |
+
model_init = model_init,
|
787 |
+
compute_metrics = compute_metrics,
|
788 |
+
callbacks = callbacks,
|
789 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
790 |
+
peft_config = peft_config,**kwargs)
|
791 |
+
if hasattr(self, 'neftune_hook_handle'):
|
792 |
+
self.neftune_hook_handle.remove()
|
793 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
794 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
795 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
796 |
+
pass
|
797 |
+
|
798 |
+
pass
|