Zwounds commited on
Commit
eef22b5
·
verified ·
1 Parent(s): 53a648d

Upload folder using huggingface_hub

Browse files
Files changed (38) hide show
  1. .gitattributes +1 -0
  2. .gradio/certificate.pem +31 -0
  3. demo_cpu.py +182 -0
  4. requirements.txt +2 -1
  5. unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
  6. unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
  7. unsloth_compiled_cache/UnslothCPOTrainer.py +1557 -0
  8. unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
  9. unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
  10. unsloth_compiled_cache/UnslothGKDTrainer.py +863 -0
  11. unsloth_compiled_cache/UnslothGRPOTrainer.py +1438 -0
  12. unsloth_compiled_cache/UnslothKTOTrainer.py +1840 -0
  13. unsloth_compiled_cache/UnslothNashMDTrainer.py +955 -0
  14. unsloth_compiled_cache/UnslothORPOTrainer.py +1543 -0
  15. unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1269 -0
  16. unsloth_compiled_cache/UnslothPPOTrainer.py +1259 -0
  17. unsloth_compiled_cache/UnslothPRMTrainer.py +800 -0
  18. unsloth_compiled_cache/UnslothRLOOTrainer.py +1133 -0
  19. unsloth_compiled_cache/UnslothRewardTrainer.py +819 -0
  20. unsloth_compiled_cache/UnslothSFTTrainer.py +1027 -0
  21. unsloth_compiled_cache/UnslothXPOTrainer.py +1010 -0
  22. unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
  23. unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
  24. unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
  25. unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
  26. unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +3 -0
  27. unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
  28. unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
  29. unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
  30. unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
  31. unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
  32. unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
  33. unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
  34. unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
  35. unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
  36. unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
  37. unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
  38. unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
demo_cpu.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import logging
5
+ import os
6
+
7
+ # Setup logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ SYSTEM_INSTRUCTION = """Convert natural language queries into boolean search queries by following these rules:
12
+
13
+ 1. FIRST: Remove all meta-terms from this list (they should NEVER appear in output):
14
+ - articles, papers, research, studies
15
+ - examining, investigating, analyzing
16
+ - findings, documents, literature
17
+ - publications, journals, reviews
18
+ Example: "Research examining X" → just "X"
19
+
20
+ 2. SECOND: Remove generic implied terms that don't add search value:
21
+ - Remove words like "practices," "techniques," "methods," "approaches," "strategies"
22
+ - Remove words like "impacts," "effects," "influences," "role," "applications"
23
+ - For example: "sustainable agriculture practices" → "sustainable agriculture"
24
+ - For example: "teaching methodologies" → "teaching"
25
+ - For example: "leadership styles" → "leadership"
26
+
27
+ 3. THEN: Format the remaining terms:
28
+ CRITICAL QUOTING RULES:
29
+ - Multi-word phrases MUST ALWAYS be in quotes - NO EXCEPTIONS
30
+ - Examples of correct quoting:
31
+ - Wrong: machine learning AND deep learning
32
+ - Right: "machine learning" AND "deep learning"
33
+ - Wrong: natural language processing
34
+ - Right: "natural language processing"
35
+ - Single words must NEVER have quotes (e.g., science, research, learning)
36
+ - Use AND to connect required concepts
37
+ - Use OR with parentheses for alternatives
38
+ - Provide ONLY the correct Boolean query
39
+ """
40
+
41
+ def load_model():
42
+ """Load the model and tokenizer."""
43
+ logger.info("Loading model on CPU...")
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ "../boolean_model_llama/merged_llama",
46
+ torch_dtype=torch.float32,
47
+ device_map="cpu",
48
+ trust_remote_code=True
49
+ )
50
+ tokenizer = AutoTokenizer.from_pretrained(
51
+ "../boolean_model_llama/merged_llama",
52
+ trust_remote_code=True
53
+ )
54
+
55
+ logger.info("Model loaded successfully")
56
+ return model, tokenizer
57
+
58
+ def extract_response(text: str) -> str:
59
+ """Extract the actual boolean query from the model's response."""
60
+ # Split by newlines and get last non-empty line
61
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
62
+ return lines[-1] if lines else ""
63
+
64
+ def get_boolean_query(query: str, model_tuple=None) -> str:
65
+ """Generate boolean query from natural language."""
66
+ if model_tuple is None:
67
+ return ""
68
+ model, tokenizer = model_tuple
69
+
70
+ # Format the conversation
71
+ conversation = [
72
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
73
+ {"role": "user", "content": query}
74
+ ]
75
+
76
+ # Apply chat template and tokenize
77
+ encoded = tokenizer.apply_chat_template(
78
+ conversation,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ # Create attention mask
83
+ attention_mask = torch.ones_like(encoded)
84
+ inputs = {
85
+ "input_ids": encoded,
86
+ "attention_mask": attention_mask
87
+ }
88
+
89
+ # Generate response
90
+ outputs = model.generate(
91
+ **inputs,
92
+ max_new_tokens=64,
93
+ temperature=0.1,
94
+ pad_token_id=tokenizer.eos_token_id
95
+ )
96
+
97
+ # Decode and extract response
98
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+ return extract_response(response)
100
+
101
+ # Example queries demonstrating various cases
102
+ examples = [
103
+ # Testing removal of meta-terms
104
+ ["Find research papers examining the long-term effects of meditation on brain structure"],
105
+
106
+ # Testing removal of generic implied terms (practices, techniques, methods)
107
+ ["Articles about deep learning techniques for natural language processing tasks"],
108
+
109
+ # Testing removal of impact/effect terms
110
+ ["Studies on the impact of early childhood nutrition on cognitive development"],
111
+
112
+ # Testing handling of technology applications
113
+ ["Information on virtual reality applications in architectural design and urban planning"],
114
+
115
+ # Testing proper OR relationship with parentheses
116
+ ["Research on electric vehicles adoption in urban environments or rural communities"],
117
+
118
+ # Testing proper quoting of multi-word concepts only
119
+ ["Articles on biodiversity loss in coral reefs and rainforest ecosystems"],
120
+
121
+ # Testing removal of strategy/approach terms
122
+ ["Studies about different teaching approaches for children with learning disabilities"],
123
+
124
+ # Testing complex OR relationships
125
+ ["Research examining social media influence on political polarization or public discourse"],
126
+
127
+ # Testing implied terms in specific industries
128
+ ["Articles about implementation strategies for blockchain in supply chain management or financial services"],
129
+
130
+ # Testing qualifiers that don't add search value
131
+ ["Research on effective leadership styles in multicultural organizations"],
132
+
133
+ # Testing removal of multiple implied terms
134
+ ["Studies on the effects of microplastic pollution techniques on marine ecosystem health"],
135
+
136
+ # Testing domain-specific implied terms
137
+ ["Articles about successful cybersecurity protection methods for critical infrastructure"],
138
+
139
+ # Testing generalized vs specific concepts
140
+ ["Research papers on quantum computing algorithms for cryptography or optimization problems"],
141
+
142
+ # Testing implied terms in outcome descriptions
143
+ ["Studies examining the relationship between sleep quality and academic performance outcomes"],
144
+
145
+ # Testing complex nesting of concepts
146
+ ["Articles about renewable energy integration challenges in developing countries or island nations"]
147
+ ]
148
+
149
+ # Load model and tokenizer globally
150
+ logger.info("Initializing model...")
151
+ model_tuple = load_model()
152
+
153
+ # Create Gradio interface
154
+ title = "Natural Language to Boolean Search (CPU Version)"
155
+ description = """Convert natural language queries into boolean search expressions. The model will:
156
+
157
+ 1. Remove search-related terms (like 'articles', 'research', etc.)
158
+ 2. Handle generic implied terms (like 'practices', 'methods')
159
+ 3. Format concepts using proper boolean syntax:
160
+ - Multi-word phrases in quotes
161
+ - Single words without quotes
162
+ - AND to connect required concepts
163
+ - OR with parentheses for alternatives
164
+ """
165
+
166
+ demo = gr.Interface(
167
+ fn=lambda x: get_boolean_query(x, model_tuple),
168
+ inputs=[
169
+ gr.Textbox(
170
+ label="Enter your natural language query",
171
+ placeholder="e.g., I'm looking for information about climate change and renewable energy"
172
+ )
173
+ ],
174
+ outputs=gr.Textbox(label="Boolean Search Query"),
175
+ title=title,
176
+ description=description,
177
+ examples=examples,
178
+ theme=gr.themes.Soft()
179
+ )
180
+
181
+ if __name__ == "__main__":
182
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio>=4.0.0
2
- llama-cpp-python>=0.2.11
 
 
1
  gradio>=4.0.0
2
+ transformers>=4.0.0
3
+ torch>=1.0.0
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothAlignPropConfig(AlignPropConfig):
44
+ """
45
+
46
+ Configuration class for the [`AlignPropTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (defaults to the file name without the extension).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed for reproducibility.
59
+ log_with (`str` or `None`, *optional*, defaults to `None`):
60
+ Log with either `"wandb"` or `"tensorboard"`. Check
61
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
+ log_image_freq (`int`, *optional*, defaults to `1`):
63
+ Frequency for logging images.
64
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the tracker (e.g., `wandb_project`).
66
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator.
68
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
+ Name of project to use for tracking.
72
+ logdir (`str`, *optional*, defaults to `"logs"`):
73
+ Top-level logging directory for checkpoint saving.
74
+ num_epochs (`int`, *optional*, defaults to `100`):
75
+ Number of epochs to train.
76
+ save_freq (`int`, *optional*, defaults to `1`):
77
+ Number of epochs between saving model checkpoints.
78
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
+ Number of checkpoints to keep before overwriting old ones.
80
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
+ Mixed precision training.
82
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
83
+ Allow `tf32` on Ampere GPUs.
84
+ resume_from (`str`, *optional*, defaults to `""`):
85
+ Path to resume training from a checkpoint.
86
+ sample_num_steps (`int`, *optional*, defaults to `50`):
87
+ Number of sampler inference steps.
88
+ sample_eta (`float`, *optional*, defaults to `1.0`):
89
+ Eta parameter for the DDIM sampler.
90
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
+ Classifier-free guidance weight.
92
+ train_batch_size (`int`, *optional*, defaults to `1`):
93
+ Batch size for training.
94
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
+ Learning rate.
98
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
+ Beta1 for Adam optimizer.
100
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
+ Beta2 for Adam optimizer.
102
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
+ Weight decay for Adam optimizer.
104
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
+ Epsilon value for Adam optimizer.
106
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
+ Number of gradient accumulation steps.
108
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
+ Maximum gradient norm for gradient clipping.
110
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
+ Comma-separated list of prompts to use as negative examples.
112
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
+ If `True`, randomized truncation to different diffusion timesteps is used.
114
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
+ Range of diffusion timesteps for randomized truncated backpropagation.
118
+ push_to_hub (`bool`, *optional*, defaults to `False`):
119
+ Whether to push the final model to the Hub.
120
+
121
+ """
122
+ vllm_sampling_params: Optional[Any] = field(
123
+ default = None,
124
+ metadata = {'help': 'vLLM SamplingParams'},
125
+ )
126
+ unsloth_num_chunks : Optional[int] = field(
127
+ default = -1,
128
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
+ )
130
+ def __init__(
131
+ self,
132
+ exp_name = 'demo',
133
+ run_name = '',
134
+ seed = 3407,
135
+ log_with = None,
136
+ log_image_freq = 1,
137
+ tracker_project_name = 'trl',
138
+ logdir = 'logs',
139
+ num_epochs = 100,
140
+ save_freq = 1,
141
+ num_checkpoint_limit = 5,
142
+ mixed_precision = 'fp16',
143
+ allow_tf32 = True,
144
+ resume_from = '',
145
+ sample_num_steps = 50,
146
+ sample_eta = 1.0,
147
+ sample_guidance_scale = 5.0,
148
+ train_batch_size = 1,
149
+ train_use_8bit_adam = False,
150
+ train_learning_rate = 5e-05,
151
+ train_adam_beta1 = 0.9,
152
+ train_adam_beta2 = 0.999,
153
+ train_adam_weight_decay = 0.01,
154
+ train_adam_epsilon = 1e-08,
155
+ train_gradient_accumulation_steps = 2,
156
+ train_max_grad_norm = 1.0,
157
+ negative_prompts = None,
158
+ truncated_backprop_rand = True,
159
+ truncated_backprop_timestep = 49,
160
+ push_to_hub = False,
161
+ vllm_sampling_params = None,
162
+ unsloth_num_chunks = -1,
163
+ **kwargs,
164
+ ):
165
+
166
+ super().__init__(
167
+ exp_name = exp_name,
168
+ run_name = run_name,
169
+ seed = seed,
170
+ log_with = log_with,
171
+ log_image_freq = log_image_freq,
172
+ tracker_project_name = tracker_project_name,
173
+ logdir = logdir,
174
+ num_epochs = num_epochs,
175
+ save_freq = save_freq,
176
+ num_checkpoint_limit = num_checkpoint_limit,
177
+ mixed_precision = mixed_precision,
178
+ allow_tf32 = allow_tf32,
179
+ resume_from = resume_from,
180
+ sample_num_steps = sample_num_steps,
181
+ sample_eta = sample_eta,
182
+ sample_guidance_scale = sample_guidance_scale,
183
+ train_batch_size = train_batch_size,
184
+ train_use_8bit_adam = train_use_8bit_adam,
185
+ train_learning_rate = train_learning_rate,
186
+ train_adam_beta1 = train_adam_beta1,
187
+ train_adam_beta2 = train_adam_beta2,
188
+ train_adam_weight_decay = train_adam_weight_decay,
189
+ train_adam_epsilon = train_adam_epsilon,
190
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
+ train_max_grad_norm = train_max_grad_norm,
192
+ negative_prompts = negative_prompts,
193
+ truncated_backprop_rand = truncated_backprop_rand,
194
+ truncated_backprop_timestep = truncated_backprop_timestep,
195
+ push_to_hub = push_to_hub,**kwargs)
196
+ self.vllm_sampling_params = vllm_sampling_params
197
+ self.unsloth_num_chunks = unsloth_num_chunks
198
+ pass
199
+
200
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
+ """"""
202
+
203
+ _tag_names = ["trl", "alignprop"]
204
+
205
+ def __init__(
206
+ self,
207
+ config: AlignPropConfig,
208
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
+ prompt_function: Callable[[], tuple[str, Any]],
210
+ sd_pipeline: DDPOStableDiffusionPipeline,
211
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
+ ):
213
+ if image_samples_hook is None:
214
+ warn("No image_samples_hook provided; no images will be logged")
215
+
216
+ self.prompt_fn = prompt_function
217
+ self.reward_fn = reward_function
218
+ self.config = config
219
+ self.image_samples_callback = image_samples_hook
220
+
221
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
+
223
+ if self.config.resume_from:
224
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
+ # get the most recent checkpoint in this directory
227
+ checkpoints = list(
228
+ filter(
229
+ lambda x: "checkpoint_" in x,
230
+ os.listdir(self.config.resume_from),
231
+ )
232
+ )
233
+ if len(checkpoints) == 0:
234
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
+ self.config.resume_from = os.path.join(
237
+ self.config.resume_from,
238
+ f"checkpoint_{checkpoint_numbers[-1]}",
239
+ )
240
+
241
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
+
243
+ self.accelerator = Accelerator(
244
+ log_with=self.config.log_with,
245
+ mixed_precision=self.config.mixed_precision,
246
+ project_config=accelerator_project_config,
247
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
+ # the total number of optimizer steps to accumulate across.
250
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
+ **self.config.accelerator_kwargs,
252
+ )
253
+
254
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
+
256
+ if self.accelerator.is_main_process:
257
+ self.accelerator.init_trackers(
258
+ self.config.tracker_project_name,
259
+ config=dict(alignprop_trainer_config=config.to_dict())
260
+ if not is_using_tensorboard
261
+ else config.to_dict(),
262
+ init_kwargs=self.config.tracker_kwargs,
263
+ )
264
+
265
+ logger.info(f"\n{config}")
266
+
267
+ set_seed(self.config.seed, device_specific=True)
268
+
269
+ self.sd_pipeline = sd_pipeline
270
+
271
+ self.sd_pipeline.set_progress_bar_config(
272
+ position=1,
273
+ disable=not self.accelerator.is_local_main_process,
274
+ leave=False,
275
+ desc="Timestep",
276
+ dynamic_ncols=True,
277
+ )
278
+
279
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
+ # as these weights are only used for inference, keeping weights in full precision is not required.
281
+ if self.accelerator.mixed_precision == "fp16":
282
+ inference_dtype = torch.float16
283
+ elif self.accelerator.mixed_precision == "bf16":
284
+ inference_dtype = torch.bfloat16
285
+ else:
286
+ inference_dtype = torch.float32
287
+
288
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
+
292
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
293
+
294
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
+
297
+ # Enable TF32 for faster training on Ampere GPUs,
298
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
+ if self.config.allow_tf32:
300
+ torch.backends.cuda.matmul.allow_tf32 = True
301
+
302
+ self.optimizer = self._setup_optimizer(
303
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
+ )
305
+
306
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
+ self.sd_pipeline.tokenizer(
308
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
+ return_tensors="pt",
310
+ padding="max_length",
311
+ truncation=True,
312
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
313
+ ).input_ids.to(self.accelerator.device)
314
+ )[0]
315
+
316
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
+ # more memory
318
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
+
320
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
+ else:
324
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
+
326
+ if config.resume_from:
327
+ logger.info(f"Resuming from {config.resume_from}")
328
+ self.accelerator.load_state(config.resume_from)
329
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
+ else:
331
+ self.first_epoch = 0
332
+
333
+ def compute_rewards(self, prompt_image_pairs):
334
+ reward, reward_metadata = self.reward_fn(
335
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
+ )
337
+ return reward
338
+
339
+ def step(self, epoch: int, global_step: int):
340
+ """
341
+ Perform a single step of training.
342
+
343
+ Args:
344
+ epoch (int): The current epoch.
345
+ global_step (int): The current global step.
346
+
347
+ Side Effects:
348
+ - Model weights are updated
349
+ - Logs the statistics to the accelerator trackers.
350
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
+
352
+ Returns:
353
+ global_step (int): The updated global step.
354
+ """
355
+ info = defaultdict(list)
356
+
357
+ self.sd_pipeline.unet.train()
358
+
359
+ for _ in range(self.config.train_gradient_accumulation_steps):
360
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
+ prompt_image_pairs = self._generate_samples(
362
+ batch_size=self.config.train_batch_size,
363
+ )
364
+
365
+ rewards = self.compute_rewards(prompt_image_pairs)
366
+
367
+ prompt_image_pairs["rewards"] = rewards
368
+
369
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
+
371
+ loss = self.calculate_loss(rewards)
372
+
373
+ self.accelerator.backward(loss)
374
+
375
+ if self.accelerator.sync_gradients:
376
+ self.accelerator.clip_grad_norm_(
377
+ self.trainable_layers.parameters()
378
+ if not isinstance(self.trainable_layers, list)
379
+ else self.trainable_layers,
380
+ self.config.train_max_grad_norm,
381
+ )
382
+
383
+ self.optimizer.step()
384
+ self.optimizer.zero_grad()
385
+
386
+ info["reward_mean"].append(rewards_vis.mean())
387
+ info["reward_std"].append(rewards_vis.std())
388
+ info["loss"].append(loss.item())
389
+
390
+ # Checks if the accelerator has performed an optimization step behind the scenes
391
+ if self.accelerator.sync_gradients:
392
+ # log training-related stuff
393
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
+ info = self.accelerator.reduce(info, reduction="mean")
395
+ info.update({"epoch": epoch})
396
+ self.accelerator.log(info, step=global_step)
397
+ global_step += 1
398
+ info = defaultdict(list)
399
+ else:
400
+ raise ValueError(
401
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
+ )
403
+ # Logs generated images
404
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
+
407
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
+ self.accelerator.save_state()
409
+
410
+ return global_step
411
+
412
+ def calculate_loss(self, rewards):
413
+ """
414
+ Calculate the loss for a batch of an unpacked sample
415
+
416
+ Args:
417
+ rewards (torch.Tensor):
418
+ Differentiable reward scalars for each generated image, shape: [batch_size]
419
+
420
+ Returns:
421
+ loss (torch.Tensor)
422
+ (all of these are of shape (1,))
423
+ """
424
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
+ loss = 10.0 - (rewards).mean()
426
+ return loss
427
+
428
+ def loss(
429
+ self,
430
+ advantages: torch.Tensor,
431
+ clip_range: float,
432
+ ratio: torch.Tensor,
433
+ ):
434
+ unclipped_loss = -advantages * ratio
435
+ clipped_loss = -advantages * torch.clamp(
436
+ ratio,
437
+ 1.0 - clip_range,
438
+ 1.0 + clip_range,
439
+ )
440
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
+
442
+ def _setup_optimizer(self, trainable_layers_parameters):
443
+ if self.config.train_use_8bit_adam:
444
+ import bitsandbytes
445
+
446
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
447
+ else:
448
+ optimizer_cls = torch.optim.AdamW
449
+
450
+ return optimizer_cls(
451
+ trainable_layers_parameters,
452
+ lr=self.config.train_learning_rate,
453
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
+ weight_decay=self.config.train_adam_weight_decay,
455
+ eps=self.config.train_adam_epsilon,
456
+ )
457
+
458
+ def _save_model_hook(self, models, weights, output_dir):
459
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
+
462
+ def _load_model_hook(self, models, input_dir):
463
+ self.sd_pipeline.load_checkpoint(models, input_dir)
464
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
+
466
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
+ """
468
+ Generate samples from the model
469
+
470
+ Args:
471
+ batch_size (int): Batch size to use for sampling
472
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
+
474
+ Returns:
475
+ prompt_image_pairs (dict[Any])
476
+ """
477
+ prompt_image_pairs = {}
478
+
479
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
+
481
+ if prompts is None:
482
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
+ else:
484
+ prompt_metadata = [{} for _ in range(batch_size)]
485
+
486
+ prompt_ids = self.sd_pipeline.tokenizer(
487
+ prompts,
488
+ return_tensors="pt",
489
+ padding="max_length",
490
+ truncation=True,
491
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
492
+ ).input_ids.to(self.accelerator.device)
493
+
494
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
+
496
+ if with_grad:
497
+ sd_output = self.sd_pipeline.rgb_with_grad(
498
+ prompt_embeds=prompt_embeds,
499
+ negative_prompt_embeds=sample_neg_prompt_embeds,
500
+ num_inference_steps=self.config.sample_num_steps,
501
+ guidance_scale=self.config.sample_guidance_scale,
502
+ eta=self.config.sample_eta,
503
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
504
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
+ output_type="pt",
507
+ )
508
+ else:
509
+ sd_output = self.sd_pipeline(
510
+ prompt_embeds=prompt_embeds,
511
+ negative_prompt_embeds=sample_neg_prompt_embeds,
512
+ num_inference_steps=self.config.sample_num_steps,
513
+ guidance_scale=self.config.sample_guidance_scale,
514
+ eta=self.config.sample_eta,
515
+ output_type="pt",
516
+ )
517
+
518
+ images = sd_output.images
519
+
520
+ prompt_image_pairs["images"] = images
521
+ prompt_image_pairs["prompts"] = prompts
522
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
+
524
+ return prompt_image_pairs
525
+
526
+ def train(self, epochs: Optional[int] = None):
527
+ """
528
+ Train the model for a given number of epochs
529
+ """
530
+ global_step = 0
531
+ if epochs is None:
532
+ epochs = self.config.num_epochs
533
+ for epoch in range(self.first_epoch, epochs):
534
+ global_step = self.step(epoch, global_step)
535
+
536
+ def _save_pretrained(self, save_directory):
537
+ self.sd_pipeline.save_pretrained(save_directory)
538
+ self.create_model_card()
539
+
540
+ def create_model_card(
541
+ self,
542
+ model_name: Optional[str] = None,
543
+ dataset_name: Optional[str] = None,
544
+ tags: Union[str, list[str], None] = None,
545
+ ):
546
+ """
547
+ Creates a draft of a model card using the information available to the `Trainer`.
548
+
549
+ Args:
550
+ model_name (`str` or `None`, *optional*, defaults to `None`):
551
+ Name of the model.
552
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
+ Name of the dataset used for training.
554
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
+ Tags to be associated with the model card.
556
+ """
557
+ if not self.is_world_process_zero():
558
+ return
559
+
560
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
+ base_model = self.model.config._name_or_path
562
+ else:
563
+ base_model = None
564
+
565
+ tags = tags or []
566
+ if isinstance(tags, str):
567
+ tags = [tags]
568
+
569
+ if hasattr(self.model.config, "unsloth_version"):
570
+ tags.append("unsloth")
571
+
572
+ citation = textwrap.dedent("""\
573
+ @article{prabhudesai2024aligning,
574
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
+ year = 2024,
577
+ eprint = {arXiv:2310.03739}
578
+ }""")
579
+
580
+ model_card = generate_model_card(
581
+ base_model=base_model,
582
+ model_name=model_name,
583
+ hub_model_id=self.hub_model_id,
584
+ dataset_name=dataset_name,
585
+ tags=tags,
586
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
+ comet_url=get_comet_experiment_url(),
588
+ trainer_name="AlignProp",
589
+ trainer_citation=citation,
590
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
+ paper_id="2310.03739",
592
+ )
593
+
594
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
+ """
597
+
598
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
+ As of now only Stable Diffusion based pipelines are supported
601
+
602
+ Attributes:
603
+ config (`AlignPropConfig`):
604
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
+ Reward function to be used
607
+ prompt_function (`Callable[[], tuple[str, Any]]`):
608
+ Function to generate prompts to guide model
609
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
610
+ Stable Diffusion pipeline to be used for training.
611
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
+ Hook to be called to log images
613
+
614
+ """
615
+ def __init__(
616
+ self,
617
+ config,
618
+ reward_function,
619
+ prompt_function,
620
+ sd_pipeline,
621
+ image_samples_hook = None,
622
+ **kwargs
623
+ ):
624
+ if args is None: args = UnslothAlignPropConfig()
625
+ other_metrics = []
626
+
627
+ from unsloth_zoo.logging_utils import PatchRLStatistics
628
+ PatchRLStatistics('alignprop_trainer', other_metrics)
629
+
630
+ super().__init__(
631
+ config = config,
632
+ reward_function = reward_function,
633
+ prompt_function = prompt_function,
634
+ sd_pipeline = sd_pipeline,
635
+ image_samples_hook = image_samples_hook,**kwargs)
636
+
637
+ pass
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothBCOConfig(BCOConfig):
44
+ """
45
+
46
+ Configuration class for the [`BCOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
+ to use the default data collator.
56
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
+ Maximum length of the completion. This argument is required if you want to use the default data collator
60
+ and your model is an encoder-decoder.
61
+ beta (`float`, *optional*, defaults to `0.1`):
62
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
+ reference model.
64
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
+ Label pad token id. This argument is required if you want to use the default data collator.
66
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
67
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
68
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
+ This argument is required if you want to use the default data collator.
71
+ disable_dropout (`bool`, *optional*, defaults to `True`):
72
+ Whether to disable dropout in the model and reference model.
73
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
74
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
+ evaluation.
76
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
+ you need to specify if the model returned by the callable is an encoder-decoder model.
79
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
+ useful when training without the reference model to reduce the total GPU memory needed.
82
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
+ string.
85
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
+ from a string.
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
+ Number of prompts that are fed to density ratio classifier.
92
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
+
97
+ """
98
+ vllm_sampling_params: Optional[Any] = field(
99
+ default = None,
100
+ metadata = {'help': 'vLLM SamplingParams'},
101
+ )
102
+ unsloth_num_chunks : Optional[int] = field(
103
+ default = -1,
104
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
+ )
106
+ def __init__(
107
+ self,
108
+ output_dir = None,
109
+ overwrite_output_dir = None,
110
+ do_train = False,
111
+ do_eval = False,
112
+ do_predict = False,
113
+ eval_strategy = 'no',
114
+ prediction_loss_only = False,
115
+ per_device_train_batch_size = 4,
116
+ per_device_eval_batch_size = 4,
117
+ per_gpu_train_batch_size = None,
118
+ per_gpu_eval_batch_size = None,
119
+ gradient_accumulation_steps = 2,
120
+ eval_accumulation_steps = 2,
121
+ eval_delay = 0,
122
+ torch_empty_cache_steps = 250,
123
+ learning_rate = 5e-05,
124
+ weight_decay = 0.01,
125
+ adam_beta1 = 0.9,
126
+ adam_beta2 = 0.999,
127
+ adam_epsilon = 1e-08,
128
+ max_grad_norm = 1.0,
129
+ num_train_epochs = 3.0,
130
+ max_steps = -1,
131
+ lr_scheduler_type = 'linear',
132
+ warmup_ratio = 0.1,
133
+ warmup_steps = 0,
134
+ log_level = 'passive',
135
+ log_level_replica = 'warning',
136
+ log_on_each_node = True,
137
+ logging_dir = None,
138
+ logging_strategy = 'steps',
139
+ logging_first_step = False,
140
+ logging_steps = 1,
141
+ logging_nan_inf_filter = False,
142
+ save_strategy = 'steps',
143
+ save_steps = 500,
144
+ save_total_limit = None,
145
+ save_safetensors = True,
146
+ save_on_each_node = False,
147
+ save_only_model = False,
148
+ restore_callback_states_from_checkpoint = False,
149
+ no_cuda = False,
150
+ use_cpu = False,
151
+ use_mps_device = False,
152
+ seed = 3407,
153
+ data_seed = 3407,
154
+ jit_mode_eval = False,
155
+ use_ipex = False,
156
+ bf16 = False,
157
+ fp16 = False,
158
+ fp16_opt_level = 'O1',
159
+ half_precision_backend = 'auto',
160
+ bf16_full_eval = False,
161
+ fp16_full_eval = False,
162
+ tf32 = None,
163
+ local_rank = -1,
164
+ ddp_backend = None,
165
+ tpu_num_cores = None,
166
+ tpu_metrics_debug = False,
167
+ debug = '',
168
+ dataloader_drop_last = False,
169
+ eval_steps = None,
170
+ dataloader_num_workers = 0,
171
+ dataloader_prefetch_factor = None,
172
+ past_index = -1,
173
+ run_name = None,
174
+ disable_tqdm = None,
175
+ remove_unused_columns = True,
176
+ label_names = None,
177
+ load_best_model_at_end = False,
178
+ metric_for_best_model = None,
179
+ greater_is_better = None,
180
+ ignore_data_skip = False,
181
+ fsdp = '',
182
+ fsdp_min_num_params = 0,
183
+ fsdp_config = None,
184
+ tp_size = 0,
185
+ fsdp_transformer_layer_cls_to_wrap = None,
186
+ accelerator_config = None,
187
+ deepspeed = None,
188
+ label_smoothing_factor = 0.0,
189
+ optim = 'adamw_8bit',
190
+ optim_args = None,
191
+ adafactor = False,
192
+ group_by_length = False,
193
+ length_column_name = 'length',
194
+ report_to = None,
195
+ ddp_find_unused_parameters = None,
196
+ ddp_bucket_cap_mb = None,
197
+ ddp_broadcast_buffers = None,
198
+ dataloader_pin_memory = True,
199
+ dataloader_persistent_workers = False,
200
+ skip_memory_metrics = True,
201
+ use_legacy_prediction_loop = False,
202
+ push_to_hub = False,
203
+ resume_from_checkpoint = None,
204
+ hub_model_id = None,
205
+ hub_strategy = 'every_save',
206
+ hub_token = None,
207
+ hub_private_repo = None,
208
+ hub_always_push = False,
209
+ gradient_checkpointing = False,
210
+ gradient_checkpointing_kwargs = None,
211
+ include_inputs_for_metrics = False,
212
+ eval_do_concat_batches = True,
213
+ fp16_backend = 'auto',
214
+ evaluation_strategy = None,
215
+ push_to_hub_model_id = None,
216
+ push_to_hub_organization = None,
217
+ push_to_hub_token = None,
218
+ mp_parameters = '',
219
+ auto_find_batch_size = False,
220
+ full_determinism = False,
221
+ torchdynamo = None,
222
+ ray_scope = 'last',
223
+ ddp_timeout = 1800,
224
+ torch_compile = False,
225
+ torch_compile_backend = None,
226
+ torch_compile_mode = None,
227
+ dispatch_batches = None,
228
+ split_batches = None,
229
+ include_tokens_per_second = False,
230
+ include_num_input_tokens_seen = False,
231
+ neftune_noise_alpha = None,
232
+ optim_target_modules = None,
233
+ batch_eval_metrics = False,
234
+ eval_on_start = False,
235
+ use_liger_kernel = False,
236
+ eval_use_gather_object = False,
237
+ average_tokens_across_devices = False,
238
+ max_length = 1024,
239
+ max_prompt_length = 512,
240
+ max_completion_length = None,
241
+ beta = 0.1,
242
+ label_pad_token_id = -100,
243
+ padding_value = None,
244
+ truncation_mode = 'keep_end',
245
+ disable_dropout = True,
246
+ generate_during_eval = False,
247
+ is_encoder_decoder = None,
248
+ precompute_ref_log_probs = False,
249
+ model_init_kwargs = None,
250
+ ref_model_init_kwargs = None,
251
+ dataset_num_proc = None,
252
+ prompt_sample_size = 1024,
253
+ min_density_ratio = 0.5,
254
+ max_density_ratio = 10.0,
255
+ vllm_sampling_params = None,
256
+ unsloth_num_chunks = -1,
257
+ **kwargs,
258
+ ):
259
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
260
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
261
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
262
+ output_dir = 'unsloth_training_checkpoints'
263
+ save_strategy = 'no'
264
+ if dataset_num_proc is None:
265
+ from multiprocessing import cpu_count
266
+ dataset_num_proc = cpu_count()
267
+
268
+ super().__init__(
269
+ output_dir = output_dir,
270
+ overwrite_output_dir = overwrite_output_dir,
271
+ do_train = do_train,
272
+ do_eval = do_eval,
273
+ do_predict = do_predict,
274
+ eval_strategy = eval_strategy,
275
+ prediction_loss_only = prediction_loss_only,
276
+ per_device_train_batch_size = per_device_train_batch_size,
277
+ per_device_eval_batch_size = per_device_eval_batch_size,
278
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
279
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
280
+ gradient_accumulation_steps = gradient_accumulation_steps,
281
+ eval_accumulation_steps = eval_accumulation_steps,
282
+ eval_delay = eval_delay,
283
+ torch_empty_cache_steps = torch_empty_cache_steps,
284
+ learning_rate = learning_rate,
285
+ weight_decay = weight_decay,
286
+ adam_beta1 = adam_beta1,
287
+ adam_beta2 = adam_beta2,
288
+ adam_epsilon = adam_epsilon,
289
+ max_grad_norm = max_grad_norm,
290
+ num_train_epochs = num_train_epochs,
291
+ max_steps = max_steps,
292
+ lr_scheduler_type = lr_scheduler_type,
293
+ warmup_ratio = warmup_ratio,
294
+ warmup_steps = warmup_steps,
295
+ log_level = log_level,
296
+ log_level_replica = log_level_replica,
297
+ log_on_each_node = log_on_each_node,
298
+ logging_dir = logging_dir,
299
+ logging_strategy = logging_strategy,
300
+ logging_first_step = logging_first_step,
301
+ logging_steps = logging_steps,
302
+ logging_nan_inf_filter = logging_nan_inf_filter,
303
+ save_strategy = save_strategy,
304
+ save_steps = save_steps,
305
+ save_total_limit = save_total_limit,
306
+ save_safetensors = save_safetensors,
307
+ save_on_each_node = save_on_each_node,
308
+ save_only_model = save_only_model,
309
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
310
+ no_cuda = no_cuda,
311
+ use_cpu = use_cpu,
312
+ use_mps_device = use_mps_device,
313
+ seed = seed,
314
+ data_seed = data_seed,
315
+ jit_mode_eval = jit_mode_eval,
316
+ use_ipex = use_ipex,
317
+ bf16 = bf16,
318
+ fp16 = fp16,
319
+ fp16_opt_level = fp16_opt_level,
320
+ half_precision_backend = half_precision_backend,
321
+ bf16_full_eval = bf16_full_eval,
322
+ fp16_full_eval = fp16_full_eval,
323
+ tf32 = tf32,
324
+ local_rank = local_rank,
325
+ ddp_backend = ddp_backend,
326
+ tpu_num_cores = tpu_num_cores,
327
+ tpu_metrics_debug = tpu_metrics_debug,
328
+ debug = debug,
329
+ dataloader_drop_last = dataloader_drop_last,
330
+ eval_steps = eval_steps,
331
+ dataloader_num_workers = dataloader_num_workers,
332
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
333
+ past_index = past_index,
334
+ run_name = run_name,
335
+ disable_tqdm = disable_tqdm,
336
+ remove_unused_columns = remove_unused_columns,
337
+ label_names = label_names,
338
+ load_best_model_at_end = load_best_model_at_end,
339
+ metric_for_best_model = metric_for_best_model,
340
+ greater_is_better = greater_is_better,
341
+ ignore_data_skip = ignore_data_skip,
342
+ fsdp = fsdp,
343
+ fsdp_min_num_params = fsdp_min_num_params,
344
+ fsdp_config = fsdp_config,
345
+ tp_size = tp_size,
346
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
347
+ accelerator_config = accelerator_config,
348
+ deepspeed = deepspeed,
349
+ label_smoothing_factor = label_smoothing_factor,
350
+ optim = optim,
351
+ optim_args = optim_args,
352
+ adafactor = adafactor,
353
+ group_by_length = group_by_length,
354
+ length_column_name = length_column_name,
355
+ report_to = report_to,
356
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
357
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
358
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
359
+ dataloader_pin_memory = dataloader_pin_memory,
360
+ dataloader_persistent_workers = dataloader_persistent_workers,
361
+ skip_memory_metrics = skip_memory_metrics,
362
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
363
+ push_to_hub = push_to_hub,
364
+ resume_from_checkpoint = resume_from_checkpoint,
365
+ hub_model_id = hub_model_id,
366
+ hub_strategy = hub_strategy,
367
+ hub_token = hub_token,
368
+ hub_private_repo = hub_private_repo,
369
+ hub_always_push = hub_always_push,
370
+ gradient_checkpointing = gradient_checkpointing,
371
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
372
+ include_inputs_for_metrics = include_inputs_for_metrics,
373
+ eval_do_concat_batches = eval_do_concat_batches,
374
+ fp16_backend = fp16_backend,
375
+ evaluation_strategy = evaluation_strategy,
376
+ push_to_hub_model_id = push_to_hub_model_id,
377
+ push_to_hub_organization = push_to_hub_organization,
378
+ push_to_hub_token = push_to_hub_token,
379
+ mp_parameters = mp_parameters,
380
+ auto_find_batch_size = auto_find_batch_size,
381
+ full_determinism = full_determinism,
382
+ torchdynamo = torchdynamo,
383
+ ray_scope = ray_scope,
384
+ ddp_timeout = ddp_timeout,
385
+ torch_compile = torch_compile,
386
+ torch_compile_backend = torch_compile_backend,
387
+ torch_compile_mode = torch_compile_mode,
388
+ dispatch_batches = dispatch_batches,
389
+ split_batches = split_batches,
390
+ include_tokens_per_second = include_tokens_per_second,
391
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
392
+ neftune_noise_alpha = neftune_noise_alpha,
393
+ optim_target_modules = optim_target_modules,
394
+ batch_eval_metrics = batch_eval_metrics,
395
+ eval_on_start = eval_on_start,
396
+ use_liger_kernel = use_liger_kernel,
397
+ eval_use_gather_object = eval_use_gather_object,
398
+ average_tokens_across_devices = average_tokens_across_devices,
399
+ max_length = max_length,
400
+ max_prompt_length = max_prompt_length,
401
+ max_completion_length = max_completion_length,
402
+ beta = beta,
403
+ label_pad_token_id = label_pad_token_id,
404
+ padding_value = padding_value,
405
+ truncation_mode = truncation_mode,
406
+ disable_dropout = disable_dropout,
407
+ generate_during_eval = generate_during_eval,
408
+ is_encoder_decoder = is_encoder_decoder,
409
+ precompute_ref_log_probs = precompute_ref_log_probs,
410
+ model_init_kwargs = model_init_kwargs,
411
+ ref_model_init_kwargs = ref_model_init_kwargs,
412
+ dataset_num_proc = dataset_num_proc,
413
+ prompt_sample_size = prompt_sample_size,
414
+ min_density_ratio = min_density_ratio,
415
+ max_density_ratio = max_density_ratio,**kwargs)
416
+ self.vllm_sampling_params = vllm_sampling_params
417
+ self.unsloth_num_chunks = unsloth_num_chunks
418
+ pass
419
+
420
+ class _UnslothBCOTrainer(Trainer):
421
+ r""""""
422
+
423
+ _tag_names = ["trl", "bco"]
424
+
425
+ def __init__(
426
+ self,
427
+ model: Union[PreTrainedModel, nn.Module, str] = None,
428
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
429
+ args: BCOConfig = None,
430
+ train_dataset: Optional[Dataset] = None,
431
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
432
+ processing_class: Optional[
433
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
434
+ ] = None,
435
+ data_collator: Optional[DataCollator] = None,
436
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
437
+ callbacks: Optional[list[TrainerCallback]] = None,
438
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
439
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
440
+ peft_config: Optional[dict] = None,
441
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
442
+ model_adapter_name: Optional[str] = None,
443
+ ref_adapter_name: Optional[str] = None,
444
+ embedding_func: Optional[Callable] = None,
445
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
446
+ ):
447
+ if not is_sklearn_available():
448
+ raise ImportError(
449
+ "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
450
+ )
451
+
452
+ if type(args) is TrainingArguments:
453
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
454
+
455
+ if not isinstance(model, str) and ref_model is model:
456
+ raise ValueError(
457
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
458
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
459
+ )
460
+
461
+ if args.model_init_kwargs is None:
462
+ model_init_kwargs = {}
463
+ elif not isinstance(model, str):
464
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
465
+ else:
466
+ model_init_kwargs = args.model_init_kwargs
467
+ torch_dtype = model_init_kwargs.get("torch_dtype")
468
+ if torch_dtype is not None:
469
+ # Convert to `torch.dtype` if an str is passed
470
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
471
+ torch_dtype = getattr(torch, torch_dtype)
472
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
473
+ raise ValueError(
474
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
475
+ )
476
+ model_init_kwargs["torch_dtype"] = torch_dtype
477
+
478
+ if args.ref_model_init_kwargs is None:
479
+ ref_model_init_kwargs = {}
480
+ elif not isinstance(ref_model, str):
481
+ raise ValueError(
482
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
483
+ )
484
+ else:
485
+ ref_model_init_kwargs = args.ref_model_init_kwargs
486
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
487
+ if torch_dtype is not None:
488
+ # Convert to `torch.dtype` if an str is passed
489
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
490
+ torch_dtype = getattr(torch, torch_dtype)
491
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
492
+ raise ValueError(
493
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
494
+ )
495
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
496
+
497
+ if isinstance(model, str):
498
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
499
+
500
+ if isinstance(ref_model, str):
501
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
502
+
503
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
504
+ # has been called in order to properly call autocast if needed.
505
+ self._peft_has_been_casted_to_bf16 = False
506
+
507
+ if not is_peft_available() and peft_config is not None:
508
+ raise ValueError(
509
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
510
+ )
511
+ elif is_peft_available() and peft_config is not None:
512
+ # if model is a peft model and we have a peft_config, we merge and unload it first
513
+ if isinstance(model, PeftModel):
514
+ model = model.merge_and_unload()
515
+
516
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
517
+ _support_gc_kwargs = hasattr(
518
+ args, "gradient_checkpointing_kwargs"
519
+ ) and "gradient_checkpointing_kwargs" in list(
520
+ inspect.signature(prepare_model_for_kbit_training).parameters
521
+ )
522
+
523
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
524
+
525
+ if _support_gc_kwargs:
526
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
527
+
528
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
529
+ elif getattr(args, "gradient_checkpointing", False):
530
+ # For backward compatibility with older versions of transformers
531
+ if hasattr(model, "enable_input_require_grads"):
532
+ model.enable_input_require_grads()
533
+ else:
534
+
535
+ def make_inputs_require_grad(module, input, output):
536
+ output.requires_grad_(True)
537
+
538
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
539
+
540
+ # get peft model with the given config
541
+ model = model
542
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
543
+ peft_module_casting_to_bf16(model)
544
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
545
+ self._peft_has_been_casted_to_bf16 = True
546
+
547
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
548
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
549
+ # fail or completely fail.
550
+ elif getattr(args, "gradient_checkpointing", False):
551
+ # For backward compatibility with older versions of transformers
552
+ if hasattr(model, "enable_input_require_grads"):
553
+ model.enable_input_require_grads()
554
+ else:
555
+
556
+ def make_inputs_require_grad(module, input, output):
557
+ output.requires_grad_(True)
558
+
559
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
560
+
561
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
562
+ raise ValueError(
563
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
564
+ " Please install `wandb` or `comet-ml` to resolve."
565
+ )
566
+
567
+ if model is not None:
568
+ self.is_encoder_decoder = model.config.is_encoder_decoder
569
+ elif args.is_encoder_decoder is None:
570
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
571
+ else:
572
+ self.is_encoder_decoder = args.is_encoder_decoder
573
+
574
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
575
+ self.model_adapter_name = model_adapter_name
576
+ self.ref_adapter_name = ref_adapter_name
577
+
578
+ if ref_model:
579
+ self.ref_model = ref_model
580
+ elif self.is_peft_model or args.precompute_ref_log_probs:
581
+ # The `model` with adapters turned off will be used as the reference model
582
+ self.ref_model = None
583
+ else:
584
+ self.ref_model = create_reference_model(model)
585
+
586
+ if processing_class is None:
587
+ raise ValueError(
588
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
589
+ )
590
+ if args.max_length is None:
591
+ warnings.warn(
592
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
593
+ "It will be set to `512` by default, but you should do it yourself in the future.",
594
+ UserWarning,
595
+ )
596
+ max_length = 512
597
+ if args.max_length is not None:
598
+ max_length = args.max_length
599
+
600
+ if args.max_prompt_length is None:
601
+ warnings.warn(
602
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
603
+ "It will be set to `128` by default, but you should do it yourself in the future.",
604
+ UserWarning,
605
+ )
606
+ max_prompt_length = 128
607
+ if args.max_prompt_length is not None:
608
+ max_prompt_length = args.max_prompt_length
609
+
610
+ max_completion_length = None
611
+ if args.max_completion_length is None and self.is_encoder_decoder:
612
+ warnings.warn(
613
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
614
+ " it will be set to `128` by default, but you should do it yourself in the future.",
615
+ UserWarning,
616
+ )
617
+ max_completion_length = 128
618
+ if args.max_completion_length is not None and self.is_encoder_decoder:
619
+ max_completion_length = args.max_completion_length
620
+
621
+ if data_collator is None:
622
+ data_collator = DPODataCollatorWithPadding(
623
+ pad_token_id=processing_class.pad_token_id,
624
+ label_pad_token_id=args.label_pad_token_id,
625
+ is_encoder_decoder=self.is_encoder_decoder,
626
+ )
627
+
628
+ if args.remove_unused_columns:
629
+ args.remove_unused_columns = False
630
+ # warn users
631
+ warnings.warn(
632
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
633
+ " we have set it for you, but you should do it yourself in the future.",
634
+ UserWarning,
635
+ )
636
+
637
+ self.use_dpo_data_collator = True
638
+ else:
639
+ self.use_dpo_data_collator = False
640
+
641
+ # Disable dropout in the model and reference model
642
+ if args.disable_dropout:
643
+ disable_dropout_in_model(model)
644
+ if self.ref_model is not None:
645
+ disable_dropout_in_model(self.ref_model)
646
+
647
+ self.max_length = max_length
648
+ self.generate_during_eval = args.generate_during_eval
649
+ self.label_pad_token_id = args.label_pad_token_id
650
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
651
+ self.max_prompt_length = max_prompt_length
652
+ self.truncation_mode = args.truncation_mode
653
+ self.max_completion_length = max_completion_length
654
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
655
+
656
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
657
+ # keep track of first called to avoid computation of future calls
658
+ self._precomputed_train_ref_log_probs = False
659
+ self._precomputed_eval_ref_log_probs = False
660
+
661
+ # metric
662
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
663
+
664
+ # BCO parameter
665
+ self.beta = args.beta
666
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
667
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
668
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
669
+ warnings.warn(
670
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
671
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
672
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
673
+ "loss.",
674
+ UserWarning,
675
+ )
676
+
677
+ # Underlying Distribution Matching argument
678
+ self.embedding_func = embedding_func
679
+ self.embedding_tokenizer = embedding_tokenizer
680
+
681
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
682
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
683
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
684
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
685
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
686
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
687
+ # issued.
688
+ model.warnings_issued["estimate_tokens"] = True
689
+
690
+ with PartialState().local_main_process_first():
691
+ # Apply the chat template if needed
692
+ train_dataset = train_dataset.map(
693
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
694
+ )
695
+ if eval_dataset is not None:
696
+ eval_dataset = eval_dataset.map(
697
+ maybe_apply_chat_template,
698
+ fn_kwargs={"tokenizer": processing_class},
699
+ num_proc=args.dataset_num_proc,
700
+ )
701
+ # Shuffle the datasets
702
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
703
+ if eval_dataset is not None:
704
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
705
+ # Tokenize and prepare the training datasets
706
+ train_dataset = train_dataset.map(
707
+ _tokenize,
708
+ batched=True,
709
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Tokenizing train dataset",
712
+ )
713
+
714
+ # Prepare the datasets
715
+ fn_kwargs = {
716
+ "prefix": "",
717
+ "is_encoder_decoder": self.is_encoder_decoder,
718
+ "tokenizer": processing_class,
719
+ "max_length": self.max_length,
720
+ "truncation_mode": self.truncation_mode,
721
+ "label_pad_token_id": self.label_pad_token_id,
722
+ "max_prompt_length": self.max_prompt_length,
723
+ "max_completion_length": self.max_completion_length,
724
+ }
725
+ train_dataset = train_dataset.map(
726
+ _process_tokens,
727
+ fn_kwargs=fn_kwargs,
728
+ num_proc=args.dataset_num_proc,
729
+ desc="Processing tokenized train dataset",
730
+ )
731
+
732
+ if eval_dataset is not None:
733
+ # Tokenize
734
+ eval_dataset = eval_dataset.map(
735
+ _tokenize,
736
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
737
+ batched=True,
738
+ num_proc=args.dataset_num_proc,
739
+ desc="Tokenizing eval dataset",
740
+ )
741
+
742
+ # Process
743
+ fn_kwargs = {
744
+ "prefix": "",
745
+ "is_encoder_decoder": self.is_encoder_decoder,
746
+ "tokenizer": processing_class,
747
+ "max_length": self.max_length,
748
+ "truncation_mode": self.truncation_mode,
749
+ "label_pad_token_id": self.label_pad_token_id,
750
+ "max_prompt_length": self.max_prompt_length,
751
+ "max_completion_length": self.max_completion_length,
752
+ }
753
+ eval_dataset = eval_dataset.map(
754
+ _process_tokens,
755
+ fn_kwargs=fn_kwargs,
756
+ num_proc=args.dataset_num_proc,
757
+ desc="Processing tokenized eval dataset",
758
+ )
759
+
760
+ desirable = train_dataset.filter(
761
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
762
+ )
763
+ undesirable = train_dataset.filter(
764
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
765
+ )
766
+
767
+ desirable = desirable.shuffle(seed=args.data_seed)
768
+ undesirable = undesirable.shuffle(seed=args.data_seed)
769
+
770
+ super().__init__(
771
+ model=model,
772
+ args=args,
773
+ data_collator=data_collator,
774
+ train_dataset=train_dataset,
775
+ eval_dataset=eval_dataset,
776
+ processing_class=processing_class,
777
+ model_init=model_init,
778
+ compute_metrics=compute_metrics,
779
+ callbacks=callbacks,
780
+ optimizers=optimizers,
781
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
782
+ )
783
+
784
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
785
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
786
+ # self.model_accepts_loss_kwargs to False to enable scaling.
787
+ self.model_accepts_loss_kwargs = False
788
+
789
+ # Add tags for models that have been loaded with the correct transformers version
790
+ if hasattr(self.model, "add_model_tags"):
791
+ self.model.add_model_tags(self._tag_names)
792
+
793
+ if not hasattr(self, "accelerator"):
794
+ raise AttributeError(
795
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
796
+ )
797
+
798
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
799
+ if self.is_deepspeed_enabled:
800
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
801
+ raise ValueError(
802
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
803
+ )
804
+
805
+ if self.ref_model is None:
806
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
807
+ raise ValueError(
808
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
809
+ )
810
+ else:
811
+ if self.is_deepspeed_enabled:
812
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
813
+ else:
814
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
815
+
816
+ self.running = RunningMoments(accelerator=self.accelerator)
817
+
818
+ if self.embedding_func is None:
819
+ return
820
+
821
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
822
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
823
+
824
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
825
+ labels = torch.cat(
826
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
827
+ )
828
+
829
+ self.clf = LogisticRegression(class_weight="balanced").fit(
830
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
831
+ )
832
+
833
+ @property
834
+ def match_underlying_distribution(self):
835
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
836
+
837
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
838
+ """
839
+ Calculates the probability if the given prompt embedding is from desirable dataset.
840
+ This function calculates the probability in the process and ensemble across processes.
841
+ """
842
+ dtype = prompt_embeddings.dtype
843
+ device = prompt_embeddings.device
844
+ rank = self.accelerator.process_index
845
+
846
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
847
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
848
+ )
849
+ sample_size = padded_prompt_embeddings.shape[0]
850
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
851
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
852
+
853
+ # cannot predict for all empty values
854
+ if prompt_embeddings.shape[0] == 0:
855
+ return torch.tensor([], device=device, dtype=dtype)
856
+
857
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
858
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
859
+ prob = self.accelerator.reduce(prob, reduction="mean")
860
+
861
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
862
+ prob = prob[nonzero]
863
+
864
+ return prob
865
+
866
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
867
+ """
868
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
869
+ and applies self.embedding_func
870
+ """
871
+ input_ids = torch.where(
872
+ input_ids == self.processing_class.pad_token_id,
873
+ self.embedding_tokenizer.pad_token_id,
874
+ input_ids,
875
+ )
876
+
877
+ with torch.no_grad():
878
+ embeddings = self.embedding_func(
879
+ input_ids=input_ids,
880
+ attention_mask=attention_mask,
881
+ )
882
+
883
+ return embeddings
884
+
885
+ def _get_prompt_embeddings(
886
+ self, batch: dict[str, Union[list, torch.LongTensor]]
887
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
888
+ """Extract embeddings from frozen embedding model"""
889
+
890
+ if not self.match_underlying_distribution:
891
+ return None, None
892
+
893
+ embeddings = self._vectorize_prompt(
894
+ input_ids=batch["embedding_input_ids"],
895
+ attention_mask=batch["embedding_attention_mask"],
896
+ )
897
+
898
+ chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
899
+ rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
900
+
901
+ chosen_embeddings = embeddings[chosen_idx, ...]
902
+ rejected_embeddings = embeddings[rejected_idx, ...]
903
+
904
+ return (chosen_embeddings, rejected_embeddings)
905
+
906
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
907
+ """
908
+ Sample instances from dataset and get prompt embeddings.
909
+ Used for density ratio classifier training.
910
+ """
911
+ n_samples = min(len(dataset), sample_size)
912
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
913
+
914
+ embedding_dataset = dataset.select(rand_indices)
915
+
916
+ dataloader_params = {
917
+ "batch_size": self.args.per_device_train_batch_size,
918
+ "collate_fn": self.data_collator,
919
+ "num_workers": self.args.dataloader_num_workers,
920
+ "pin_memory": self.args.dataloader_pin_memory,
921
+ "shuffle": False,
922
+ }
923
+
924
+ # prepare dataloader
925
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
926
+
927
+ with torch.no_grad():
928
+ all_embeddings = torch.empty(0)
929
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
930
+ embeddings = self._vectorize_prompt(
931
+ input_ids=padded_batch["embedding_input_ids"],
932
+ attention_mask=padded_batch["embedding_attention_mask"],
933
+ )
934
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
935
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
936
+
937
+ return all_embeddings
938
+
939
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
940
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
941
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
942
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
943
+
944
+ if model is not None:
945
+ if hasattr(model, "config"):
946
+ hidden_size = (
947
+ max(model.config.hidden_sizes)
948
+ if getattr(model.config, "hidden_sizes", None)
949
+ else getattr(model.config, "hidden_size", None)
950
+ )
951
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
952
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
953
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
954
+ config_kwargs.update(
955
+ {
956
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
957
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
958
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
959
+ }
960
+ )
961
+
962
+ # If ZeRO-3 is used, we shard both the active and reference model.
963
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
964
+ if config_kwargs["zero_optimization"]["stage"] != 3:
965
+ config_kwargs["zero_optimization"]["stage"] = 0
966
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
967
+ model.eval()
968
+ return model
969
+
970
+ def _save_optimizer_and_scheduler(self, output_dir):
971
+ super()._save_optimizer_and_scheduler(output_dir)
972
+
973
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
974
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
975
+
976
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
977
+
978
+ if self.match_underlying_distribution:
979
+ torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
980
+
981
+ def _load_optimizer_and_scheduler(self, checkpoint):
982
+ super()._load_optimizer_and_scheduler(checkpoint)
983
+
984
+ if checkpoint is None:
985
+ return
986
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
987
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
988
+ if os.path.isfile(running_file):
989
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
990
+
991
+ if self.match_underlying_distribution:
992
+ clf_file = os.path.join(checkpoint, CLF_NAME)
993
+ if os.path.isfile(running_file):
994
+ self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
995
+
996
+ @contextmanager
997
+ def null_ref_context(self):
998
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
999
+ with (
1000
+ self.accelerator.unwrap_model(self.model).disable_adapter()
1001
+ if self.is_peft_model and not self.ref_adapter_name
1002
+ else nullcontext()
1003
+ ):
1004
+ if self.ref_adapter_name:
1005
+ self.model.set_adapter(self.ref_adapter_name)
1006
+ yield
1007
+ if self.ref_adapter_name:
1008
+ self.model.set_adapter(self.model_adapter_name or "default")
1009
+
1010
+ def get_train_dataloader(self) -> DataLoader:
1011
+ """
1012
+ Returns the training [`~torch.utils.data.DataLoader`].
1013
+
1014
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1015
+ """
1016
+
1017
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1018
+ dataloader_params = {
1019
+ "batch_size": self.args.per_device_train_batch_size,
1020
+ "collate_fn": self.data_collator,
1021
+ "num_workers": self.args.dataloader_num_workers,
1022
+ "pin_memory": self.args.dataloader_pin_memory,
1023
+ "shuffle": False,
1024
+ }
1025
+
1026
+ # prepare dataloader
1027
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1028
+ reference_completion_logps = []
1029
+
1030
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1031
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1032
+
1033
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1034
+ reference_completion_logps.append(reference_completion_logp.cpu())
1035
+
1036
+ self.train_dataset = self.train_dataset.add_column(
1037
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1038
+ )
1039
+
1040
+ self._precomputed_train_ref_log_probs = True
1041
+
1042
+ return super().get_train_dataloader()
1043
+
1044
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1045
+ """
1046
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1047
+
1048
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1049
+
1050
+ Args:
1051
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1052
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1053
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1054
+ """
1055
+ if eval_dataset is None and self.eval_dataset is None:
1056
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1057
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1058
+
1059
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1060
+ dataloader_params = {
1061
+ "batch_size": self.args.per_device_eval_batch_size,
1062
+ "collate_fn": self.data_collator,
1063
+ "num_workers": self.args.dataloader_num_workers,
1064
+ "pin_memory": self.args.dataloader_pin_memory,
1065
+ "shuffle": False,
1066
+ }
1067
+
1068
+ # prepare dataloader
1069
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1070
+
1071
+ reference_completion_logps = []
1072
+
1073
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1074
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1075
+
1076
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1077
+ reference_completion_logps.append(reference_completion_logp.cpu())
1078
+
1079
+ eval_dataset = eval_dataset.add_column(
1080
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1081
+ )
1082
+
1083
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1084
+ if self.eval_dataset is not None:
1085
+ self.eval_dataset = eval_dataset
1086
+ self._precomputed_eval_ref_log_probs = True
1087
+
1088
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1089
+
1090
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1091
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1092
+ with torch.no_grad():
1093
+ if self.ref_model is None:
1094
+ with self.null_ref_context():
1095
+ if self.is_encoder_decoder:
1096
+ completion_logits = self.model(
1097
+ padded_batch["prompt_input_ids"],
1098
+ attention_mask=padded_batch["prompt_attention_mask"],
1099
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1100
+ labels=padded_batch["completion_labels"],
1101
+ ).logits
1102
+
1103
+ else:
1104
+ completion_logits = self.model(
1105
+ padded_batch["completion_input_ids"],
1106
+ attention_mask=padded_batch["completion_attention_mask"],
1107
+ ).logits
1108
+
1109
+ else:
1110
+ if self.is_encoder_decoder:
1111
+ completion_logits = self.ref_model(
1112
+ padded_batch["prompt_input_ids"],
1113
+ attention_mask=padded_batch["prompt_attention_mask"],
1114
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1115
+ labels=padded_batch["completion_labels"],
1116
+ ).logits
1117
+
1118
+ else:
1119
+ completion_logits = self.ref_model(
1120
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1121
+ ).logits
1122
+
1123
+ completion_logps = self.get_batch_logps(
1124
+ completion_logits,
1125
+ padded_batch["completion_labels"],
1126
+ average_log_prob=False,
1127
+ is_encoder_decoder=self.is_encoder_decoder,
1128
+ label_pad_token_id=self.label_pad_token_id,
1129
+ )
1130
+
1131
+ return completion_logps
1132
+
1133
+ @staticmethod
1134
+ def get_batch_logps(
1135
+ logits: torch.FloatTensor,
1136
+ labels: torch.LongTensor,
1137
+ average_log_prob: bool = False,
1138
+ label_pad_token_id: int = -100,
1139
+ is_encoder_decoder: bool = False,
1140
+ ) -> torch.FloatTensor:
1141
+ """Compute the log probabilities of the given labels under the given logits.
1142
+
1143
+ Args:
1144
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1145
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1146
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1147
+
1148
+ Returns:
1149
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1150
+ """
1151
+ if logits.shape[:-1] != labels.shape:
1152
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1153
+
1154
+ if not is_encoder_decoder:
1155
+ labels = labels[:, 1:].clone()
1156
+ logits = logits[:, :-1, :]
1157
+ else:
1158
+ # Fixes end-dec RuntimeError
1159
+ labels = labels.clone()
1160
+
1161
+ loss_mask = labels != label_pad_token_id
1162
+
1163
+ # dummy token; we'll ignore the losses on these tokens later
1164
+ labels[labels == label_pad_token_id] = 0
1165
+
1166
+ per_token_logps = selective_log_softmax(logits, labels)
1167
+
1168
+ if average_log_prob:
1169
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1170
+ else:
1171
+ return (per_token_logps * loss_mask).sum(-1)
1172
+
1173
+ def forward(
1174
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1175
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1176
+ model_kwargs = (
1177
+ {
1178
+ "labels": batch["completion_labels"],
1179
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1180
+ }
1181
+ if self.is_encoder_decoder
1182
+ else {}
1183
+ )
1184
+ if self.aux_loss_enabled:
1185
+ model_kwargs["output_router_logits"] = True
1186
+
1187
+ outputs = model(
1188
+ batch["completion_input_ids"],
1189
+ attention_mask=batch["completion_attention_mask"],
1190
+ **model_kwargs,
1191
+ )
1192
+ completion_logits = outputs.logits
1193
+
1194
+ completion_logps = self.get_batch_logps(
1195
+ completion_logits,
1196
+ batch["completion_labels"],
1197
+ average_log_prob=False,
1198
+ is_encoder_decoder=self.is_encoder_decoder,
1199
+ label_pad_token_id=self.label_pad_token_id,
1200
+ )
1201
+
1202
+ if completion_logps.shape[0] != len(batch["label"]):
1203
+ raise ValueError(
1204
+ "There is a mismatch between the number of examples in this batch and the number of "
1205
+ "examples for which an output sequence was predicted."
1206
+ )
1207
+
1208
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1209
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1210
+
1211
+ chosen_logps = completion_logps[chosen_idx, ...]
1212
+ rejected_logps = completion_logps[rejected_idx, ...]
1213
+
1214
+ chosen_logits = completion_logits[chosen_idx, ...]
1215
+ rejected_logits = completion_logits[rejected_idx, ...]
1216
+
1217
+ if self.aux_loss_enabled:
1218
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1219
+ else:
1220
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1221
+
1222
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1223
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1224
+ min_ratio = self.args.min_density_ratio
1225
+ max_ratio = self.args.max_density_ratio
1226
+
1227
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1228
+
1229
+ return weight
1230
+
1231
+ def bco_loss(
1232
+ self,
1233
+ policy_chosen_logps: torch.FloatTensor,
1234
+ policy_rejected_logps: torch.FloatTensor,
1235
+ reference_chosen_logps: torch.FloatTensor,
1236
+ reference_rejected_logps: torch.FloatTensor,
1237
+ chosen_embeddings: Optional[torch.FloatTensor],
1238
+ rejected_embeddings: Optional[torch.FloatTensor],
1239
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1240
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1241
+
1242
+ Args:
1243
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1246
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1247
+ chosen_embeddings: embeddings of desirable prompts
1248
+ rejected_embeddings: embeddings of undesirable prompts
1249
+
1250
+ Returns:
1251
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1252
+ The losses tensor contains the BCO loss for each example in the batch.
1253
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1254
+ The delta value contains the moving average of all implicit rewards.
1255
+ """
1256
+
1257
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1258
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1259
+ chosen_rewards = self.beta * chosen_logratios
1260
+ else:
1261
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1262
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1263
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1264
+
1265
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1266
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1267
+ rejected_rewards = self.beta * rejected_logratios
1268
+ else:
1269
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1270
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1271
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1272
+
1273
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1274
+ self.running.update(rewards)
1275
+ delta = self.running.mean
1276
+
1277
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1278
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1279
+
1280
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1281
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1282
+
1283
+ if self.match_underlying_distribution:
1284
+ chosen_weight = torch.ones_like(chosen_losses)
1285
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1286
+
1287
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1288
+ else:
1289
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1290
+
1291
+ return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1292
+
1293
+ def get_batch_loss_metrics(
1294
+ self,
1295
+ model,
1296
+ batch: dict[str, Union[list, torch.LongTensor]],
1297
+ ):
1298
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1299
+ metrics = {}
1300
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1301
+
1302
+ forward_output = self.forward(model, batch)
1303
+ (
1304
+ policy_chosen_logps,
1305
+ policy_rejected_logps,
1306
+ policy_chosen_logits,
1307
+ policy_rejected_logits,
1308
+ ) = forward_output[:4]
1309
+ if self.aux_loss_enabled:
1310
+ aux_loss = forward_output[4]
1311
+
1312
+ # if reference_logps in batch use them, otherwise use the reference model
1313
+ if "reference_logps" in batch:
1314
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1315
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1316
+
1317
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1318
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1319
+ else:
1320
+ with torch.no_grad():
1321
+ if self.ref_model is None:
1322
+ with self.null_ref_context():
1323
+ (
1324
+ reference_chosen_logps,
1325
+ reference_rejected_logps,
1326
+ _,
1327
+ _,
1328
+ ) = self.forward(self.model, batch)[:4]
1329
+ else:
1330
+ (
1331
+ reference_chosen_logps,
1332
+ reference_rejected_logps,
1333
+ _,
1334
+ _,
1335
+ ) = self.forward(self.ref_model, batch)[:4]
1336
+
1337
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1338
+
1339
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1340
+ policy_chosen_logps,
1341
+ policy_rejected_logps,
1342
+ reference_chosen_logps,
1343
+ reference_rejected_logps,
1344
+ chosen_embeddings,
1345
+ rejected_embeddings,
1346
+ )
1347
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1348
+
1349
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1350
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1351
+
1352
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1353
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1354
+
1355
+ if all_num_chosen > 0:
1356
+ metrics["rewards/chosen_sum"] = (
1357
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1358
+ )
1359
+ metrics["logps/chosen_sum"] = (
1360
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1361
+ )
1362
+ metrics["logits/chosen_sum"] = (
1363
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1364
+ )
1365
+ metrics["count/chosen"] = all_num_chosen
1366
+
1367
+ if all_num_rejected > 0:
1368
+ metrics["rewards/rejected_sum"] = (
1369
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1370
+ )
1371
+ metrics["logps/rejected_sum"] = (
1372
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1373
+ )
1374
+ metrics["logits/rejected_sum"] = (
1375
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1376
+ )
1377
+ metrics["count/rejected"] = all_num_rejected
1378
+
1379
+ loss = losses.nanmean()
1380
+ if self.aux_loss_enabled:
1381
+ loss += self.aux_loss_coef * aux_loss
1382
+
1383
+ return loss, metrics
1384
+
1385
+ def compute_loss(
1386
+ self,
1387
+ model: Union[PreTrainedModel, nn.Module],
1388
+ inputs: dict[str, Union[torch.Tensor, Any]],
1389
+ return_outputs=False,
1390
+ num_items_in_batch=None,
1391
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1392
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1393
+
1394
+ with compute_loss_context_manager:
1395
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1396
+
1397
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1398
+ loss = loss.to(self.args.device)
1399
+ # force log the metrics
1400
+ if self.accelerator.is_main_process:
1401
+ self.store_metrics(metrics, train_eval="train")
1402
+
1403
+ if return_outputs:
1404
+ return (loss, metrics)
1405
+ return loss
1406
+
1407
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1408
+ for key, value in metrics.items():
1409
+ self._stored_metrics[train_eval][key].append(value)
1410
+
1411
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1412
+ if self.train_dataset is None or not has_length(self.train_dataset):
1413
+ return None
1414
+ return SequentialSampler(self.train_dataset)
1415
+
1416
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1417
+ """Generate samples from the model and reference model for the given batch of inputs."""
1418
+
1419
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1420
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1421
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1422
+ with generate_context_manager:
1423
+ policy_output = model.generate(
1424
+ input_ids=batch["prompt_input_ids"],
1425
+ attention_mask=batch["prompt_attention_mask"],
1426
+ max_length=self.max_length,
1427
+ do_sample=True,
1428
+ pad_token_id=self.processing_class.pad_token_id,
1429
+ )
1430
+
1431
+ # if reference_output in batch use that otherwise use the reference model
1432
+ if "reference_output" in batch:
1433
+ reference_output = batch["reference_output"]
1434
+ else:
1435
+ if self.ref_model is None:
1436
+ with self.null_ref_context():
1437
+ reference_output = self.model.generate(
1438
+ input_ids=batch["prompt_input_ids"],
1439
+ attention_mask=batch["prompt_attention_mask"],
1440
+ max_length=self.max_length,
1441
+ do_sample=True,
1442
+ pad_token_id=self.processing_class.pad_token_id,
1443
+ )
1444
+ else:
1445
+ reference_output = self.ref_model.generate(
1446
+ input_ids=batch["prompt_input_ids"],
1447
+ attention_mask=batch["prompt_attention_mask"],
1448
+ max_length=self.max_length,
1449
+ do_sample=True,
1450
+ pad_token_id=self.processing_class.pad_token_id,
1451
+ )
1452
+
1453
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1454
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1455
+
1456
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1457
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1458
+
1459
+ return policy_output_decoded, reference_output_decoded
1460
+
1461
+ def prediction_step(
1462
+ self,
1463
+ model: Union[PreTrainedModel, nn.Module],
1464
+ inputs: dict[str, Union[torch.Tensor, Any]],
1465
+ prediction_loss_only: bool,
1466
+ ignore_keys: Optional[list[str]] = None,
1467
+ ):
1468
+ if ignore_keys is None:
1469
+ if hasattr(model, "config"):
1470
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1471
+ else:
1472
+ ignore_keys = []
1473
+
1474
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1475
+ with torch.no_grad(), prediction_context_manager:
1476
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1477
+
1478
+ # force log the metrics
1479
+ if self.accelerator.is_main_process:
1480
+ self.store_metrics(metrics, train_eval="eval")
1481
+
1482
+ if prediction_loss_only:
1483
+ return (loss.detach(), None, None)
1484
+
1485
+ # logits for the chosen and rejected samples from model
1486
+ logits_dict = {
1487
+ "eval_logits/chosen": metrics["logits/chosen"],
1488
+ "eval_logits/rejected": metrics["logits/rejected"],
1489
+ }
1490
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1491
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1492
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1493
+
1494
+ return (loss.detach(), logits, labels)
1495
+
1496
+ def evaluation_loop(
1497
+ self,
1498
+ dataloader: DataLoader,
1499
+ description: str,
1500
+ prediction_loss_only: Optional[bool] = None,
1501
+ ignore_keys: Optional[list[str]] = None,
1502
+ metric_key_prefix: str = "eval",
1503
+ ) -> EvalLoopOutput:
1504
+ """
1505
+ Overriding built-in evaluation loop to store metrics for each batch.
1506
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1507
+
1508
+ Works both with or without labels.
1509
+ """
1510
+
1511
+ # Sample and save to game log if requested (for one batch to save time)
1512
+ if self.generate_during_eval:
1513
+ # Generate random indices within the range of the total number of samples
1514
+ num_samples = len(dataloader.dataset)
1515
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1516
+
1517
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1518
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1519
+ random_batch = self.data_collator(random_batch_dataset)
1520
+ random_batch = self._prepare_inputs(random_batch)
1521
+
1522
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1523
+ target_batch = {
1524
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1525
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1526
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1527
+ }
1528
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1529
+
1530
+ table = pd.DataFrame(
1531
+ columns=["Prompt", "Policy", "Ref Model"],
1532
+ data=[
1533
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1534
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1535
+ ],
1536
+ )
1537
+ if "wandb" in self.args.report_to:
1538
+ wandb.log({"game_log": wandb.Table(data=table)})
1539
+
1540
+ if "comet_ml" in self.args.report_to:
1541
+ log_table_to_comet_experiment(
1542
+ name="game_log.csv",
1543
+ table=table,
1544
+ )
1545
+
1546
+ # Base evaluation
1547
+ initial_output = super().evaluation_loop(
1548
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1549
+ )
1550
+
1551
+ return initial_output
1552
+
1553
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1554
+ """
1555
+ Log `logs` on the various objects watching training, including stored metrics.
1556
+
1557
+ Args:
1558
+ logs (`dict[str, float]`):
1559
+ The values to log.
1560
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1561
+ Start time of the training.
1562
+ """
1563
+ # logs either has 'loss' or 'eval_loss'
1564
+ train_eval = "train" if "loss" in logs else "eval"
1565
+ # train metrics should have no prefix, eval should have 'eval_'
1566
+ prefix = "eval_" if train_eval == "eval" else ""
1567
+ # accumulate average metrics from sums and lengths
1568
+ for split in ["chosen", "rejected"]:
1569
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1570
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1571
+ for metric in ["rewards", "logps", "logits"]:
1572
+ logs[f"{prefix}{metric}/{split}"] = (
1573
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1574
+ / count_sum
1575
+ )
1576
+ # delete obsolete metric
1577
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1578
+ del self._stored_metrics[train_eval][f"count/{split}"]
1579
+ # calculate reward margin
1580
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1581
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1582
+ # Add averaged stored metrics to logs
1583
+ for key, metrics in self._stored_metrics[train_eval].items():
1584
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1585
+ del self._stored_metrics[train_eval]
1586
+
1587
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1588
+ return super().log(logs, start_time)
1589
+ else: # transformers<=4.46
1590
+ return super().log(logs)
1591
+
1592
+ def create_model_card(
1593
+ self,
1594
+ model_name: Optional[str] = None,
1595
+ dataset_name: Optional[str] = None,
1596
+ tags: Union[str, list[str], None] = None,
1597
+ ):
1598
+ """
1599
+ Creates a draft of a model card using the information available to the `Trainer`.
1600
+
1601
+ Args:
1602
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1603
+ Name of the model.
1604
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1605
+ Name of the dataset used for training.
1606
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1607
+ Tags to be associated with the model card.
1608
+ """
1609
+ if not self.is_world_process_zero():
1610
+ return
1611
+
1612
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1613
+ base_model = self.model.config._name_or_path
1614
+ else:
1615
+ base_model = None
1616
+
1617
+ tags = tags or []
1618
+ if isinstance(tags, str):
1619
+ tags = [tags]
1620
+
1621
+ if hasattr(self.model.config, "unsloth_version"):
1622
+ tags.append("unsloth")
1623
+
1624
+ citation = textwrap.dedent("""\
1625
+ @article{jung2024binary,
1626
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1627
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1628
+ year = 2024,
1629
+ eprint = {arXiv:2404.04656}
1630
+ }""")
1631
+
1632
+ model_card = generate_model_card(
1633
+ base_model=base_model,
1634
+ model_name=model_name,
1635
+ hub_model_id=self.hub_model_id,
1636
+ dataset_name=dataset_name,
1637
+ tags=tags,
1638
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1639
+ comet_url=get_comet_experiment_url(),
1640
+ trainer_name="BCO",
1641
+ trainer_citation=citation,
1642
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1643
+ paper_id="2404.04656",
1644
+ )
1645
+
1646
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1647
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1648
+ """
1649
+
1650
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1651
+
1652
+ Args:
1653
+ model (`transformers.PreTrainedModel`):
1654
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1655
+ ref_model (`PreTrainedModelWrapper`):
1656
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1657
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1658
+ args (`BCOConfig`):
1659
+ The arguments to use for training.
1660
+ train_dataset (`datasets.Dataset`):
1661
+ The dataset to use for training.
1662
+ eval_dataset (`datasets.Dataset`):
1663
+ The dataset to use for evaluation.
1664
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1665
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1666
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1667
+ reuse the fine-tuned model.
1668
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1669
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1670
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1671
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1672
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1673
+ callbacks (`list[transformers.TrainerCallback]`):
1674
+ The callbacks to use for training.
1675
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1676
+ The optimizer and scheduler to use for training.
1677
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1678
+ The function to use to preprocess the logits before computing the metrics.
1679
+ peft_config (`dict`, defaults to `None`):
1680
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1681
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1682
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1683
+ a dictionary string to metric values.
1684
+ model_adapter_name (`str`, defaults to `None`):
1685
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1686
+ ref_adapter_name (`str`, defaults to `None`):
1687
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1688
+
1689
+ """
1690
+ def __init__(
1691
+ self,
1692
+ model = None,
1693
+ ref_model = None,
1694
+ args = None,
1695
+ train_dataset = None,
1696
+ eval_dataset = None,
1697
+ processing_class = None,
1698
+ data_collator = None,
1699
+ model_init = None,
1700
+ callbacks = None,
1701
+ preprocess_logits_for_metrics = None,
1702
+ peft_config = None,
1703
+ compute_metrics = None,
1704
+ model_adapter_name = None,
1705
+ ref_adapter_name = None,
1706
+ embedding_func = None,
1707
+ embedding_tokenizer = None,
1708
+ **kwargs
1709
+ ):
1710
+ if args is None: args = UnslothBCOConfig()
1711
+ use_bf16 = getattr(args, 'bf16', False)
1712
+ use_fp16 = getattr(args, 'fp16', False)
1713
+ force_float32 = False
1714
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1715
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1716
+ force_float32 = True
1717
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1718
+ dtype = getattr(model.config, 'torch_dtype', None)
1719
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1720
+ from unsloth_zoo.utils import _get_dtype
1721
+ dtype = _get_dtype(dtype)
1722
+ float16 = dtype == torch.float16
1723
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1724
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1725
+ if force_float32:
1726
+ args.fp16 = False
1727
+ args.bf16 = False
1728
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1729
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1730
+ args.fp16 = float16
1731
+ args.bf16 = not float16
1732
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1733
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1734
+ args.eval_strategy = 'steps'
1735
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1736
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1737
+ if ga_steps is not None and ga_steps > 1:
1738
+ from transformers import __version__ as transformers_version
1739
+ if Version(transformers_version) <= Version('4.45.2'):
1740
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1741
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1742
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1743
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1744
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1745
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1746
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1747
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1748
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1749
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1750
+ if force_float32:
1751
+ args.bf16_full_eval = False
1752
+ args.fp16_full_eval = False
1753
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1754
+ args.bf16_full_eval = True
1755
+ args.fp16_full_eval = False
1756
+ elif not bf16_full_eval and not fp16_full_eval:
1757
+ args.bf16_full_eval = args.bf16
1758
+ args.fp16_full_eval = args.fp16
1759
+ _output_logits = False
1760
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1761
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1762
+ if _output_logits:
1763
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1764
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1765
+ pass
1766
+ else:
1767
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1768
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1769
+ if args_max_seq_length is None and model_max_seq_length is not None:
1770
+ max_seq_length = model.max_seq_length
1771
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1772
+ if model is not None and hasattr(model, 'for_training'):
1773
+ model.for_training()
1774
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1775
+ if 'processing_class' in locals():
1776
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1777
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1778
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1779
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1780
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1781
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1782
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1783
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1784
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1785
+ else:
1786
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1787
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1788
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1789
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1790
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1791
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1792
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1793
+ else:
1794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1795
+ other_metrics = []
1796
+
1797
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1798
+ PatchRLStatistics('bco_trainer', other_metrics)
1799
+
1800
+ super().__init__(
1801
+ model = model,
1802
+ ref_model = ref_model,
1803
+ args = args,
1804
+ train_dataset = train_dataset,
1805
+ eval_dataset = eval_dataset,
1806
+ processing_class = processing_class,
1807
+ data_collator = data_collator,
1808
+ model_init = model_init,
1809
+ callbacks = callbacks,
1810
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1811
+ peft_config = peft_config,
1812
+ compute_metrics = compute_metrics,
1813
+ model_adapter_name = model_adapter_name,
1814
+ ref_adapter_name = ref_adapter_name,
1815
+ embedding_func = embedding_func,
1816
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1817
+ if hasattr(self, 'neftune_hook_handle'):
1818
+ self.neftune_hook_handle.remove()
1819
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1820
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1821
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1822
+ pass
1823
+
1824
+ pass
unsloth_compiled_cache/UnslothCPOTrainer.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothCPOConfig(CPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`CPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
67
+ the [paper](https://huggingface.co/papers/2310.12036).
68
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
69
+ Label smoothing factor. This argument is required if you want to use the default data collator.
70
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
71
+ Type of loss to use. Possible values are:
72
+
73
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
74
+ - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
75
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
76
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
77
+
78
+ disable_dropout (`bool`, *optional*, defaults to `True`):
79
+ Whether to disable dropout in the model.
80
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
81
+ Weight of the BC regularizer in CPO training.
82
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
83
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
84
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
85
+ Label pad token id. This argument is required if you want to use the default data collator.
86
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
87
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
88
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
89
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
90
+ This argument is required if you want to use the default data collator.
91
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
92
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
93
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
94
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
95
+ you need to specify if the model returned by the callable is an encoder-decoder model.
96
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
98
+ string.
99
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+
102
+ """
103
+ vllm_sampling_params: Optional[Any] = field(
104
+ default = None,
105
+ metadata = {'help': 'vLLM SamplingParams'},
106
+ )
107
+ unsloth_num_chunks : Optional[int] = field(
108
+ default = -1,
109
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
+ )
111
+ def __init__(
112
+ self,
113
+ output_dir = None,
114
+ overwrite_output_dir = None,
115
+ do_train = False,
116
+ do_eval = False,
117
+ do_predict = False,
118
+ eval_strategy = 'no',
119
+ prediction_loss_only = False,
120
+ per_device_train_batch_size = 4,
121
+ per_device_eval_batch_size = 4,
122
+ per_gpu_train_batch_size = None,
123
+ per_gpu_eval_batch_size = None,
124
+ gradient_accumulation_steps = 2,
125
+ eval_accumulation_steps = 2,
126
+ eval_delay = 0,
127
+ torch_empty_cache_steps = 250,
128
+ learning_rate = 5e-05,
129
+ weight_decay = 0.01,
130
+ adam_beta1 = 0.9,
131
+ adam_beta2 = 0.999,
132
+ adam_epsilon = 1e-08,
133
+ max_grad_norm = 1.0,
134
+ num_train_epochs = 3.0,
135
+ max_steps = -1,
136
+ lr_scheduler_type = 'linear',
137
+ warmup_ratio = 0.1,
138
+ warmup_steps = 0,
139
+ log_level = 'passive',
140
+ log_level_replica = 'warning',
141
+ log_on_each_node = True,
142
+ logging_dir = None,
143
+ logging_strategy = 'steps',
144
+ logging_first_step = False,
145
+ logging_steps = 1,
146
+ logging_nan_inf_filter = False,
147
+ save_strategy = 'steps',
148
+ save_steps = 500,
149
+ save_total_limit = None,
150
+ save_safetensors = True,
151
+ save_on_each_node = False,
152
+ save_only_model = False,
153
+ restore_callback_states_from_checkpoint = False,
154
+ no_cuda = False,
155
+ use_cpu = False,
156
+ use_mps_device = False,
157
+ seed = 3407,
158
+ data_seed = 3407,
159
+ jit_mode_eval = False,
160
+ use_ipex = False,
161
+ bf16 = False,
162
+ fp16 = False,
163
+ fp16_opt_level = 'O1',
164
+ half_precision_backend = 'auto',
165
+ bf16_full_eval = False,
166
+ fp16_full_eval = False,
167
+ tf32 = None,
168
+ local_rank = -1,
169
+ ddp_backend = None,
170
+ tpu_num_cores = None,
171
+ tpu_metrics_debug = False,
172
+ debug = '',
173
+ dataloader_drop_last = False,
174
+ eval_steps = None,
175
+ dataloader_num_workers = 0,
176
+ dataloader_prefetch_factor = None,
177
+ past_index = -1,
178
+ run_name = None,
179
+ disable_tqdm = None,
180
+ remove_unused_columns = True,
181
+ label_names = None,
182
+ load_best_model_at_end = False,
183
+ metric_for_best_model = None,
184
+ greater_is_better = None,
185
+ ignore_data_skip = False,
186
+ fsdp = '',
187
+ fsdp_min_num_params = 0,
188
+ fsdp_config = None,
189
+ tp_size = 0,
190
+ fsdp_transformer_layer_cls_to_wrap = None,
191
+ accelerator_config = None,
192
+ deepspeed = None,
193
+ label_smoothing_factor = 0.0,
194
+ optim = 'adamw_8bit',
195
+ optim_args = None,
196
+ adafactor = False,
197
+ group_by_length = False,
198
+ length_column_name = 'length',
199
+ report_to = None,
200
+ ddp_find_unused_parameters = None,
201
+ ddp_bucket_cap_mb = None,
202
+ ddp_broadcast_buffers = None,
203
+ dataloader_pin_memory = True,
204
+ dataloader_persistent_workers = False,
205
+ skip_memory_metrics = True,
206
+ use_legacy_prediction_loop = False,
207
+ push_to_hub = False,
208
+ resume_from_checkpoint = None,
209
+ hub_model_id = None,
210
+ hub_strategy = 'every_save',
211
+ hub_token = None,
212
+ hub_private_repo = None,
213
+ hub_always_push = False,
214
+ gradient_checkpointing = False,
215
+ gradient_checkpointing_kwargs = None,
216
+ include_inputs_for_metrics = False,
217
+ eval_do_concat_batches = True,
218
+ fp16_backend = 'auto',
219
+ evaluation_strategy = None,
220
+ push_to_hub_model_id = None,
221
+ push_to_hub_organization = None,
222
+ push_to_hub_token = None,
223
+ mp_parameters = '',
224
+ auto_find_batch_size = False,
225
+ full_determinism = False,
226
+ torchdynamo = None,
227
+ ray_scope = 'last',
228
+ ddp_timeout = 1800,
229
+ torch_compile = False,
230
+ torch_compile_backend = None,
231
+ torch_compile_mode = None,
232
+ dispatch_batches = None,
233
+ split_batches = None,
234
+ include_tokens_per_second = False,
235
+ include_num_input_tokens_seen = False,
236
+ neftune_noise_alpha = None,
237
+ optim_target_modules = None,
238
+ batch_eval_metrics = False,
239
+ eval_on_start = False,
240
+ use_liger_kernel = False,
241
+ eval_use_gather_object = False,
242
+ average_tokens_across_devices = False,
243
+ max_length = 1024,
244
+ max_prompt_length = 512,
245
+ max_completion_length = None,
246
+ beta = 0.1,
247
+ label_smoothing = 0.0,
248
+ loss_type = 'sigmoid',
249
+ disable_dropout = True,
250
+ cpo_alpha = 1.0,
251
+ simpo_gamma = 0.5,
252
+ label_pad_token_id = -100,
253
+ padding_value = None,
254
+ truncation_mode = 'keep_end',
255
+ generate_during_eval = False,
256
+ is_encoder_decoder = None,
257
+ model_init_kwargs = None,
258
+ dataset_num_proc = None,
259
+ vllm_sampling_params = None,
260
+ unsloth_num_chunks = -1,
261
+ **kwargs,
262
+ ):
263
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
264
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
265
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
266
+ output_dir = 'unsloth_training_checkpoints'
267
+ save_strategy = 'no'
268
+ if dataset_num_proc is None:
269
+ from multiprocessing import cpu_count
270
+ dataset_num_proc = cpu_count()
271
+
272
+ super().__init__(
273
+ output_dir = output_dir,
274
+ overwrite_output_dir = overwrite_output_dir,
275
+ do_train = do_train,
276
+ do_eval = do_eval,
277
+ do_predict = do_predict,
278
+ eval_strategy = eval_strategy,
279
+ prediction_loss_only = prediction_loss_only,
280
+ per_device_train_batch_size = per_device_train_batch_size,
281
+ per_device_eval_batch_size = per_device_eval_batch_size,
282
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
283
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
284
+ gradient_accumulation_steps = gradient_accumulation_steps,
285
+ eval_accumulation_steps = eval_accumulation_steps,
286
+ eval_delay = eval_delay,
287
+ torch_empty_cache_steps = torch_empty_cache_steps,
288
+ learning_rate = learning_rate,
289
+ weight_decay = weight_decay,
290
+ adam_beta1 = adam_beta1,
291
+ adam_beta2 = adam_beta2,
292
+ adam_epsilon = adam_epsilon,
293
+ max_grad_norm = max_grad_norm,
294
+ num_train_epochs = num_train_epochs,
295
+ max_steps = max_steps,
296
+ lr_scheduler_type = lr_scheduler_type,
297
+ warmup_ratio = warmup_ratio,
298
+ warmup_steps = warmup_steps,
299
+ log_level = log_level,
300
+ log_level_replica = log_level_replica,
301
+ log_on_each_node = log_on_each_node,
302
+ logging_dir = logging_dir,
303
+ logging_strategy = logging_strategy,
304
+ logging_first_step = logging_first_step,
305
+ logging_steps = logging_steps,
306
+ logging_nan_inf_filter = logging_nan_inf_filter,
307
+ save_strategy = save_strategy,
308
+ save_steps = save_steps,
309
+ save_total_limit = save_total_limit,
310
+ save_safetensors = save_safetensors,
311
+ save_on_each_node = save_on_each_node,
312
+ save_only_model = save_only_model,
313
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
314
+ no_cuda = no_cuda,
315
+ use_cpu = use_cpu,
316
+ use_mps_device = use_mps_device,
317
+ seed = seed,
318
+ data_seed = data_seed,
319
+ jit_mode_eval = jit_mode_eval,
320
+ use_ipex = use_ipex,
321
+ bf16 = bf16,
322
+ fp16 = fp16,
323
+ fp16_opt_level = fp16_opt_level,
324
+ half_precision_backend = half_precision_backend,
325
+ bf16_full_eval = bf16_full_eval,
326
+ fp16_full_eval = fp16_full_eval,
327
+ tf32 = tf32,
328
+ local_rank = local_rank,
329
+ ddp_backend = ddp_backend,
330
+ tpu_num_cores = tpu_num_cores,
331
+ tpu_metrics_debug = tpu_metrics_debug,
332
+ debug = debug,
333
+ dataloader_drop_last = dataloader_drop_last,
334
+ eval_steps = eval_steps,
335
+ dataloader_num_workers = dataloader_num_workers,
336
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
337
+ past_index = past_index,
338
+ run_name = run_name,
339
+ disable_tqdm = disable_tqdm,
340
+ remove_unused_columns = remove_unused_columns,
341
+ label_names = label_names,
342
+ load_best_model_at_end = load_best_model_at_end,
343
+ metric_for_best_model = metric_for_best_model,
344
+ greater_is_better = greater_is_better,
345
+ ignore_data_skip = ignore_data_skip,
346
+ fsdp = fsdp,
347
+ fsdp_min_num_params = fsdp_min_num_params,
348
+ fsdp_config = fsdp_config,
349
+ tp_size = tp_size,
350
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
351
+ accelerator_config = accelerator_config,
352
+ deepspeed = deepspeed,
353
+ label_smoothing_factor = label_smoothing_factor,
354
+ optim = optim,
355
+ optim_args = optim_args,
356
+ adafactor = adafactor,
357
+ group_by_length = group_by_length,
358
+ length_column_name = length_column_name,
359
+ report_to = report_to,
360
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
361
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
362
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
363
+ dataloader_pin_memory = dataloader_pin_memory,
364
+ dataloader_persistent_workers = dataloader_persistent_workers,
365
+ skip_memory_metrics = skip_memory_metrics,
366
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
367
+ push_to_hub = push_to_hub,
368
+ resume_from_checkpoint = resume_from_checkpoint,
369
+ hub_model_id = hub_model_id,
370
+ hub_strategy = hub_strategy,
371
+ hub_token = hub_token,
372
+ hub_private_repo = hub_private_repo,
373
+ hub_always_push = hub_always_push,
374
+ gradient_checkpointing = gradient_checkpointing,
375
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
376
+ include_inputs_for_metrics = include_inputs_for_metrics,
377
+ eval_do_concat_batches = eval_do_concat_batches,
378
+ fp16_backend = fp16_backend,
379
+ evaluation_strategy = evaluation_strategy,
380
+ push_to_hub_model_id = push_to_hub_model_id,
381
+ push_to_hub_organization = push_to_hub_organization,
382
+ push_to_hub_token = push_to_hub_token,
383
+ mp_parameters = mp_parameters,
384
+ auto_find_batch_size = auto_find_batch_size,
385
+ full_determinism = full_determinism,
386
+ torchdynamo = torchdynamo,
387
+ ray_scope = ray_scope,
388
+ ddp_timeout = ddp_timeout,
389
+ torch_compile = torch_compile,
390
+ torch_compile_backend = torch_compile_backend,
391
+ torch_compile_mode = torch_compile_mode,
392
+ dispatch_batches = dispatch_batches,
393
+ split_batches = split_batches,
394
+ include_tokens_per_second = include_tokens_per_second,
395
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
396
+ neftune_noise_alpha = neftune_noise_alpha,
397
+ optim_target_modules = optim_target_modules,
398
+ batch_eval_metrics = batch_eval_metrics,
399
+ eval_on_start = eval_on_start,
400
+ use_liger_kernel = use_liger_kernel,
401
+ eval_use_gather_object = eval_use_gather_object,
402
+ average_tokens_across_devices = average_tokens_across_devices,
403
+ max_length = max_length,
404
+ max_prompt_length = max_prompt_length,
405
+ max_completion_length = max_completion_length,
406
+ beta = beta,
407
+ label_smoothing = label_smoothing,
408
+ loss_type = loss_type,
409
+ disable_dropout = disable_dropout,
410
+ cpo_alpha = cpo_alpha,
411
+ simpo_gamma = simpo_gamma,
412
+ label_pad_token_id = label_pad_token_id,
413
+ padding_value = padding_value,
414
+ truncation_mode = truncation_mode,
415
+ generate_during_eval = generate_during_eval,
416
+ is_encoder_decoder = is_encoder_decoder,
417
+ model_init_kwargs = model_init_kwargs,
418
+ dataset_num_proc = dataset_num_proc,**kwargs)
419
+ self.vllm_sampling_params = vllm_sampling_params
420
+ self.unsloth_num_chunks = unsloth_num_chunks
421
+ pass
422
+
423
+ class _UnslothCPOTrainer(Trainer):
424
+ r""""""
425
+
426
+ _tag_names = ["trl", "cpo"]
427
+
428
+ def __init__(
429
+ self,
430
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
431
+ args: Optional[CPOConfig] = None,
432
+ data_collator: Optional[DataCollator] = None,
433
+ train_dataset: Optional[Dataset] = None,
434
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
435
+ processing_class: Optional[
436
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
437
+ ] = None,
438
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
439
+ callbacks: Optional[list[TrainerCallback]] = None,
440
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
441
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
442
+ peft_config: Optional[dict] = None,
443
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
444
+ ):
445
+ if args.model_init_kwargs is None:
446
+ model_init_kwargs = {}
447
+ elif not isinstance(model, str):
448
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
449
+ else:
450
+ model_init_kwargs = args.model_init_kwargs
451
+ torch_dtype = model_init_kwargs.get("torch_dtype")
452
+ if torch_dtype is not None:
453
+ # Convert to `torch.dtype` if an str is passed
454
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
455
+ torch_dtype = getattr(torch, torch_dtype)
456
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
457
+ raise ValueError(
458
+ f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
459
+ )
460
+ model_init_kwargs["torch_dtype"] = torch_dtype
461
+
462
+ if isinstance(model, str):
463
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
464
+
465
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
466
+ # has been called in order to properly call autocast if needed.
467
+ self._peft_has_been_casted_to_bf16 = False
468
+
469
+ if not is_peft_available() and peft_config is not None:
470
+ raise ValueError(
471
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
472
+ )
473
+ elif is_peft_available() and peft_config is not None:
474
+ # if model is a peft model and we have a peft_config, we merge and unload it first
475
+ if isinstance(model, PeftModel):
476
+ model = model.merge_and_unload()
477
+
478
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
479
+ _support_gc_kwargs = hasattr(
480
+ args, "gradient_checkpointing_kwargs"
481
+ ) and "gradient_checkpointing_kwargs" in list(
482
+ inspect.signature(prepare_model_for_kbit_training).parameters
483
+ )
484
+
485
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
486
+
487
+ if _support_gc_kwargs:
488
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
489
+
490
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
491
+ elif getattr(args, "gradient_checkpointing", False):
492
+ # For backward compatibility with older versions of transformers
493
+ if hasattr(model, "enable_input_require_grads"):
494
+ model.enable_input_require_grads()
495
+ else:
496
+
497
+ def make_inputs_require_grad(module, input, output):
498
+ output.requires_grad_(True)
499
+
500
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
501
+
502
+ # get peft model with the given config
503
+ model = model
504
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
505
+ peft_module_casting_to_bf16(model)
506
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
507
+ self._peft_has_been_casted_to_bf16 = True
508
+
509
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
510
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
511
+ # fail or completely fail.
512
+ elif getattr(args, "gradient_checkpointing", False):
513
+ # For backward compatibility with older versions of transformers
514
+ if hasattr(model, "enable_input_require_grads"):
515
+ model.enable_input_require_grads()
516
+ else:
517
+
518
+ def make_inputs_require_grad(module, input, output):
519
+ output.requires_grad_(True)
520
+
521
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
522
+
523
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
524
+ raise ValueError(
525
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
526
+ " Please install `wandb` or `comet-ml` to resolve."
527
+ )
528
+
529
+ if model is not None:
530
+ self.is_encoder_decoder = model.config.is_encoder_decoder
531
+ elif args.is_encoder_decoder is None:
532
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
533
+ else:
534
+ self.is_encoder_decoder = args.is_encoder_decoder
535
+
536
+ if self.is_encoder_decoder:
537
+ self.decoder_start_token_id = model.config.decoder_start_token_id
538
+ self.pad_token_id = model.config.pad_token_id
539
+
540
+ if processing_class is None:
541
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
542
+ if args.max_length is None:
543
+ warnings.warn(
544
+ "`max_length` is not set in the CPOConfig's init"
545
+ " it will default to `512` by default, but you should do it yourself in the future.",
546
+ UserWarning,
547
+ )
548
+ max_length = 512
549
+ else:
550
+ max_length = args.max_length
551
+ if args.max_prompt_length is None:
552
+ warnings.warn(
553
+ "`max_prompt_length` is not set in the CPOConfig's init"
554
+ " it will default to `128` by default, but you should do it yourself in the future.",
555
+ UserWarning,
556
+ )
557
+ max_prompt_length = 128
558
+ else:
559
+ max_prompt_length = args.max_prompt_length
560
+
561
+ if args.max_completion_length is None and self.is_encoder_decoder:
562
+ warnings.warn(
563
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
564
+ " it will default to `128` by default, but you should do it yourself in the future.",
565
+ UserWarning,
566
+ )
567
+ max_completion_length = 128
568
+ else:
569
+ max_completion_length = args.max_completion_length
570
+
571
+ if data_collator is None:
572
+ data_collator = DPODataCollatorWithPadding(
573
+ pad_token_id=processing_class.pad_token_id,
574
+ label_pad_token_id=args.label_pad_token_id,
575
+ is_encoder_decoder=self.is_encoder_decoder,
576
+ )
577
+
578
+ if args.remove_unused_columns:
579
+ args.remove_unused_columns = False
580
+ # warn users
581
+ warnings.warn(
582
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
583
+ " we have set it for you, but you should do it yourself in the future.",
584
+ UserWarning,
585
+ )
586
+
587
+ self.use_dpo_data_collator = True
588
+ else:
589
+ self.use_dpo_data_collator = False
590
+
591
+ # Disable dropout in the model
592
+ if args.disable_dropout:
593
+ disable_dropout_in_model(model)
594
+
595
+ self.max_length = max_length
596
+ self.generate_during_eval = args.generate_during_eval
597
+ self.label_pad_token_id = args.label_pad_token_id
598
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
599
+ self.max_prompt_length = max_prompt_length
600
+ self.truncation_mode = args.truncation_mode
601
+ self.max_completion_length = max_completion_length
602
+ self.processing_class = processing_class
603
+
604
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
605
+ warnings.warn(
606
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
607
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
608
+ UserWarning,
609
+ )
610
+ if args.loss_type == "kto_pair":
611
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
612
+
613
+ self.beta = args.beta
614
+ self.label_smoothing = args.label_smoothing
615
+ self.loss_type = args.loss_type
616
+ self.cpo_alpha = args.cpo_alpha
617
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
618
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
619
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
620
+ warnings.warn(
621
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
622
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
623
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
624
+ "loss.",
625
+ UserWarning,
626
+ )
627
+
628
+ if args.loss_type == "simpo":
629
+ self.simpo_gamma = args.simpo_gamma
630
+
631
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
632
+
633
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
634
+ # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
635
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
636
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
637
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
638
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
639
+ # that the warning has already been issued.
640
+ model.warnings_issued["estimate_tokens"] = True
641
+
642
+ # Compute that only on the main process for faster data processing.
643
+ # see: https://github.com/huggingface/trl/pull/1255
644
+ with PartialState().local_main_process_first():
645
+ # Extract the prompt if needed, and apply the chat template if needed
646
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
647
+ train_dataset = train_dataset.map(
648
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
649
+ )
650
+ if eval_dataset is not None:
651
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
652
+ eval_dataset = eval_dataset.map(
653
+ maybe_apply_chat_template,
654
+ fn_kwargs={"tokenizer": processing_class},
655
+ num_proc=args.dataset_num_proc,
656
+ )
657
+
658
+ # tokenize the dataset
659
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
660
+ if eval_dataset is not None:
661
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
662
+
663
+ super().__init__(
664
+ model=model,
665
+ args=args,
666
+ data_collator=data_collator,
667
+ train_dataset=train_dataset,
668
+ eval_dataset=eval_dataset,
669
+ processing_class=processing_class,
670
+ model_init=model_init,
671
+ compute_metrics=compute_metrics,
672
+ callbacks=callbacks,
673
+ optimizers=optimizers,
674
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
675
+ )
676
+
677
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
678
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
679
+ # self.model_accepts_loss_kwargs to False to enable scaling.
680
+ self.model_accepts_loss_kwargs = False
681
+
682
+ # Add tags for models that have been loaded with the correct transformers version
683
+ if hasattr(self.model, "add_model_tags"):
684
+ self.model.add_model_tags(self._tag_names)
685
+
686
+ if not hasattr(self, "accelerator"):
687
+ raise AttributeError(
688
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
689
+ )
690
+
691
+ def build_tokenized_answer(self, prompt, answer):
692
+ """
693
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
694
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
695
+ Reference:
696
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
697
+ """
698
+
699
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
700
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
701
+
702
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
703
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
704
+
705
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
706
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
707
+
708
+ # Prepare input tokens for token by token comparison
709
+ full_input_ids = np.array(full_tokenized["input_ids"])
710
+
711
+ if len(full_input_ids) != len(full_concat_input_ids):
712
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
713
+
714
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
715
+ # can be merged together when tokenizing prompt+answer. This could result
716
+ # on the last token from the prompt being different when tokenized on its own
717
+ # vs when done as prompt+answer.
718
+ response_token_ids_start_idx = len(prompt_input_ids)
719
+
720
+ # If tokenized prompt is different than both prompt+answer, then it means the
721
+ # last token has changed due to merging.
722
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
723
+ response_token_ids_start_idx -= 1
724
+
725
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
726
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
727
+
728
+ if len(prompt_input_ids) != len(prompt_attention_mask):
729
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
730
+
731
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
732
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
733
+
734
+ return dict(
735
+ prompt_input_ids=prompt_input_ids,
736
+ prompt_attention_mask=prompt_attention_mask,
737
+ input_ids=answer_input_ids,
738
+ attention_mask=answer_attention_mask,
739
+ )
740
+
741
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
742
+ """Tokenize a single row from a CPO specific dataset.
743
+
744
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
745
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
746
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
747
+
748
+ We also create the labels for the chosen/rejected responses, which are of length equal to
749
+ the sum of the length of the prompt and the chosen/rejected response, with
750
+ label_pad_token_id for the prompt tokens.
751
+ """
752
+ batch = {}
753
+ prompt = feature["prompt"]
754
+ chosen = feature["chosen"]
755
+ rejected = feature["rejected"]
756
+
757
+ if not self.is_encoder_decoder:
758
+ # Check issues below for more details
759
+ # 1. https://github.com/huggingface/trl/issues/907
760
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
761
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
762
+
763
+ if not isinstance(prompt, str):
764
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
765
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
766
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
767
+
768
+ if not isinstance(chosen, str):
769
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
770
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
771
+
772
+ if not isinstance(rejected, str):
773
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
774
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
775
+
776
+ # Last prompt token might get merged by tokenizer and
777
+ # it should not be included for generation if that happens
778
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
779
+
780
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
781
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
782
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
783
+
784
+ for k, v in prompt_tokens.items():
785
+ prompt_tokens[k] = v[:prompt_len_input_ids]
786
+
787
+ # Make sure prompts only have one different token at most an
788
+ # and length only differs by 1 at most
789
+ num_diff_tokens = sum(
790
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
791
+ )
792
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
793
+ if num_diff_tokens > 1 or num_diff_len > 1:
794
+ raise ValueError(
795
+ "Chosen and rejected prompt_input_ids might only differ on the "
796
+ "last token due to tokenizer merge ops."
797
+ )
798
+
799
+ # add BOS token to head of prompt. Avoid adding if it's already there
800
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
801
+ self.processing_class.bos_token_id,
802
+ prompt_len_input_ids,
803
+ prompt_tokens,
804
+ chosen_prompt_len_input_ids,
805
+ chosen_tokens,
806
+ rejected_prompt_len_input_ids,
807
+ rejected_tokens,
808
+ )
809
+
810
+ # add EOS token to end of answer. Avoid adding if it's already there
811
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
812
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
813
+ )
814
+
815
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
816
+
817
+ # if combined sequence is too long, truncate the prompt
818
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
819
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
820
+ if self.truncation_mode == "keep_start":
821
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
822
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
823
+ elif self.truncation_mode == "keep_end":
824
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
825
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
826
+ else:
827
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
828
+
829
+ # if that's still too long, truncate the response
830
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
831
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
832
+ for k in ["input_ids", "attention_mask"]:
833
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
834
+
835
+ # Create labels
836
+ chosen_sequence_tokens = {
837
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
838
+ }
839
+ rejected_sequence_tokens = {
840
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
841
+ }
842
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
843
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
844
+ self.label_pad_token_id
845
+ ] * len(chosen_tokens["prompt_input_ids"])
846
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
847
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
848
+ self.label_pad_token_id
849
+ ] * len(rejected_tokens["prompt_input_ids"])
850
+
851
+ for k, toks in {
852
+ "chosen_": chosen_sequence_tokens,
853
+ "rejected_": rejected_sequence_tokens,
854
+ "": prompt_tokens,
855
+ }.items():
856
+ for type_key, tokens in toks.items():
857
+ if type_key == "token_type_ids":
858
+ continue
859
+ batch[f"{k}{type_key}"] = tokens
860
+
861
+ else:
862
+ chosen_tokens = self.processing_class(
863
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
864
+ )
865
+ rejected_tokens = self.processing_class(
866
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
867
+ )
868
+ prompt_tokens = self.processing_class(
869
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
870
+ )
871
+
872
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
873
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
874
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
875
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
876
+
877
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
878
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
879
+ labels=torch.tensor(batch["rejected_labels"])
880
+ )
881
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
882
+ labels=torch.tensor(batch["chosen_labels"])
883
+ )
884
+
885
+ return batch
886
+
887
+ @staticmethod
888
+ def concatenated_inputs(
889
+ batch: dict[str, Union[list, torch.LongTensor]],
890
+ is_encoder_decoder: bool = False,
891
+ label_pad_token_id: int = -100,
892
+ padding_value: int = 0,
893
+ device: Optional[torch.device] = None,
894
+ ) -> dict[str, torch.LongTensor]:
895
+ """Concatenate the chosen and rejected inputs into a single tensor.
896
+
897
+ Args:
898
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
899
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
900
+ label_pad_token_id: The label pad token id.
901
+ padding_value: The padding value to use for the concatenated inputs_ids.
902
+ device: The device for the concatenated inputs.
903
+
904
+ Returns:
905
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
906
+ """
907
+ concatenated_batch = {}
908
+
909
+ if is_encoder_decoder:
910
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
911
+ else:
912
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
913
+
914
+ for k in batch:
915
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
916
+ if "labels" in k or is_encoder_decoder:
917
+ pad_value = label_pad_token_id
918
+ elif k.endswith("_input_ids"):
919
+ pad_value = padding_value
920
+ elif k.endswith("_attention_mask"):
921
+ pad_value = 0
922
+ concatenated_key = k.replace("chosen", "concatenated")
923
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
924
+ for k in batch:
925
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
926
+ if "labels" in k or is_encoder_decoder:
927
+ pad_value = label_pad_token_id
928
+ elif k.endswith("_input_ids"):
929
+ pad_value = padding_value
930
+ elif k.endswith("_attention_mask"):
931
+ pad_value = 0
932
+ concatenated_key = k.replace("rejected", "concatenated")
933
+ concatenated_batch[concatenated_key] = torch.cat(
934
+ (
935
+ concatenated_batch[concatenated_key],
936
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
937
+ ),
938
+ dim=0,
939
+ ).to(device=device)
940
+
941
+ if is_encoder_decoder:
942
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
943
+ concatenated_batch["concatenated_attention_mask"] = (
944
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
945
+ )
946
+
947
+ return concatenated_batch
948
+
949
+ def cpo_loss(
950
+ self,
951
+ policy_chosen_logps: torch.FloatTensor,
952
+ policy_rejected_logps: torch.FloatTensor,
953
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
954
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
955
+
956
+ Args:
957
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
958
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
959
+
960
+ Returns:
961
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
962
+ The losses tensor contains the CPO loss for each example in the batch.
963
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
964
+ """
965
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
966
+
967
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
968
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
969
+ # calculates a conservative CPO loss.
970
+
971
+ if self.loss_type == "simpo":
972
+ gamma_logratios = self.simpo_gamma / self.beta
973
+ logits = logits - gamma_logratios
974
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
975
+ losses = (
976
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
977
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
978
+ )
979
+ elif self.loss_type == "sigmoid":
980
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
981
+ losses = (
982
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
983
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
984
+ )
985
+ elif self.loss_type == "hinge":
986
+ losses = torch.relu(1 - self.beta * logits)
987
+ elif self.loss_type == "ipo":
988
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
989
+ losses = (logits - 1 / (2 * self.beta)) ** 2
990
+ else:
991
+ raise ValueError(
992
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
993
+ )
994
+
995
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
996
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
997
+
998
+ return losses, chosen_rewards, rejected_rewards
999
+
1000
+ @staticmethod
1001
+ def get_batch_logps(
1002
+ logits: torch.FloatTensor,
1003
+ labels: torch.LongTensor,
1004
+ average_log_prob: bool = False,
1005
+ label_pad_token_id: int = -100,
1006
+ is_encoder_decoder: bool = False,
1007
+ ) -> torch.FloatTensor:
1008
+ """Compute the log probabilities of the given labels under the given logits.
1009
+
1010
+ Args:
1011
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1012
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1013
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1014
+ label_pad_token_id: The label pad token id.
1015
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
1016
+
1017
+ Returns:
1018
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1019
+ """
1020
+ if logits.shape[:-1] != labels.shape:
1021
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1022
+
1023
+ if not is_encoder_decoder:
1024
+ labels = labels[:, 1:].clone()
1025
+ logits = logits[:, :-1, :]
1026
+ loss_mask = labels != label_pad_token_id
1027
+
1028
+ # dummy token; we'll ignore the losses on these tokens later
1029
+ labels[labels == label_pad_token_id] = 0
1030
+
1031
+ per_token_logps = selective_log_softmax(logits, labels)
1032
+
1033
+ if average_log_prob:
1034
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1035
+ else:
1036
+ return (per_token_logps * loss_mask).sum(-1)
1037
+
1038
+ def concatenated_forward(
1039
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1040
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1041
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1042
+
1043
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1044
+ """
1045
+ concatenated_batch = self.concatenated_inputs(
1046
+ batch,
1047
+ is_encoder_decoder=self.is_encoder_decoder,
1048
+ label_pad_token_id=self.label_pad_token_id,
1049
+ padding_value=self.padding_value,
1050
+ device=self.accelerator.device,
1051
+ )
1052
+ len_chosen = batch["chosen_labels"].shape[0]
1053
+
1054
+ model_kwargs = (
1055
+ {
1056
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1057
+ }
1058
+ if self.is_encoder_decoder
1059
+ else {}
1060
+ )
1061
+
1062
+ if self.aux_loss_enabled:
1063
+ model_kwargs["output_router_logits"] = True
1064
+
1065
+ outputs = model(
1066
+ concatenated_batch["concatenated_input_ids"],
1067
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1068
+ use_cache=False,
1069
+ **model_kwargs,
1070
+ )
1071
+ all_logits = outputs.logits
1072
+
1073
+ def cross_entropy_loss(logits, labels):
1074
+ if not self.is_encoder_decoder:
1075
+ # Shift so that tokens < n predict n
1076
+ logits = logits[..., :-1, :].contiguous()
1077
+ labels = labels[..., 1:].contiguous()
1078
+ # Flatten the tokens
1079
+ loss_fct = nn.CrossEntropyLoss()
1080
+ logits = logits.view(-1, logits.shape[-1])
1081
+ labels = labels.view(-1)
1082
+ # Enable model parallelism
1083
+ labels = labels.to(logits.device)
1084
+ loss = loss_fct(logits, labels)
1085
+ return loss
1086
+
1087
+ labels = concatenated_batch["concatenated_labels"].clone()
1088
+
1089
+ if self.cpo_alpha == 0:
1090
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1091
+ else:
1092
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1093
+
1094
+ all_logps = self.get_batch_logps(
1095
+ all_logits,
1096
+ concatenated_batch["concatenated_labels"],
1097
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
1098
+ is_encoder_decoder=self.is_encoder_decoder,
1099
+ label_pad_token_id=self.label_pad_token_id,
1100
+ )
1101
+
1102
+ chosen_logps = all_logps[:len_chosen]
1103
+ rejected_logps = all_logps[len_chosen:]
1104
+
1105
+ chosen_logits = all_logits[:len_chosen]
1106
+ rejected_logits = all_logits[len_chosen:]
1107
+
1108
+ if self.aux_loss_enabled:
1109
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1110
+
1111
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1112
+
1113
+ def get_batch_loss_metrics(
1114
+ self,
1115
+ model,
1116
+ batch: dict[str, Union[list, torch.LongTensor]],
1117
+ train_eval: Literal["train", "eval"] = "train",
1118
+ ):
1119
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1120
+ metrics = {}
1121
+
1122
+ forward_output = self.concatenated_forward(model, batch)
1123
+ (
1124
+ policy_chosen_logps,
1125
+ policy_rejected_logps,
1126
+ policy_chosen_logits,
1127
+ policy_rejected_logits,
1128
+ policy_nll_loss,
1129
+ ) = forward_output[:5]
1130
+ if self.aux_loss_enabled:
1131
+ aux_loss = forward_output[5]
1132
+
1133
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1134
+ policy_chosen_logps,
1135
+ policy_rejected_logps,
1136
+ )
1137
+
1138
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1139
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1140
+
1141
+ prefix = "eval_" if train_eval == "eval" else ""
1142
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1143
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1144
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1145
+ metrics[f"{prefix}rewards/margins"] = (
1146
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1147
+ )
1148
+ metrics[f"{prefix}logps/rejected"] = (
1149
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1150
+ )
1151
+ metrics[f"{prefix}logps/chosen"] = (
1152
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1153
+ )
1154
+ metrics[f"{prefix}logits/rejected"] = (
1155
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1156
+ )
1157
+ metrics[f"{prefix}logits/chosen"] = (
1158
+ self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1159
+ )
1160
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1161
+
1162
+ if self.aux_loss_enabled:
1163
+ loss += self.aux_loss_coef * aux_loss
1164
+
1165
+ return loss, metrics
1166
+
1167
+ def compute_loss(
1168
+ self,
1169
+ model: Union[PreTrainedModel, nn.Module],
1170
+ inputs: dict[str, Union[torch.Tensor, Any]],
1171
+ return_outputs=False,
1172
+ num_items_in_batch=None,
1173
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1174
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1175
+
1176
+ with compute_loss_context_manager:
1177
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1178
+
1179
+ # force log the metrics
1180
+ self.store_metrics(metrics, train_eval="train")
1181
+
1182
+ if return_outputs:
1183
+ return (loss, metrics)
1184
+ return loss
1185
+
1186
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1187
+ """Generate samples from the model and reference model for the given batch of inputs."""
1188
+
1189
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1190
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1191
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1192
+
1193
+ with generate_context_manager:
1194
+ policy_output = model.generate(
1195
+ input_ids=batch["prompt_input_ids"],
1196
+ attention_mask=batch["prompt_attention_mask"],
1197
+ max_length=self.max_length,
1198
+ do_sample=True,
1199
+ pad_token_id=self.processing_class.pad_token_id,
1200
+ )
1201
+
1202
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1203
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1204
+
1205
+ return policy_output_decoded
1206
+
1207
+ def prediction_step(
1208
+ self,
1209
+ model: Union[PreTrainedModel, nn.Module],
1210
+ inputs: dict[str, Union[torch.Tensor, Any]],
1211
+ prediction_loss_only: bool,
1212
+ ignore_keys: Optional[list[str]] = None,
1213
+ ):
1214
+ if ignore_keys is None:
1215
+ if hasattr(model, "config"):
1216
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1217
+ else:
1218
+ ignore_keys = []
1219
+
1220
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1221
+
1222
+ with torch.no_grad(), prediction_context_manager:
1223
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1224
+
1225
+ # force log the metrics
1226
+ self.store_metrics(metrics, train_eval="eval")
1227
+
1228
+ if prediction_loss_only:
1229
+ return (loss.detach(), None, None)
1230
+
1231
+ # logits for the chosen and rejected samples from model
1232
+ logits_dict = {
1233
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1234
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1235
+ }
1236
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1237
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1238
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1239
+
1240
+ return (loss.detach(), logits, labels)
1241
+
1242
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1243
+ for key, value in metrics.items():
1244
+ self._stored_metrics[train_eval][key].append(value)
1245
+
1246
+ def evaluation_loop(
1247
+ self,
1248
+ dataloader: DataLoader,
1249
+ description: str,
1250
+ prediction_loss_only: Optional[bool] = None,
1251
+ ignore_keys: Optional[list[str]] = None,
1252
+ metric_key_prefix: str = "eval",
1253
+ ) -> EvalLoopOutput:
1254
+ """
1255
+ Overriding built-in evaluation loop to store metrics for each batch.
1256
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1257
+
1258
+ Works both with or without labels.
1259
+ """
1260
+
1261
+ # Sample and save to game log if requested (for one batch to save time)
1262
+ if self.generate_during_eval:
1263
+ # Generate random indices within the range of the total number of samples
1264
+ num_samples = len(dataloader.dataset)
1265
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1266
+
1267
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1268
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1269
+ random_batch = self.data_collator(random_batch_dataset)
1270
+ random_batch = self._prepare_inputs(random_batch)
1271
+
1272
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1273
+
1274
+ table = pd.DataFrame(
1275
+ columns=["Prompt", "Policy"],
1276
+ data=[
1277
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1278
+ ],
1279
+ )
1280
+ if "wandb" in self.args.report_to:
1281
+ wandb.log({"game_log": wandb.Table(data=table)})
1282
+
1283
+ if "comet_ml" in self.args.report_to:
1284
+ log_table_to_comet_experiment(
1285
+ name="game_log.csv",
1286
+ table=table,
1287
+ )
1288
+
1289
+ # Base evaluation
1290
+ initial_output = super().evaluation_loop(
1291
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1292
+ )
1293
+
1294
+ return initial_output
1295
+
1296
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1297
+ """
1298
+ Log `logs` on the various objects watching training, including stored metrics.
1299
+
1300
+ Args:
1301
+ logs (`dict[str, float]`):
1302
+ The values to log.
1303
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1304
+ Start time of the training.
1305
+ """
1306
+ # logs either has 'loss' or 'eval_loss'
1307
+ train_eval = "train" if "loss" in logs else "eval"
1308
+ # Add averaged stored metrics to logs
1309
+ for key, metrics in self._stored_metrics[train_eval].items():
1310
+ logs[key] = torch.tensor(metrics).mean().item()
1311
+ del self._stored_metrics[train_eval]
1312
+
1313
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1314
+ return super().log(logs, start_time)
1315
+ else: # transformers<=4.46
1316
+ return super().log(logs)
1317
+
1318
+ def _shift_right(self, input_ids):
1319
+ if self.decoder_start_token_id is None:
1320
+ raise ValueError(
1321
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1322
+ )
1323
+
1324
+ # shift inputs to the right
1325
+ if is_torch_fx_proxy(input_ids):
1326
+ # Item assignment is not supported natively for proxies.
1327
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1328
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1329
+ else:
1330
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1331
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1332
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1333
+
1334
+ if self.pad_token_id is None:
1335
+ raise ValueError("model.config.pad_token_id has to be defined.")
1336
+ # replace possible -100 values in labels by `pad_token_id`
1337
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1338
+
1339
+ return shifted_input_ids
1340
+
1341
+ def create_model_card(
1342
+ self,
1343
+ model_name: Optional[str] = None,
1344
+ dataset_name: Optional[str] = None,
1345
+ tags: Union[str, list[str], None] = None,
1346
+ ):
1347
+ """
1348
+ Creates a draft of a model card using the information available to the `Trainer`.
1349
+
1350
+ Args:
1351
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1352
+ Name of the model.
1353
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1354
+ Name of the dataset used for training.
1355
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1356
+ Tags to be associated with the model card.
1357
+ """
1358
+ if not self.is_world_process_zero():
1359
+ return
1360
+
1361
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1362
+ base_model = self.model.config._name_or_path
1363
+ else:
1364
+ base_model = None
1365
+
1366
+ tags = tags or []
1367
+ if isinstance(tags, str):
1368
+ tags = [tags]
1369
+
1370
+ if hasattr(self.model.config, "unsloth_version"):
1371
+ tags.append("unsloth")
1372
+
1373
+ citation = textwrap.dedent("""\
1374
+ @inproceedings{xu2024contrastive,
1375
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1376
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1377
+ year = 2024,
1378
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1379
+ publisher = {OpenReview.net},
1380
+ url = {https://openreview.net/forum?id=51iwkioZpn}
1381
+ }""")
1382
+
1383
+ model_card = generate_model_card(
1384
+ base_model=base_model,
1385
+ model_name=model_name,
1386
+ hub_model_id=self.hub_model_id,
1387
+ dataset_name=dataset_name,
1388
+ tags=tags,
1389
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1390
+ comet_url=get_comet_experiment_url(),
1391
+ trainer_name="CPO",
1392
+ trainer_citation=citation,
1393
+ paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1394
+ paper_id="2401.08417",
1395
+ )
1396
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1397
+ class UnslothCPOTrainer(_UnslothCPOTrainer):
1398
+ """
1399
+
1400
+ Initialize CPOTrainer.
1401
+
1402
+ Args:
1403
+ model (`transformers.PreTrainedModel`):
1404
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1405
+ args (`CPOConfig`):
1406
+ The CPO config arguments to use for training.
1407
+ data_collator (`transformers.DataCollator`):
1408
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1409
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1410
+ train_dataset (`datasets.Dataset`):
1411
+ The dataset to use for training.
1412
+ eval_dataset (`datasets.Dataset`):
1413
+ The dataset to use for evaluation.
1414
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1415
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1416
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1417
+ reuse the fine-tuned model.
1418
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1419
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1420
+ callbacks (`list[transformers.TrainerCallback]`):
1421
+ The callbacks to use for training.
1422
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1423
+ The optimizer and scheduler to use for training.
1424
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1425
+ The function to use to preprocess the logits before computing the metrics.
1426
+ peft_config (`dict`, defaults to `None`):
1427
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1428
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1429
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1430
+ a dictionary string to metric values.
1431
+
1432
+ """
1433
+ def __init__(
1434
+ self,
1435
+ model = None,
1436
+ args = None,
1437
+ data_collator = None,
1438
+ train_dataset = None,
1439
+ eval_dataset = None,
1440
+ processing_class = None,
1441
+ model_init = None,
1442
+ callbacks = None,
1443
+ preprocess_logits_for_metrics = None,
1444
+ peft_config = None,
1445
+ compute_metrics = None,
1446
+ **kwargs
1447
+ ):
1448
+ if args is None: args = UnslothCPOConfig()
1449
+ use_bf16 = getattr(args, 'bf16', False)
1450
+ use_fp16 = getattr(args, 'fp16', False)
1451
+ force_float32 = False
1452
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1453
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1454
+ force_float32 = True
1455
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1456
+ dtype = getattr(model.config, 'torch_dtype', None)
1457
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1458
+ from unsloth_zoo.utils import _get_dtype
1459
+ dtype = _get_dtype(dtype)
1460
+ float16 = dtype == torch.float16
1461
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1462
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1463
+ if force_float32:
1464
+ args.fp16 = False
1465
+ args.bf16 = False
1466
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1467
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1468
+ args.fp16 = float16
1469
+ args.bf16 = not float16
1470
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1471
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1472
+ args.eval_strategy = 'steps'
1473
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1474
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1475
+ if ga_steps is not None and ga_steps > 1:
1476
+ from transformers import __version__ as transformers_version
1477
+ if Version(transformers_version) <= Version('4.45.2'):
1478
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1479
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1480
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1481
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1482
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1483
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1484
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1485
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1486
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1487
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1488
+ if force_float32:
1489
+ args.bf16_full_eval = False
1490
+ args.fp16_full_eval = False
1491
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1492
+ args.bf16_full_eval = True
1493
+ args.fp16_full_eval = False
1494
+ elif not bf16_full_eval and not fp16_full_eval:
1495
+ args.bf16_full_eval = args.bf16
1496
+ args.fp16_full_eval = args.fp16
1497
+ _output_logits = False
1498
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1499
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1500
+ if _output_logits:
1501
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1502
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1503
+ pass
1504
+ else:
1505
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1506
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1507
+ if args_max_seq_length is None and model_max_seq_length is not None:
1508
+ max_seq_length = model.max_seq_length
1509
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1510
+ if model is not None and hasattr(model, 'for_training'):
1511
+ model.for_training()
1512
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1513
+ if 'processing_class' in locals():
1514
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1515
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1516
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1517
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1518
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1519
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1520
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1521
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1522
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1523
+ else:
1524
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1525
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1526
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1527
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1528
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1529
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1530
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1531
+ else:
1532
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1533
+ other_metrics = []
1534
+
1535
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1536
+ PatchRLStatistics('cpo_trainer', other_metrics)
1537
+
1538
+ super().__init__(
1539
+ model = model,
1540
+ args = args,
1541
+ data_collator = data_collator,
1542
+ train_dataset = train_dataset,
1543
+ eval_dataset = eval_dataset,
1544
+ processing_class = processing_class,
1545
+ model_init = model_init,
1546
+ callbacks = callbacks,
1547
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1548
+ peft_config = peft_config,
1549
+ compute_metrics = compute_metrics,**kwargs)
1550
+ if hasattr(self, 'neftune_hook_handle'):
1551
+ self.neftune_hook_handle.remove()
1552
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1553
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1554
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1555
+ pass
1556
+
1557
+ pass
unsloth_compiled_cache/UnslothDDPOTrainer.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothDDPOConfig(DDPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`DDPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (by default is the file name without the extension name).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed.
59
+ log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
60
+ Log with either 'wandb' or 'tensorboard', check
61
+ https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
62
+ tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
63
+ Keyword arguments for the tracker (e.g. wandb_project).
64
+ accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the accelerator.
66
+ project_kwargs (`Dict`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator project config (e.g. `logging_dir`).
68
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
69
+ Name of project to use for tracking.
70
+ logdir (`str`, *optional*, defaults to `"logs"`):
71
+ Top-level logging directory for checkpoint saving.
72
+ num_epochs (`int`, *optional*, defaults to `100`):
73
+ Number of epochs to train.
74
+ save_freq (`int`, *optional*, defaults to `1`):
75
+ Number of epochs between saving model checkpoints.
76
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
77
+ Number of checkpoints to keep before overwriting old ones.
78
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
79
+ Mixed precision training.
80
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
81
+ Allow `tf32` on Ampere GPUs.
82
+ resume_from (`str`, *optional*, defaults to `""`):
83
+ Resume training from a checkpoint.
84
+ sample_num_steps (`int`, *optional*, defaults to `50`):
85
+ Number of sampler inference steps.
86
+ sample_eta (`float`, *optional*, defaults to `1.0`):
87
+ Eta parameter for the DDIM sampler.
88
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
89
+ Classifier-free guidance weight.
90
+ sample_batch_size (`int`, *optional*, defaults to `1`):
91
+ Batch size (per GPU) to use for sampling.
92
+ sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
93
+ Number of batches to sample per epoch.
94
+ train_batch_size (`int`, *optional*, defaults to `1`):
95
+ Batch size (per GPU) to use for training.
96
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
97
+ Use 8bit Adam optimizer from bitsandbytes.
98
+ train_learning_rate (`float`, *optional*, defaults to `3e-4`):
99
+ Learning rate.
100
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
101
+ Adam beta1.
102
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
103
+ Adam beta2.
104
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
105
+ Adam weight decay.
106
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
107
+ Adam epsilon.
108
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
109
+ Number of gradient accumulation steps.
110
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
111
+ Maximum gradient norm for gradient clipping.
112
+ train_num_inner_epochs (`int`, *optional*, defaults to `1`):
113
+ Number of inner epochs per outer epoch.
114
+ train_cfg (`bool`, *optional*, defaults to `True`):
115
+ Whether to use classifier-free guidance during training.
116
+ train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
117
+ Clip advantages to the range.
118
+ train_clip_range (`float`, *optional*, defaults to `1e-4`):
119
+ PPO clip range.
120
+ train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
121
+ Fraction of timesteps to train on.
122
+ per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
123
+ Whether to track statistics for each prompt separately.
124
+ per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
125
+ Number of reward values to store in the buffer for each prompt.
126
+ per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
127
+ Minimum number of reward values to store in the buffer.
128
+ async_reward_computation (`bool`, *optional*, defaults to `False`):
129
+ Whether to compute rewards asynchronously.
130
+ max_workers (`int`, *optional*, defaults to `2`):
131
+ Maximum number of workers to use for async reward computation.
132
+ negative_prompts (`str`, *optional*, defaults to `""`):
133
+ Comma-separated list of prompts to use as negative examples.
134
+ push_to_hub (`bool`, *optional*, defaults to `False`):
135
+ Whether to push the final model checkpoint to the Hub.
136
+
137
+ """
138
+ vllm_sampling_params: Optional[Any] = field(
139
+ default = None,
140
+ metadata = {'help': 'vLLM SamplingParams'},
141
+ )
142
+ unsloth_num_chunks : Optional[int] = field(
143
+ default = -1,
144
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
145
+ )
146
+ def __init__(
147
+ self,
148
+ exp_name = 'demo',
149
+ run_name = '',
150
+ seed = 3407,
151
+ log_with = None,
152
+ tracker_project_name = 'trl',
153
+ logdir = 'logs',
154
+ num_epochs = 100,
155
+ save_freq = 1,
156
+ num_checkpoint_limit = 5,
157
+ mixed_precision = 'fp16',
158
+ allow_tf32 = True,
159
+ resume_from = '',
160
+ sample_num_steps = 50,
161
+ sample_eta = 1.0,
162
+ sample_guidance_scale = 5.0,
163
+ sample_batch_size = 1,
164
+ sample_num_batches_per_epoch = 2,
165
+ train_batch_size = 1,
166
+ train_use_8bit_adam = False,
167
+ train_learning_rate = 5e-05,
168
+ train_adam_beta1 = 0.9,
169
+ train_adam_beta2 = 0.999,
170
+ train_adam_weight_decay = 0.01,
171
+ train_adam_epsilon = 1e-08,
172
+ train_gradient_accumulation_steps = 2,
173
+ train_max_grad_norm = 1.0,
174
+ train_num_inner_epochs = 1,
175
+ train_cfg = True,
176
+ train_adv_clip_max = 5.0,
177
+ train_clip_range = 0.0001,
178
+ train_timestep_fraction = 1.0,
179
+ per_prompt_stat_tracking = False,
180
+ per_prompt_stat_tracking_buffer_size = 16,
181
+ per_prompt_stat_tracking_min_count = 16,
182
+ async_reward_computation = False,
183
+ max_workers = 2,
184
+ negative_prompts = '',
185
+ push_to_hub = False,
186
+ vllm_sampling_params = None,
187
+ unsloth_num_chunks = -1,
188
+ **kwargs,
189
+ ):
190
+
191
+ super().__init__(
192
+ exp_name = exp_name,
193
+ run_name = run_name,
194
+ seed = seed,
195
+ log_with = log_with,
196
+ tracker_project_name = tracker_project_name,
197
+ logdir = logdir,
198
+ num_epochs = num_epochs,
199
+ save_freq = save_freq,
200
+ num_checkpoint_limit = num_checkpoint_limit,
201
+ mixed_precision = mixed_precision,
202
+ allow_tf32 = allow_tf32,
203
+ resume_from = resume_from,
204
+ sample_num_steps = sample_num_steps,
205
+ sample_eta = sample_eta,
206
+ sample_guidance_scale = sample_guidance_scale,
207
+ sample_batch_size = sample_batch_size,
208
+ sample_num_batches_per_epoch = sample_num_batches_per_epoch,
209
+ train_batch_size = train_batch_size,
210
+ train_use_8bit_adam = train_use_8bit_adam,
211
+ train_learning_rate = train_learning_rate,
212
+ train_adam_beta1 = train_adam_beta1,
213
+ train_adam_beta2 = train_adam_beta2,
214
+ train_adam_weight_decay = train_adam_weight_decay,
215
+ train_adam_epsilon = train_adam_epsilon,
216
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
217
+ train_max_grad_norm = train_max_grad_norm,
218
+ train_num_inner_epochs = train_num_inner_epochs,
219
+ train_cfg = train_cfg,
220
+ train_adv_clip_max = train_adv_clip_max,
221
+ train_clip_range = train_clip_range,
222
+ train_timestep_fraction = train_timestep_fraction,
223
+ per_prompt_stat_tracking = per_prompt_stat_tracking,
224
+ per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
225
+ per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
226
+ async_reward_computation = async_reward_computation,
227
+ max_workers = max_workers,
228
+ negative_prompts = negative_prompts,
229
+ push_to_hub = push_to_hub,**kwargs)
230
+ self.vllm_sampling_params = vllm_sampling_params
231
+ self.unsloth_num_chunks = unsloth_num_chunks
232
+ pass
233
+
234
+ class _UnslothDDPOTrainer(PyTorchModelHubMixin):
235
+ """"""
236
+
237
+ _tag_names = ["trl", "ddpo"]
238
+
239
+ def __init__(
240
+ self,
241
+ config: DDPOConfig,
242
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
243
+ prompt_function: Callable[[], tuple[str, Any]],
244
+ sd_pipeline: DDPOStableDiffusionPipeline,
245
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
246
+ ):
247
+ if image_samples_hook is None:
248
+ warn("No image_samples_hook provided; no images will be logged")
249
+
250
+ self.prompt_fn = prompt_function
251
+ self.reward_fn = reward_function
252
+ self.config = config
253
+ self.image_samples_callback = image_samples_hook
254
+
255
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
256
+
257
+ if self.config.resume_from:
258
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
259
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
260
+ # get the most recent checkpoint in this directory
261
+ checkpoints = list(
262
+ filter(
263
+ lambda x: "checkpoint_" in x,
264
+ os.listdir(self.config.resume_from),
265
+ )
266
+ )
267
+ if len(checkpoints) == 0:
268
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
269
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
270
+ self.config.resume_from = os.path.join(
271
+ self.config.resume_from,
272
+ f"checkpoint_{checkpoint_numbers[-1]}",
273
+ )
274
+
275
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
276
+
277
+ # number of timesteps within each trajectory to train on
278
+ self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
279
+
280
+ self.accelerator = Accelerator(
281
+ log_with=self.config.log_with,
282
+ mixed_precision=self.config.mixed_precision,
283
+ project_config=accelerator_project_config,
284
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
285
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
286
+ # the total number of optimizer steps to accumulate across.
287
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
288
+ **self.config.accelerator_kwargs,
289
+ )
290
+
291
+ is_okay, message = self._config_check()
292
+ if not is_okay:
293
+ raise ValueError(message)
294
+
295
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
296
+
297
+ if self.accelerator.is_main_process:
298
+ self.accelerator.init_trackers(
299
+ self.config.tracker_project_name,
300
+ config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
301
+ init_kwargs=self.config.tracker_kwargs,
302
+ )
303
+
304
+ logger.info(f"\n{config}")
305
+
306
+ set_seed(self.config.seed, device_specific=True)
307
+
308
+ self.sd_pipeline = sd_pipeline
309
+
310
+ self.sd_pipeline.set_progress_bar_config(
311
+ position=1,
312
+ disable=not self.accelerator.is_local_main_process,
313
+ leave=False,
314
+ desc="Timestep",
315
+ dynamic_ncols=True,
316
+ )
317
+
318
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
319
+ # as these weights are only used for inference, keeping weights in full precision is not required.
320
+ if self.accelerator.mixed_precision == "fp16":
321
+ inference_dtype = torch.float16
322
+ elif self.accelerator.mixed_precision == "bf16":
323
+ inference_dtype = torch.bfloat16
324
+ else:
325
+ inference_dtype = torch.float32
326
+
327
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
328
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
329
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
330
+
331
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
332
+
333
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
334
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
335
+
336
+ # Enable TF32 for faster training on Ampere GPUs,
337
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
338
+ if self.config.allow_tf32:
339
+ torch.backends.cuda.matmul.allow_tf32 = True
340
+
341
+ self.optimizer = self._setup_optimizer(
342
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
343
+ )
344
+
345
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
346
+ self.sd_pipeline.tokenizer(
347
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
348
+ return_tensors="pt",
349
+ padding="max_length",
350
+ truncation=True,
351
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
352
+ ).input_ids.to(self.accelerator.device)
353
+ )[0]
354
+
355
+ if config.per_prompt_stat_tracking:
356
+ self.stat_tracker = PerPromptStatTracker(
357
+ config.per_prompt_stat_tracking_buffer_size,
358
+ config.per_prompt_stat_tracking_min_count,
359
+ )
360
+
361
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
362
+ # more memory
363
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
364
+
365
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
366
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
367
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
368
+ else:
369
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
370
+
371
+ if self.config.async_reward_computation:
372
+ self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
373
+
374
+ if config.resume_from:
375
+ logger.info(f"Resuming from {config.resume_from}")
376
+ self.accelerator.load_state(config.resume_from)
377
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
378
+ else:
379
+ self.first_epoch = 0
380
+
381
+ def compute_rewards(self, prompt_image_pairs, is_async=False):
382
+ if not is_async:
383
+ rewards = []
384
+ for images, prompts, prompt_metadata in prompt_image_pairs:
385
+ reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
386
+ rewards.append(
387
+ (
388
+ torch.as_tensor(reward, device=self.accelerator.device),
389
+ reward_metadata,
390
+ )
391
+ )
392
+ else:
393
+ rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
394
+ rewards = [
395
+ (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
396
+ for reward, reward_metadata in rewards
397
+ ]
398
+
399
+ return zip(*rewards)
400
+
401
+ def step(self, epoch: int, global_step: int):
402
+ """
403
+ Perform a single step of training.
404
+
405
+ Args:
406
+ epoch (int): The current epoch.
407
+ global_step (int): The current global step.
408
+
409
+ Side Effects:
410
+ - Model weights are updated
411
+ - Logs the statistics to the accelerator trackers.
412
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
413
+
414
+ Returns:
415
+ global_step (int): The updated global step.
416
+
417
+ """
418
+ samples, prompt_image_data = self._generate_samples(
419
+ iterations=self.config.sample_num_batches_per_epoch,
420
+ batch_size=self.config.sample_batch_size,
421
+ )
422
+
423
+ # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
424
+ samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
425
+ rewards, rewards_metadata = self.compute_rewards(
426
+ prompt_image_data, is_async=self.config.async_reward_computation
427
+ )
428
+
429
+ for i, image_data in enumerate(prompt_image_data):
430
+ image_data.extend([rewards[i], rewards_metadata[i]])
431
+
432
+ if self.image_samples_callback is not None:
433
+ self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
434
+
435
+ rewards = torch.cat(rewards)
436
+ rewards = self.accelerator.gather(rewards).cpu().numpy()
437
+
438
+ self.accelerator.log(
439
+ {
440
+ "reward": rewards,
441
+ "epoch": epoch,
442
+ "reward_mean": rewards.mean(),
443
+ "reward_std": rewards.std(),
444
+ },
445
+ step=global_step,
446
+ )
447
+
448
+ if self.config.per_prompt_stat_tracking:
449
+ # gather the prompts across processes
450
+ prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
451
+ prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
452
+ advantages = self.stat_tracker.update(prompts, rewards)
453
+ else:
454
+ advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
455
+
456
+ # ungather advantages; keep the entries corresponding to the samples on this process
457
+ samples["advantages"] = (
458
+ torch.as_tensor(advantages)
459
+ .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
460
+ .to(self.accelerator.device)
461
+ )
462
+
463
+ del samples["prompt_ids"]
464
+
465
+ total_batch_size, num_timesteps = samples["timesteps"].shape
466
+
467
+ for inner_epoch in range(self.config.train_num_inner_epochs):
468
+ # shuffle samples along batch dimension
469
+ perm = torch.randperm(total_batch_size, device=self.accelerator.device)
470
+ samples = {k: v[perm] for k, v in samples.items()}
471
+
472
+ # shuffle along time dimension independently for each sample
473
+ # still trying to understand the code below
474
+ perms = torch.stack(
475
+ [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
476
+ )
477
+
478
+ for key in ["timesteps", "latents", "next_latents", "log_probs"]:
479
+ samples[key] = samples[key][
480
+ torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
481
+ perms,
482
+ ]
483
+
484
+ original_keys = samples.keys()
485
+ original_values = samples.values()
486
+ # rebatch them as user defined train_batch_size is different from sample_batch_size
487
+ reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
488
+
489
+ # Transpose the list of original values
490
+ transposed_values = zip(*reshaped_values)
491
+ # Create new dictionaries for each row of transposed values
492
+ samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
493
+
494
+ self.sd_pipeline.unet.train()
495
+ global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
496
+ # ensure optimization step at the end of the inner epoch
497
+ if not self.accelerator.sync_gradients:
498
+ raise ValueError(
499
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
500
+ )
501
+
502
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
503
+ self.accelerator.save_state()
504
+
505
+ return global_step
506
+
507
+ def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
508
+ """
509
+ Calculate the loss for a batch of an unpacked sample
510
+
511
+ Args:
512
+ latents (torch.Tensor):
513
+ The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
514
+ timesteps (torch.Tensor):
515
+ The timesteps sampled from the diffusion model, shape: [batch_size]
516
+ next_latents (torch.Tensor):
517
+ The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
518
+ log_probs (torch.Tensor):
519
+ The log probabilities of the latents, shape: [batch_size]
520
+ advantages (torch.Tensor):
521
+ The advantages of the latents, shape: [batch_size]
522
+ embeds (torch.Tensor):
523
+ The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
524
+ Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
525
+
526
+ Returns:
527
+ loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
528
+ (all of these are of shape (1,))
529
+ """
530
+ with self.autocast():
531
+ if self.config.train_cfg:
532
+ noise_pred = self.sd_pipeline.unet(
533
+ torch.cat([latents] * 2),
534
+ torch.cat([timesteps] * 2),
535
+ embeds,
536
+ ).sample
537
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
538
+ noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
539
+ noise_pred_text - noise_pred_uncond
540
+ )
541
+ else:
542
+ noise_pred = self.sd_pipeline.unet(
543
+ latents,
544
+ timesteps,
545
+ embeds,
546
+ ).sample
547
+ # compute the log prob of next_latents given latents under the current model
548
+
549
+ scheduler_step_output = self.sd_pipeline.scheduler_step(
550
+ noise_pred,
551
+ timesteps,
552
+ latents,
553
+ eta=self.config.sample_eta,
554
+ prev_sample=next_latents,
555
+ )
556
+
557
+ log_prob = scheduler_step_output.log_probs
558
+
559
+ advantages = torch.clamp(
560
+ advantages,
561
+ -self.config.train_adv_clip_max,
562
+ self.config.train_adv_clip_max,
563
+ )
564
+
565
+ ratio = torch.exp(log_prob - log_probs)
566
+
567
+ loss = self.loss(advantages, self.config.train_clip_range, ratio)
568
+
569
+ approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
570
+
571
+ clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
572
+
573
+ return loss, approx_kl, clipfrac
574
+
575
+ def loss(
576
+ self,
577
+ advantages: torch.Tensor,
578
+ clip_range: float,
579
+ ratio: torch.Tensor,
580
+ ):
581
+ unclipped_loss = -advantages * ratio
582
+ clipped_loss = -advantages * torch.clamp(
583
+ ratio,
584
+ 1.0 - clip_range,
585
+ 1.0 + clip_range,
586
+ )
587
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
588
+
589
+ def _setup_optimizer(self, trainable_layers_parameters):
590
+ if self.config.train_use_8bit_adam:
591
+ import bitsandbytes
592
+
593
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
594
+ else:
595
+ optimizer_cls = torch.optim.AdamW
596
+
597
+ return optimizer_cls(
598
+ trainable_layers_parameters,
599
+ lr=self.config.train_learning_rate,
600
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
601
+ weight_decay=self.config.train_adam_weight_decay,
602
+ eps=self.config.train_adam_epsilon,
603
+ )
604
+
605
+ def _save_model_hook(self, models, weights, output_dir):
606
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
607
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
608
+
609
+ def _load_model_hook(self, models, input_dir):
610
+ self.sd_pipeline.load_checkpoint(models, input_dir)
611
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
612
+
613
+ def _generate_samples(self, iterations, batch_size):
614
+ """
615
+ Generate samples from the model
616
+
617
+ Args:
618
+ iterations (int): Number of iterations to generate samples for
619
+ batch_size (int): Batch size to use for sampling
620
+
621
+ Returns:
622
+ samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
623
+ """
624
+ samples = []
625
+ prompt_image_pairs = []
626
+ self.sd_pipeline.unet.eval()
627
+
628
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
629
+
630
+ for _ in range(iterations):
631
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
632
+
633
+ prompt_ids = self.sd_pipeline.tokenizer(
634
+ prompts,
635
+ return_tensors="pt",
636
+ padding="max_length",
637
+ truncation=True,
638
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
639
+ ).input_ids.to(self.accelerator.device)
640
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
641
+
642
+ with self.autocast():
643
+ sd_output = self.sd_pipeline(
644
+ prompt_embeds=prompt_embeds,
645
+ negative_prompt_embeds=sample_neg_prompt_embeds,
646
+ num_inference_steps=self.config.sample_num_steps,
647
+ guidance_scale=self.config.sample_guidance_scale,
648
+ eta=self.config.sample_eta,
649
+ output_type="pt",
650
+ )
651
+
652
+ images = sd_output.images
653
+ latents = sd_output.latents
654
+ log_probs = sd_output.log_probs
655
+
656
+ latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
657
+ log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
658
+ timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
659
+
660
+ samples.append(
661
+ {
662
+ "prompt_ids": prompt_ids,
663
+ "prompt_embeds": prompt_embeds,
664
+ "timesteps": timesteps,
665
+ "latents": latents[:, :-1], # each entry is the latent before timestep t
666
+ "next_latents": latents[:, 1:], # each entry is the latent after timestep t
667
+ "log_probs": log_probs,
668
+ "negative_prompt_embeds": sample_neg_prompt_embeds,
669
+ }
670
+ )
671
+ prompt_image_pairs.append([images, prompts, prompt_metadata])
672
+
673
+ return samples, prompt_image_pairs
674
+
675
+ def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
676
+ """
677
+ Train on a batch of samples. Main training segment
678
+
679
+ Args:
680
+ inner_epoch (int): The current inner epoch
681
+ epoch (int): The current epoch
682
+ global_step (int): The current global step
683
+ batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
684
+
685
+ Side Effects:
686
+ - Model weights are updated
687
+ - Logs the statistics to the accelerator trackers.
688
+
689
+ Returns:
690
+ global_step (int): The updated global step
691
+ """
692
+ info = defaultdict(list)
693
+ for _i, sample in enumerate(batched_samples):
694
+ if self.config.train_cfg:
695
+ # concat negative prompts to sample prompts to avoid two forward passes
696
+ embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
697
+ else:
698
+ embeds = sample["prompt_embeds"]
699
+
700
+ for j in range(self.num_train_timesteps):
701
+ with self.accelerator.accumulate(self.sd_pipeline.unet):
702
+ loss, approx_kl, clipfrac = self.calculate_loss(
703
+ sample["latents"][:, j],
704
+ sample["timesteps"][:, j],
705
+ sample["next_latents"][:, j],
706
+ sample["log_probs"][:, j],
707
+ sample["advantages"],
708
+ embeds,
709
+ )
710
+ info["approx_kl"].append(approx_kl)
711
+ info["clipfrac"].append(clipfrac)
712
+ info["loss"].append(loss)
713
+
714
+ self.accelerator.backward(loss)
715
+ if self.accelerator.sync_gradients:
716
+ self.accelerator.clip_grad_norm_(
717
+ self.trainable_layers.parameters()
718
+ if not isinstance(self.trainable_layers, list)
719
+ else self.trainable_layers,
720
+ self.config.train_max_grad_norm,
721
+ )
722
+ self.optimizer.step()
723
+ self.optimizer.zero_grad()
724
+
725
+ # Checks if the accelerator has performed an optimization step behind the scenes
726
+ if self.accelerator.sync_gradients:
727
+ # log training-related stuff
728
+ info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
729
+ info = self.accelerator.reduce(info, reduction="mean")
730
+ info.update({"epoch": epoch, "inner_epoch": inner_epoch})
731
+ self.accelerator.log(info, step=global_step)
732
+ global_step += 1
733
+ info = defaultdict(list)
734
+ return global_step
735
+
736
+ def _config_check(self) -> tuple[bool, str]:
737
+ samples_per_epoch = (
738
+ self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
739
+ )
740
+ total_train_batch_size = (
741
+ self.config.train_batch_size
742
+ * self.accelerator.num_processes
743
+ * self.config.train_gradient_accumulation_steps
744
+ )
745
+
746
+ if not self.config.sample_batch_size >= self.config.train_batch_size:
747
+ return (
748
+ False,
749
+ f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
750
+ )
751
+ if not self.config.sample_batch_size % self.config.train_batch_size == 0:
752
+ return (
753
+ False,
754
+ f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
755
+ )
756
+ if not samples_per_epoch % total_train_batch_size == 0:
757
+ return (
758
+ False,
759
+ f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
760
+ )
761
+ return True, ""
762
+
763
+ def train(self, epochs: Optional[int] = None):
764
+ """
765
+ Train the model for a given number of epochs
766
+ """
767
+ global_step = 0
768
+ if epochs is None:
769
+ epochs = self.config.num_epochs
770
+ for epoch in range(self.first_epoch, epochs):
771
+ global_step = self.step(epoch, global_step)
772
+
773
+ def _save_pretrained(self, save_directory):
774
+ self.sd_pipeline.save_pretrained(save_directory)
775
+ self.create_model_card()
776
+
777
+ def create_model_card(
778
+ self,
779
+ model_name: Optional[str] = None,
780
+ dataset_name: Optional[str] = None,
781
+ tags: Union[str, list[str], None] = None,
782
+ ):
783
+ """
784
+ Creates a draft of a model card using the information available to the `Trainer`.
785
+
786
+ Args:
787
+ model_name (`str` or `None`, *optional*, defaults to `None`):
788
+ Name of the model.
789
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
790
+ Name of the dataset used for training.
791
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
792
+ Tags to be associated with the model card.
793
+ """
794
+ if not self.is_world_process_zero():
795
+ return
796
+
797
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
798
+ base_model = self.model.config._name_or_path
799
+ else:
800
+ base_model = None
801
+
802
+ tags = tags or []
803
+ if isinstance(tags, str):
804
+ tags = [tags]
805
+
806
+ if hasattr(self.model.config, "unsloth_version"):
807
+ tags.append("unsloth")
808
+
809
+ citation = textwrap.dedent("""\
810
+ @inproceedings{black2024training,
811
+ title = {{Training Diffusion Models with Reinforcement Learning}},
812
+ author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
813
+ year = 2024,
814
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
815
+ publisher = {OpenReview.net},
816
+ url = {https://openreview.net/forum?id=YCWjhGrJFD},
817
+ }""")
818
+
819
+ model_card = generate_model_card(
820
+ base_model=base_model,
821
+ model_name=model_name,
822
+ hub_model_id=self.hub_model_id,
823
+ dataset_name=dataset_name,
824
+ tags=tags,
825
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
826
+ comet_url=get_comet_experiment_url(),
827
+ trainer_name="DDPO",
828
+ trainer_citation=citation,
829
+ paper_title="Training Diffusion Models with Reinforcement Learning",
830
+ paper_id="2305.13301",
831
+ )
832
+
833
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
834
+ class UnslothDDPOTrainer(_UnslothDDPOTrainer):
835
+ """
836
+
837
+ The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
838
+ Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
839
+ As of now only Stable Diffusion based pipelines are supported
840
+
841
+ Attributes:
842
+ **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
843
+ details.
844
+ **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
845
+ **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
846
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
847
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
848
+
849
+ """
850
+ def __init__(
851
+ self,
852
+ config,
853
+ reward_function,
854
+ prompt_function,
855
+ sd_pipeline,
856
+ image_samples_hook = None,
857
+ **kwargs
858
+ ):
859
+ if args is None: args = UnslothDDPOConfig()
860
+ other_metrics = []
861
+
862
+ from unsloth_zoo.logging_utils import PatchRLStatistics
863
+ PatchRLStatistics('ddpo_trainer', other_metrics)
864
+
865
+ super().__init__(
866
+ config = config,
867
+ reward_function = reward_function,
868
+ prompt_function = prompt_function,
869
+ sd_pipeline = sd_pipeline,
870
+ image_samples_hook = image_samples_hook,**kwargs)
871
+
872
+ pass
unsloth_compiled_cache/UnslothDPOTrainer.py ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/UnslothGKDTrainer.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothGKDConfig(GKDConfig):
44
+ """
45
+
46
+ Configuration class for [`GKDTrainer`].
47
+
48
+ Args:
49
+ temperature (`float`, *optional*, defaults to `0.9`):
50
+ Temperature for sampling. The higher the temperature, the more random the completions.
51
+ lmbda (`float`, *optional*, defaults to `0.5`):
52
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
53
+ student-generated outputs).
54
+ beta (`float`, *optional*, defaults to `0.5`):
55
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
56
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
57
+ max_new_tokens (`int`, *optional*, defaults to `128`):
58
+ Maximum number of tokens to generate per completion.
59
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
60
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
61
+ being trained.
62
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
63
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
64
+ from a string.
65
+ disable_dropout (`bool`, *optional*, defaults to `True`):
66
+ Whether to disable dropout in the model.
67
+ seq_kd (`bool`, *optional*, defaults to `False`):
68
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
69
+ on teacher-generated output).
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ evaluation_strategy = None,
189
+ push_to_hub_model_id = None,
190
+ push_to_hub_organization = None,
191
+ push_to_hub_token = None,
192
+ mp_parameters = '',
193
+ auto_find_batch_size = False,
194
+ full_determinism = False,
195
+ torchdynamo = None,
196
+ ray_scope = 'last',
197
+ ddp_timeout = 1800,
198
+ torch_compile = False,
199
+ torch_compile_backend = None,
200
+ torch_compile_mode = None,
201
+ dispatch_batches = None,
202
+ split_batches = None,
203
+ include_tokens_per_second = False,
204
+ include_num_input_tokens_seen = False,
205
+ neftune_noise_alpha = None,
206
+ optim_target_modules = None,
207
+ batch_eval_metrics = False,
208
+ eval_on_start = False,
209
+ use_liger_kernel = False,
210
+ eval_use_gather_object = False,
211
+ average_tokens_across_devices = False,
212
+ model_init_kwargs = None,
213
+ use_liger = False,
214
+ dataset_text_field = 'text',
215
+ dataset_kwargs = None,
216
+ dataset_num_proc = None,
217
+ max_seq_length = None,
218
+ packing = False,
219
+ eval_packing = None,
220
+ dataset_batch_size = None,
221
+ num_of_sequences = None,
222
+ chars_per_token = None,
223
+ temperature = 0.9,
224
+ lmbda = 0.5,
225
+ beta = 0.5,
226
+ max_new_tokens = 128,
227
+ teacher_model_name_or_path = None,
228
+ teacher_model_init_kwargs = None,
229
+ disable_dropout = True,
230
+ seq_kd = False,
231
+ vllm_sampling_params = None,
232
+ unsloth_num_chunks = -1,
233
+ **kwargs,
234
+ ):
235
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
236
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
237
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
238
+ output_dir = 'unsloth_training_checkpoints'
239
+ save_strategy = 'no'
240
+ if dataset_num_proc is None:
241
+ from multiprocessing import cpu_count
242
+ dataset_num_proc = cpu_count()
243
+
244
+ super().__init__(
245
+ output_dir = output_dir,
246
+ overwrite_output_dir = overwrite_output_dir,
247
+ do_train = do_train,
248
+ do_eval = do_eval,
249
+ do_predict = do_predict,
250
+ eval_strategy = eval_strategy,
251
+ prediction_loss_only = prediction_loss_only,
252
+ per_device_train_batch_size = per_device_train_batch_size,
253
+ per_device_eval_batch_size = per_device_eval_batch_size,
254
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
255
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
256
+ gradient_accumulation_steps = gradient_accumulation_steps,
257
+ eval_accumulation_steps = eval_accumulation_steps,
258
+ eval_delay = eval_delay,
259
+ torch_empty_cache_steps = torch_empty_cache_steps,
260
+ learning_rate = learning_rate,
261
+ weight_decay = weight_decay,
262
+ adam_beta1 = adam_beta1,
263
+ adam_beta2 = adam_beta2,
264
+ adam_epsilon = adam_epsilon,
265
+ max_grad_norm = max_grad_norm,
266
+ num_train_epochs = num_train_epochs,
267
+ max_steps = max_steps,
268
+ lr_scheduler_type = lr_scheduler_type,
269
+ warmup_ratio = warmup_ratio,
270
+ warmup_steps = warmup_steps,
271
+ log_level = log_level,
272
+ log_level_replica = log_level_replica,
273
+ log_on_each_node = log_on_each_node,
274
+ logging_dir = logging_dir,
275
+ logging_strategy = logging_strategy,
276
+ logging_first_step = logging_first_step,
277
+ logging_steps = logging_steps,
278
+ logging_nan_inf_filter = logging_nan_inf_filter,
279
+ save_strategy = save_strategy,
280
+ save_steps = save_steps,
281
+ save_total_limit = save_total_limit,
282
+ save_safetensors = save_safetensors,
283
+ save_on_each_node = save_on_each_node,
284
+ save_only_model = save_only_model,
285
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
286
+ no_cuda = no_cuda,
287
+ use_cpu = use_cpu,
288
+ use_mps_device = use_mps_device,
289
+ seed = seed,
290
+ data_seed = data_seed,
291
+ jit_mode_eval = jit_mode_eval,
292
+ use_ipex = use_ipex,
293
+ bf16 = bf16,
294
+ fp16 = fp16,
295
+ fp16_opt_level = fp16_opt_level,
296
+ half_precision_backend = half_precision_backend,
297
+ bf16_full_eval = bf16_full_eval,
298
+ fp16_full_eval = fp16_full_eval,
299
+ tf32 = tf32,
300
+ local_rank = local_rank,
301
+ ddp_backend = ddp_backend,
302
+ tpu_num_cores = tpu_num_cores,
303
+ tpu_metrics_debug = tpu_metrics_debug,
304
+ debug = debug,
305
+ dataloader_drop_last = dataloader_drop_last,
306
+ eval_steps = eval_steps,
307
+ dataloader_num_workers = dataloader_num_workers,
308
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
309
+ past_index = past_index,
310
+ run_name = run_name,
311
+ disable_tqdm = disable_tqdm,
312
+ remove_unused_columns = remove_unused_columns,
313
+ label_names = label_names,
314
+ load_best_model_at_end = load_best_model_at_end,
315
+ metric_for_best_model = metric_for_best_model,
316
+ greater_is_better = greater_is_better,
317
+ ignore_data_skip = ignore_data_skip,
318
+ fsdp = fsdp,
319
+ fsdp_min_num_params = fsdp_min_num_params,
320
+ fsdp_config = fsdp_config,
321
+ tp_size = tp_size,
322
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
323
+ accelerator_config = accelerator_config,
324
+ deepspeed = deepspeed,
325
+ label_smoothing_factor = label_smoothing_factor,
326
+ optim = optim,
327
+ optim_args = optim_args,
328
+ adafactor = adafactor,
329
+ group_by_length = group_by_length,
330
+ length_column_name = length_column_name,
331
+ report_to = report_to,
332
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
333
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
334
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
335
+ dataloader_pin_memory = dataloader_pin_memory,
336
+ dataloader_persistent_workers = dataloader_persistent_workers,
337
+ skip_memory_metrics = skip_memory_metrics,
338
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
339
+ push_to_hub = push_to_hub,
340
+ resume_from_checkpoint = resume_from_checkpoint,
341
+ hub_model_id = hub_model_id,
342
+ hub_strategy = hub_strategy,
343
+ hub_token = hub_token,
344
+ hub_private_repo = hub_private_repo,
345
+ hub_always_push = hub_always_push,
346
+ gradient_checkpointing = gradient_checkpointing,
347
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
348
+ include_inputs_for_metrics = include_inputs_for_metrics,
349
+ eval_do_concat_batches = eval_do_concat_batches,
350
+ fp16_backend = fp16_backend,
351
+ evaluation_strategy = evaluation_strategy,
352
+ push_to_hub_model_id = push_to_hub_model_id,
353
+ push_to_hub_organization = push_to_hub_organization,
354
+ push_to_hub_token = push_to_hub_token,
355
+ mp_parameters = mp_parameters,
356
+ auto_find_batch_size = auto_find_batch_size,
357
+ full_determinism = full_determinism,
358
+ torchdynamo = torchdynamo,
359
+ ray_scope = ray_scope,
360
+ ddp_timeout = ddp_timeout,
361
+ torch_compile = torch_compile,
362
+ torch_compile_backend = torch_compile_backend,
363
+ torch_compile_mode = torch_compile_mode,
364
+ dispatch_batches = dispatch_batches,
365
+ split_batches = split_batches,
366
+ include_tokens_per_second = include_tokens_per_second,
367
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
368
+ neftune_noise_alpha = neftune_noise_alpha,
369
+ optim_target_modules = optim_target_modules,
370
+ batch_eval_metrics = batch_eval_metrics,
371
+ eval_on_start = eval_on_start,
372
+ use_liger_kernel = use_liger_kernel,
373
+ eval_use_gather_object = eval_use_gather_object,
374
+ average_tokens_across_devices = average_tokens_across_devices,
375
+ model_init_kwargs = model_init_kwargs,
376
+ use_liger = use_liger,
377
+ dataset_text_field = dataset_text_field,
378
+ dataset_kwargs = dataset_kwargs,
379
+ dataset_num_proc = dataset_num_proc,
380
+ max_seq_length = max_seq_length,
381
+ packing = packing,
382
+ eval_packing = eval_packing,
383
+ dataset_batch_size = dataset_batch_size,
384
+ num_of_sequences = num_of_sequences,
385
+ chars_per_token = chars_per_token,
386
+ temperature = temperature,
387
+ lmbda = lmbda,
388
+ beta = beta,
389
+ max_new_tokens = max_new_tokens,
390
+ teacher_model_name_or_path = teacher_model_name_or_path,
391
+ teacher_model_init_kwargs = teacher_model_init_kwargs,
392
+ disable_dropout = disable_dropout,
393
+ seq_kd = seq_kd,**kwargs)
394
+ self.vllm_sampling_params = vllm_sampling_params
395
+ self.unsloth_num_chunks = unsloth_num_chunks
396
+ pass
397
+
398
+ class _UnslothGKDTrainer(SFTTrainer):
399
+ _tag_names = ["trl", "gkd"]
400
+
401
+ def __init__(
402
+ self,
403
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
404
+ teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
405
+ args: Optional[GKDConfig] = None,
406
+ data_collator: Optional[DataCollator] = None, # type: ignore
407
+ train_dataset: Optional[Dataset] = None,
408
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
409
+ processing_class: Optional[
410
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
411
+ ] = None,
412
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
413
+ callbacks: Optional[list[TrainerCallback]] = None,
414
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
415
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
416
+ peft_config: Optional["PeftConfig"] = None,
417
+ formatting_func: Optional[Callable] = None,
418
+ ):
419
+ # add remove_unused_columns=False to the dataclass args
420
+ args.remove_unused_columns = False
421
+ data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
422
+
423
+ super().__init__(
424
+ model,
425
+ args=args,
426
+ data_collator=data_collator,
427
+ train_dataset=train_dataset,
428
+ eval_dataset=eval_dataset,
429
+ processing_class=processing_class,
430
+ compute_metrics=compute_metrics,
431
+ callbacks=callbacks,
432
+ optimizers=optimizers,
433
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
434
+ peft_config=peft_config,
435
+ formatting_func=formatting_func,
436
+ )
437
+
438
+ if args.teacher_model_init_kwargs is None:
439
+ teacher_model_init_kwargs = {}
440
+ elif not isinstance(teacher_model, str):
441
+ raise ValueError(
442
+ "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
443
+ )
444
+ else:
445
+ teacher_model_init_kwargs = args.teacher_model_init_kwargs
446
+ teacher_model_init_kwargs["torch_dtype"] = (
447
+ teacher_model_init_kwargs["torch_dtype"]
448
+ if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
449
+ else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
450
+ )
451
+
452
+ if isinstance(teacher_model, str):
453
+ if args.use_liger:
454
+ teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
455
+ else:
456
+ teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
457
+
458
+ # Disable dropout in the model
459
+ if args.disable_dropout:
460
+ disable_dropout_in_model(self.model)
461
+
462
+ if self.is_deepspeed_enabled:
463
+ self.teacher_model = self._prepare_deepspeed(teacher_model)
464
+ else:
465
+ self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
466
+
467
+ self.lmbda = args.lmbda
468
+ self.beta = args.beta
469
+ self.temperature = args.temperature
470
+ self.seq_kd = args.seq_kd
471
+
472
+ self.generation_config = GenerationConfig(
473
+ max_new_tokens=args.max_new_tokens,
474
+ temperature=args.temperature,
475
+ do_sample=True,
476
+ top_k=0,
477
+ use_cache=False if args.gradient_checkpointing else True,
478
+ pad_token_id=self.processing_class.pad_token_id,
479
+ )
480
+ # Set custom EOS tokens if they are specified by the model's generation
481
+ # config. This is important for models with the Llama 3 chat template,
482
+ # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
483
+ # turns or messages.
484
+ if (
485
+ hasattr(self.model.generation_config, "eos_token_id")
486
+ and self.model.generation_config.eos_token_id is not None
487
+ ):
488
+ self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
489
+
490
+ def _prepare_dataset(self, dataset, *args):
491
+ # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
492
+ # need to keep the messages column as it is. We use the following workaround to keep the messages column.
493
+ dataset = dataset.add_column("_messages", dataset["messages"])
494
+ dataset = super()._prepare_dataset(dataset, *args)
495
+ dataset = dataset.rename_column("_messages", "messages")
496
+ return dataset
497
+
498
+ @staticmethod
499
+ def generalized_jsd_loss(
500
+ student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
501
+ ):
502
+ """
503
+ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
504
+ of https://huggingface.co/papers/2306.13649 for the definition.
505
+
506
+ Args:
507
+ student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
508
+ teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
509
+ labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
510
+ beta: Interpolation coefficient between 0 and 1 (default: 0.5)
511
+ temperature: Softmax temperature (default: 1.0)
512
+ reduction: Specifies the reduction to apply to the output (default: 'batchmean')
513
+
514
+ Returns:
515
+ loss: Scalar tensor with the generalized JSD loss
516
+ """
517
+
518
+ # Apply temperature scaling
519
+ student_logits = student_logits / temperature
520
+ teacher_logits = teacher_logits / temperature
521
+
522
+ # Compute log probabilities for student and probabilities for teacher
523
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
524
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
525
+
526
+ # Compute the log of the mixture distribution
527
+ # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
528
+ beta = torch.tensor(beta, dtype=student_log_probs.dtype)
529
+ mixture_log_probs = torch.logsumexp(
530
+ torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
531
+ dim=0,
532
+ )
533
+
534
+ # Compute KL divergences using F.kl_div
535
+ # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
536
+ kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
537
+ kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
538
+
539
+ # Compute the Generalized Jensen-Shannon Divergence
540
+ jsd = beta * kl_teacher + (1 - beta) * kl_student
541
+
542
+ # Masking
543
+ if labels is not None:
544
+ mask = labels != -100
545
+ jsd = jsd[mask]
546
+
547
+ # Apply reduction
548
+ if reduction == "batchmean":
549
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
550
+ elif reduction == "sum":
551
+ return jsd.sum()
552
+ elif reduction == "mean":
553
+ return jsd.mean()
554
+ else:
555
+ return jsd
556
+
557
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
558
+ # compute student output
559
+ outputs_student = model(
560
+ input_ids=inputs["input_ids"],
561
+ attention_mask=inputs["attention_mask"],
562
+ )
563
+
564
+ # compute teacher output in eval mode
565
+ self.teacher_model.eval()
566
+ with torch.no_grad():
567
+ outputs_teacher = self.teacher_model(
568
+ input_ids=inputs["input_ids"],
569
+ attention_mask=inputs["attention_mask"],
570
+ )
571
+
572
+ # slice the logits for the generated tokens using the inputs["prompts"] lengths
573
+ prompt_lengths = inputs["prompts"].shape[1]
574
+ shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
575
+ shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
576
+ shifted_labels = inputs["labels"][:, prompt_lengths:]
577
+
578
+ # compute loss
579
+ loss = self.generalized_jsd_loss(
580
+ student_logits=shifted_student_logits,
581
+ teacher_logits=shifted_teacher_logits,
582
+ labels=shifted_labels,
583
+ beta=self.beta,
584
+ )
585
+
586
+ # empty cache
587
+ empty_cache()
588
+
589
+ # Return loss
590
+ return (loss, outputs_student) if return_outputs else loss
591
+
592
+ @staticmethod
593
+ def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
594
+ # Generate output with respect to the prompt only
595
+ generated_outputs = model.generate(
596
+ input_ids=inputs["prompts"],
597
+ attention_mask=inputs.get("prompt_attention_mask", None),
598
+ generation_config=generation_config,
599
+ return_dict_in_generate=True,
600
+ )
601
+
602
+ # Get the generated token IDs
603
+ generated_tokens = generated_outputs.sequences
604
+ # Calculate new attention mask
605
+ new_attention_mask = torch.ones_like(generated_tokens)
606
+ new_labels = generated_tokens.clone()
607
+
608
+ # If there's pad_token_id, set attention mask to 0 for padding tokens
609
+ if pad_token_id is not None:
610
+ new_labels[new_labels == pad_token_id] = -100
611
+ new_attention_mask[generated_tokens == pad_token_id] = 0
612
+
613
+ return generated_tokens, new_attention_mask, new_labels
614
+
615
+ def training_step(
616
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
617
+ ) -> torch.Tensor:
618
+ """
619
+ Perform a training step for the Generalized Knowledge Distillation (GKD) model.
620
+
621
+ This method implements the on-policy learning approach described in the GKD paper.
622
+ With probability `self.lmbda`, it generates new responses using the student model,
623
+ which are then used for training instead of the original inputs.
624
+ """
625
+ if self.seq_kd:
626
+ with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
627
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
628
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
629
+ )
630
+ inputs["input_ids"] = new_input_ids
631
+ inputs["attention_mask"] = new_attention_mask
632
+ inputs["labels"] = new_labels
633
+ if random.random() <= self.lmbda:
634
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
635
+ new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
636
+ unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
637
+ )
638
+ inputs["input_ids"] = new_input_ids
639
+ inputs["attention_mask"] = new_attention_mask
640
+ inputs["labels"] = new_labels
641
+
642
+ loss = super().training_step(model, inputs, num_items_in_batch)
643
+ return loss
644
+
645
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
646
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
647
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
648
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
649
+
650
+ if model is not None:
651
+ if hasattr(model, "config"):
652
+ hidden_size = (
653
+ max(model.config.hidden_sizes)
654
+ if getattr(model.config, "hidden_sizes", None)
655
+ else getattr(model.config, "hidden_size", None)
656
+ )
657
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
658
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
659
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
660
+ config_kwargs.update(
661
+ {
662
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
663
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
664
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
665
+ }
666
+ )
667
+
668
+ # If ZeRO-3 is used, we shard both the active and reference model.
669
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
670
+ if config_kwargs["zero_optimization"]["stage"] != 3:
671
+ config_kwargs["zero_optimization"]["stage"] = 0
672
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
673
+ model.eval()
674
+ return model
675
+
676
+ def create_model_card(
677
+ self,
678
+ model_name: Optional[str] = None,
679
+ dataset_name: Optional[str] = None,
680
+ tags: Union[str, list[str], None] = None,
681
+ ):
682
+ """
683
+ Creates a draft of a model card using the information available to the `Trainer`.
684
+
685
+ Args:
686
+ model_name (`str` or `None`, *optional*, defaults to `None`):
687
+ Name of the model.
688
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
689
+ Name of the dataset used for training.
690
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
691
+ Tags to be associated with the model card.
692
+ """
693
+ if not self.is_world_process_zero():
694
+ return
695
+
696
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
697
+ base_model = self.model.config._name_or_path
698
+ else:
699
+ base_model = None
700
+
701
+ tags = tags or []
702
+ if isinstance(tags, str):
703
+ tags = [tags]
704
+
705
+ if hasattr(self.model.config, "unsloth_version"):
706
+ tags.append("unsloth")
707
+
708
+ citation = textwrap.dedent("""\
709
+ @inproceedings{agarwal2024on-policy,
710
+ title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
711
+ author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
712
+ year = 2024,
713
+ booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
714
+ publisher = {OpenReview.net},
715
+ url = {https://openreview.net/forum?id=3zKtaqxLhW},
716
+ }""")
717
+
718
+ model_card = generate_model_card(
719
+ base_model=base_model,
720
+ model_name=model_name,
721
+ hub_model_id=self.hub_model_id,
722
+ dataset_name=dataset_name,
723
+ tags=tags,
724
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
725
+ comet_url=get_comet_experiment_url(),
726
+ trainer_name="GKD",
727
+ trainer_citation=citation,
728
+ paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
729
+ paper_id="2306.13649",
730
+ )
731
+
732
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
733
+ class UnslothGKDTrainer(_UnslothGKDTrainer):
734
+ """
735
+
736
+ """
737
+ def __init__(
738
+ self,
739
+ model = None,
740
+ teacher_model = None,
741
+ args = None,
742
+ data_collator = None,
743
+ train_dataset = None,
744
+ eval_dataset = None,
745
+ processing_class = None,
746
+ compute_metrics = None,
747
+ callbacks = None,
748
+ preprocess_logits_for_metrics = None,
749
+ peft_config = None,
750
+ formatting_func = None,
751
+ **kwargs
752
+ ):
753
+ if args is None: args = UnslothGKDConfig()
754
+ use_bf16 = getattr(args, 'bf16', False)
755
+ use_fp16 = getattr(args, 'fp16', False)
756
+ force_float32 = False
757
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
758
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
759
+ force_float32 = True
760
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
761
+ dtype = getattr(model.config, 'torch_dtype', None)
762
+ if dtype is None: dtype = model.get_input_embeddings().dtype
763
+ from unsloth_zoo.utils import _get_dtype
764
+ dtype = _get_dtype(dtype)
765
+ float16 = dtype == torch.float16
766
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
767
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
768
+ if force_float32:
769
+ args.fp16 = False
770
+ args.bf16 = False
771
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
772
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
773
+ args.fp16 = float16
774
+ args.bf16 = not float16
775
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
776
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
777
+ args.eval_strategy = 'steps'
778
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
779
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
780
+ if ga_steps is not None and ga_steps > 1:
781
+ from transformers import __version__ as transformers_version
782
+ if Version(transformers_version) <= Version('4.45.2'):
783
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
784
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
785
+ if getattr(args, 'eval_strategy', 'no') != 'no':
786
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
787
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
788
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
789
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
790
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
791
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
792
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
793
+ if force_float32:
794
+ args.bf16_full_eval = False
795
+ args.fp16_full_eval = False
796
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
797
+ args.bf16_full_eval = True
798
+ args.fp16_full_eval = False
799
+ elif not bf16_full_eval and not fp16_full_eval:
800
+ args.bf16_full_eval = args.bf16
801
+ args.fp16_full_eval = args.fp16
802
+ _output_logits = False
803
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
804
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
805
+ if _output_logits:
806
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
807
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
808
+ pass
809
+ else:
810
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
811
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
812
+ if args_max_seq_length is None and model_max_seq_length is not None:
813
+ max_seq_length = model.max_seq_length
814
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
815
+ if model is not None and hasattr(model, 'for_training'):
816
+ model.for_training()
817
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
818
+ if 'processing_class' in locals():
819
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
820
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
821
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
822
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
823
+ if not isinstance(data_collator, UnslothVisionDataCollator):
824
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
825
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
826
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
827
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
828
+ else:
829
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
830
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
831
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
832
+ if not isinstance(data_collator, UnslothVisionDataCollator):
833
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
834
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
835
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
836
+ else:
837
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
838
+ other_metrics = []
839
+
840
+ from unsloth_zoo.logging_utils import PatchRLStatistics
841
+ PatchRLStatistics('gkd_trainer', other_metrics)
842
+
843
+ super().__init__(
844
+ model = model,
845
+ teacher_model = teacher_model,
846
+ args = args,
847
+ data_collator = data_collator,
848
+ train_dataset = train_dataset,
849
+ eval_dataset = eval_dataset,
850
+ processing_class = processing_class,
851
+ compute_metrics = compute_metrics,
852
+ callbacks = callbacks,
853
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
854
+ peft_config = peft_config,
855
+ formatting_func = formatting_func,**kwargs)
856
+ if hasattr(self, 'neftune_hook_handle'):
857
+ self.neftune_hook_handle.remove()
858
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
859
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
860
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
861
+ pass
862
+
863
+ pass
unsloth_compiled_cache/UnslothGRPOTrainer.py ADDED
@@ -0,0 +1,1438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, GRPOTrainer, Trainer, gather, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+
43
+ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
44
+ # All Unsloth Zoo code licensed under LGPLv3
45
+ old_logits = old_logits.to(torch.float32)
46
+ new_logits = new_logits.to(torch.float32)
47
+ input_ids = input_ids.unsqueeze(-1)
48
+
49
+ # x_i - logsumexp(x_i)
50
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
51
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
52
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
53
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
54
+
55
+ # Reverse KL
56
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
57
+ # Full correct reverse KL divergence?? Missing term maybe?
58
+ # kl_i = torch.exp(new) * kl_i
59
+
60
+ # Below is forward KL (normal KL)
61
+ # kl_i = torch.exp(old) * (old - new)
62
+
63
+ # Must detach - otherwise gradients are not propagated correctly!
64
+ # exp(x - x) == 1
65
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
66
+ loss_i = -(loss_i - beta * kl_i)
67
+
68
+ mask = mask.to(torch.float32)
69
+ n_mask_per_reward = mask.sum(1)
70
+
71
+ # See https://github.com/huggingface/trl/pull/2881
72
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
73
+ loss = loss_per_reward.mean()
74
+ # loss = (loss_i * mask).sum() / mask.sum()
75
+
76
+ # Get metrics as well which are folded
77
+ with torch.inference_mode():
78
+ completion_length = n_mask_per_reward.mean()
79
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
80
+ mean_kl = mean_kl_per_reward.mean()
81
+ pass
82
+ return loss, completion_length, mean_kl
83
+
84
+ class UnslothEfficientGRPO(torch.autograd.Function):
85
+ # All Unsloth Zoo code licensed under LGPLv3
86
+ @staticmethod
87
+ def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
88
+ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
89
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
90
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
91
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
92
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
93
+ loss, completion_length, mean_kl = grpo_compute_loss(
94
+ old_logits, new_logits, input_ids, mask, beta, advantages,
95
+ )
96
+ # Scale loss if needed for mixed precision training
97
+ scaled_loss = loss * scaling
98
+ # Must add .loss.detach otherwise autograd uses 2x VRAM
99
+ return scaled_loss, (loss.detach(), completion_length, mean_kl,)
100
+ pass
101
+
102
+ device =_new_hidden_states.device
103
+ grad_inputs = torch.empty_like(_new_hidden_states)
104
+ accumulated_loss = torch.zeros(1, device = device)
105
+ accumulated_completion_length = torch.zeros(1, device = device)
106
+ accumulated_mean_kl = torch.zeros(1, device = device)
107
+
108
+ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
109
+ (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
110
+ compute_loss,
111
+ argnums = (0,),
112
+ has_aux = True,
113
+ )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
114
+ accumulated_loss .add_(unscaled_loss)
115
+ accumulated_completion_length.add_(chunk_completion_length)
116
+ accumulated_mean_kl .add_(chunk_mean_kl)
117
+ return chunk_grad_input
118
+ pass
119
+
120
+ accumulate_chunk = torch.compile(
121
+ accumulate_chunk,
122
+ fullgraph = True,
123
+ options = torch_compile_options,
124
+ )
125
+
126
+ grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
127
+ new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
128
+ old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
129
+ input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
130
+ mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
131
+ advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
132
+
133
+ # Get mixed precision scaling if seen
134
+ scaling = scaler.get_scale() if scaler is not None else 1.0
135
+
136
+ # Force torch.compile to use dynamic shapes for seqlen dim
137
+ mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
138
+
139
+ for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
140
+ zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
141
+
142
+ mark_dynamic(new_hidden_states_j)
143
+ mark_dynamic(old_hidden_states_j)
144
+ mark_dynamic(input_ids_j)
145
+ mark_dynamic(mask_j)
146
+
147
+ grad_inputs_j.copy_(
148
+ accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
149
+ )
150
+ pass
151
+
152
+ grad_inputs .div_(n_chunks)
153
+ accumulated_loss .div_(n_chunks)
154
+ accumulated_completion_length.div_(n_chunks)
155
+ accumulated_mean_kl .div_(n_chunks)
156
+ ctx.save_for_backward(grad_inputs)
157
+
158
+ return (
159
+ accumulated_loss,
160
+ accumulated_completion_length,
161
+ accumulated_mean_kl,
162
+ )
163
+ pass
164
+
165
+ @staticmethod
166
+ def backward(ctx, grad_output, dcompletion_length, dmean_kl):
167
+ (grad_input,) = ctx.saved_tensors
168
+ return (grad_input, None, None, None, None, None, None, None, None,)
169
+ pass
170
+
171
+ def grpo_accumulated_loss(
172
+ trainer,
173
+ input_ids,
174
+ logits_to_keep,
175
+ completion_mask,
176
+ advantages,
177
+ n_chunks = -1,
178
+ ):
179
+ # All Unsloth Zoo code licensed under LGPLv3
180
+ bsz, qlen = input_ids.shape
181
+ # Find closest multiple
182
+ factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
183
+ if n_chunks == -1: n_chunks = bsz
184
+ n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
185
+
186
+ mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
187
+ os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
188
+
189
+ completion_input_ids = input_ids[:, -logits_to_keep:]
190
+ lm_head = trainer.model.get_output_embeddings().weight
191
+
192
+ with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
193
+ with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
194
+ old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
195
+ pass
196
+
197
+ new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
198
+
199
+ loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
200
+ new_hidden_states, old_hidden_states, lm_head,
201
+ completion_input_ids, completion_mask, advantages, trainer.beta,
202
+ trainer.accelerator.scaler,
203
+ n_chunks,
204
+ )
205
+ return loss, completion_length, mean_kl
206
+
207
+ # Old non efficient code path
208
+ new_logits = torch.matmul(new_hidden_states, lm_head.t())
209
+ new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
210
+ old_logits = torch.matmul(old_hidden_states, lm_head.t())
211
+ old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
212
+ loss, completion_length, mean_kl = grpo_compute_loss(
213
+ old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
214
+ )
215
+ return loss, completion_length, mean_kl
216
+ pass
217
+
218
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
219
+ def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
220
+ # All Unsloth Zoo code licensed under LGPLv3
221
+ old_logits = old_logits.to(torch.float32)
222
+ new_logits = new_logits.to(torch.float32)
223
+ input_ids = input_ids.unsqueeze(-1)
224
+
225
+ # x_i - logsumexp(x_i)
226
+ old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
227
+ new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
228
+ old = old_x - torch.logsumexp(old_logits, dim = -1)
229
+ new = new_x - torch.logsumexp(new_logits, dim = -1)
230
+
231
+ # Reverse KL
232
+ kl_i = torch.exp(old - new) - (old - new) - 1.0
233
+ # Full correct reverse KL divergence?? Missing term maybe?
234
+ # kl_i = torch.exp(new) * kl_i
235
+
236
+ # Below is forward KL (normal KL)
237
+ # kl_i = torch.exp(old) * (old - new)
238
+
239
+ # Must detach - otherwise gradients are not propagated correctly!
240
+ # exp(x - x) == 1
241
+ loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
242
+ loss_i = -(loss_i - beta * kl_i)
243
+
244
+ mask = mask.to(torch.float32)
245
+ n_mask_per_reward = mask.sum(1)
246
+
247
+ # See https://github.com/huggingface/trl/pull/2881
248
+ loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
249
+ loss = loss_per_reward.mean()
250
+ # loss = (loss_i * mask).sum() / mask.sum()
251
+
252
+ # Get metrics as well which are folded
253
+ with torch.inference_mode():
254
+ completion_length = n_mask_per_reward.mean()
255
+ mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
256
+ mean_kl = mean_kl_per_reward.mean()
257
+ pass
258
+ return loss, completion_length, mean_kl
259
+
260
+ def vLLMSamplingParams(**kwargs):
261
+ from vllm import SamplingParams
262
+ sampling_params = SamplingParams(**kwargs)
263
+ sampling_params._set_kwargs = kwargs
264
+ return sampling_params
265
+ @dataclass
266
+ class UnslothGRPOConfig(GRPOConfig):
267
+ """
268
+
269
+ Configuration class for the [`GRPOTrainer`].
270
+
271
+ Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
272
+ [`~transformers.TrainingArguments`] documentation.
273
+
274
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
275
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
276
+ command line.
277
+
278
+ Parameters:
279
+ > Parameters that control the model and reference model
280
+
281
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
282
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
283
+ argument of the [`GRPOTrainer`] is provided as a string.
284
+
285
+ > Parameters that control the data preprocessing
286
+
287
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
288
+ Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
289
+ requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
290
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
291
+ Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
292
+ num_generations (`int` or `None`, *optional*, defaults to `8`):
293
+ Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
294
+ must be divisible by this value.
295
+ temperature (`float`, *optional*, defaults to `0.9`):
296
+ Temperature for sampling. The higher the temperature, the more random the completions.
297
+ max_completion_length (`int` or `None`, *optional*, defaults to `256`):
298
+ Maximum length of the generated completion.
299
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
300
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
301
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
302
+ capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
303
+ with vLLM generation.
304
+
305
+ > Parameters that control generation acceleration powered by vLLM
306
+
307
+ use_vllm (`bool`, *optional*, defaults to `False`):
308
+ Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
309
+ training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
310
+ vllm_device (`str`, *optional*, defaults to `"auto"`):
311
+ Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
312
+ automatically select the next available GPU after the last one used for training. This assumes that
313
+ training has not already occupied all available GPUs. If only one device is available, the device will be
314
+ shared between both training and vLLM.
315
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
316
+ Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
317
+ device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
318
+ improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
319
+ during initialization.
320
+ vllm_dtype (`str`, *optional*, defaults to `"auto"`):
321
+ Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
322
+ based on the model configuration. Find the supported values in the vLLM documentation.
323
+ vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
324
+ If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
325
+ `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
326
+ context size, which might be much larger than the KV cache, leading to inefficiencies.
327
+
328
+ > Parameters that control the training
329
+
330
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
331
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
332
+ [`~transformers.TrainingArguments`].
333
+ beta (`float`, *optional*, defaults to `0.04`):
334
+ KL coefficient.
335
+ reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
336
+ Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
337
+ weighted equally with weight `1.0`.
338
+ sync_ref_model (`bool`, *optional*, defaults to `False`):
339
+ Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
340
+ the `ref_model_mixup_alpha` parameter. This synchronization originites from the
341
+ [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
342
+ ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
343
+ α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
344
+ between the current policy and the previous reference policy during updates. The reference policy is
345
+ updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
346
+ must set `sync_ref_model=True`.
347
+ ref_model_sync_steps (`int`, *optional*, defaults to `64`):
348
+ τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
349
+ frequently the current policy is synchronized with the reference policy. To use this parameter, you must
350
+ set `sync_ref_model=True`.
351
+
352
+ > Parameters that control the logging
353
+
354
+ log_completions (`bool`, *optional*, defaults to `False`):
355
+ Whether to log the completions during training.
356
+
357
+ """
358
+ vllm_sampling_params: Optional[Any] = field(
359
+ default = None,
360
+ metadata = {'help': 'vLLM SamplingParams'},
361
+ )
362
+ unsloth_num_chunks : Optional[int] = field(
363
+ default = -1,
364
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
365
+ )
366
+ def __init__(
367
+ self,
368
+ output_dir = None,
369
+ overwrite_output_dir = None,
370
+ do_train = False,
371
+ do_eval = False,
372
+ do_predict = False,
373
+ eval_strategy = 'no',
374
+ prediction_loss_only = False,
375
+ per_device_train_batch_size = 4,
376
+ per_device_eval_batch_size = 4,
377
+ per_gpu_train_batch_size = None,
378
+ per_gpu_eval_batch_size = None,
379
+ gradient_accumulation_steps = 2,
380
+ eval_accumulation_steps = 2,
381
+ eval_delay = 0,
382
+ torch_empty_cache_steps = 250,
383
+ learning_rate = 5e-05,
384
+ weight_decay = 0.01,
385
+ adam_beta1 = 0.9,
386
+ adam_beta2 = 0.999,
387
+ adam_epsilon = 1e-08,
388
+ max_grad_norm = 1.0,
389
+ num_train_epochs = 3.0,
390
+ max_steps = -1,
391
+ lr_scheduler_type = 'linear',
392
+ warmup_ratio = 0.1,
393
+ warmup_steps = 0,
394
+ log_level = 'passive',
395
+ log_level_replica = 'warning',
396
+ log_on_each_node = True,
397
+ logging_dir = None,
398
+ logging_strategy = 'steps',
399
+ logging_first_step = False,
400
+ logging_steps = 1,
401
+ logging_nan_inf_filter = False,
402
+ save_strategy = 'steps',
403
+ save_steps = 500,
404
+ save_total_limit = None,
405
+ save_safetensors = True,
406
+ save_on_each_node = False,
407
+ save_only_model = False,
408
+ restore_callback_states_from_checkpoint = False,
409
+ no_cuda = False,
410
+ use_cpu = False,
411
+ use_mps_device = False,
412
+ seed = 3407,
413
+ data_seed = 3407,
414
+ jit_mode_eval = False,
415
+ use_ipex = False,
416
+ bf16 = False,
417
+ fp16 = False,
418
+ fp16_opt_level = 'O1',
419
+ half_precision_backend = 'auto',
420
+ bf16_full_eval = False,
421
+ fp16_full_eval = False,
422
+ tf32 = None,
423
+ local_rank = -1,
424
+ ddp_backend = None,
425
+ tpu_num_cores = None,
426
+ tpu_metrics_debug = False,
427
+ debug = '',
428
+ dataloader_drop_last = False,
429
+ eval_steps = None,
430
+ dataloader_num_workers = 0,
431
+ dataloader_prefetch_factor = None,
432
+ past_index = -1,
433
+ run_name = None,
434
+ disable_tqdm = None,
435
+ remove_unused_columns = False,
436
+ label_names = None,
437
+ load_best_model_at_end = False,
438
+ metric_for_best_model = None,
439
+ greater_is_better = None,
440
+ ignore_data_skip = False,
441
+ fsdp = '',
442
+ fsdp_min_num_params = 0,
443
+ fsdp_config = None,
444
+ tp_size = 0,
445
+ fsdp_transformer_layer_cls_to_wrap = None,
446
+ accelerator_config = None,
447
+ deepspeed = None,
448
+ label_smoothing_factor = 0.0,
449
+ optim = 'adamw_8bit',
450
+ optim_args = None,
451
+ adafactor = False,
452
+ group_by_length = False,
453
+ length_column_name = 'length',
454
+ report_to = None,
455
+ ddp_find_unused_parameters = None,
456
+ ddp_bucket_cap_mb = None,
457
+ ddp_broadcast_buffers = None,
458
+ dataloader_pin_memory = True,
459
+ dataloader_persistent_workers = False,
460
+ skip_memory_metrics = True,
461
+ use_legacy_prediction_loop = False,
462
+ push_to_hub = False,
463
+ resume_from_checkpoint = None,
464
+ hub_model_id = None,
465
+ hub_strategy = 'every_save',
466
+ hub_token = None,
467
+ hub_private_repo = None,
468
+ hub_always_push = False,
469
+ gradient_checkpointing = False,
470
+ gradient_checkpointing_kwargs = None,
471
+ include_inputs_for_metrics = False,
472
+ eval_do_concat_batches = True,
473
+ fp16_backend = 'auto',
474
+ evaluation_strategy = None,
475
+ push_to_hub_model_id = None,
476
+ push_to_hub_organization = None,
477
+ push_to_hub_token = None,
478
+ mp_parameters = '',
479
+ auto_find_batch_size = False,
480
+ full_determinism = False,
481
+ torchdynamo = None,
482
+ ray_scope = 'last',
483
+ ddp_timeout = 1800,
484
+ torch_compile = False,
485
+ torch_compile_backend = None,
486
+ torch_compile_mode = None,
487
+ dispatch_batches = None,
488
+ split_batches = None,
489
+ include_tokens_per_second = False,
490
+ include_num_input_tokens_seen = False,
491
+ neftune_noise_alpha = None,
492
+ optim_target_modules = None,
493
+ batch_eval_metrics = False,
494
+ eval_on_start = False,
495
+ use_liger_kernel = False,
496
+ eval_use_gather_object = False,
497
+ average_tokens_across_devices = False,
498
+ model_init_kwargs = None,
499
+ max_prompt_length = 512,
500
+ num_generations = 8,
501
+ temperature = 0.9,
502
+ max_completion_length = 256,
503
+ ds3_gather_for_generation = True,
504
+ use_vllm = False,
505
+ vllm_device = 'auto',
506
+ vllm_gpu_memory_utilization = 0.9,
507
+ vllm_dtype = 'auto',
508
+ vllm_max_model_len = None,
509
+ beta = 0.04,
510
+ reward_weights = None,
511
+ sync_ref_model = False,
512
+ ref_model_mixup_alpha = 0.9,
513
+ ref_model_sync_steps = 64,
514
+ log_completions = False,
515
+ vllm_sampling_params = None,
516
+ unsloth_num_chunks = -1,
517
+ **kwargs,
518
+ ):
519
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
520
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
521
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
522
+ output_dir = 'unsloth_training_checkpoints'
523
+ save_strategy = 'no'
524
+ div = per_device_train_batch_size // num_generations
525
+ if div * num_generations != per_device_train_batch_size:
526
+ print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
527
+ per_device_train_batch_size = num_generations
528
+
529
+ super().__init__(
530
+ output_dir = output_dir,
531
+ overwrite_output_dir = overwrite_output_dir,
532
+ do_train = do_train,
533
+ do_eval = do_eval,
534
+ do_predict = do_predict,
535
+ eval_strategy = eval_strategy,
536
+ prediction_loss_only = prediction_loss_only,
537
+ per_device_train_batch_size = per_device_train_batch_size,
538
+ per_device_eval_batch_size = per_device_eval_batch_size,
539
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
540
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
541
+ gradient_accumulation_steps = gradient_accumulation_steps,
542
+ eval_accumulation_steps = eval_accumulation_steps,
543
+ eval_delay = eval_delay,
544
+ torch_empty_cache_steps = torch_empty_cache_steps,
545
+ learning_rate = learning_rate,
546
+ weight_decay = weight_decay,
547
+ adam_beta1 = adam_beta1,
548
+ adam_beta2 = adam_beta2,
549
+ adam_epsilon = adam_epsilon,
550
+ max_grad_norm = max_grad_norm,
551
+ num_train_epochs = num_train_epochs,
552
+ max_steps = max_steps,
553
+ lr_scheduler_type = lr_scheduler_type,
554
+ warmup_ratio = warmup_ratio,
555
+ warmup_steps = warmup_steps,
556
+ log_level = log_level,
557
+ log_level_replica = log_level_replica,
558
+ log_on_each_node = log_on_each_node,
559
+ logging_dir = logging_dir,
560
+ logging_strategy = logging_strategy,
561
+ logging_first_step = logging_first_step,
562
+ logging_steps = logging_steps,
563
+ logging_nan_inf_filter = logging_nan_inf_filter,
564
+ save_strategy = save_strategy,
565
+ save_steps = save_steps,
566
+ save_total_limit = save_total_limit,
567
+ save_safetensors = save_safetensors,
568
+ save_on_each_node = save_on_each_node,
569
+ save_only_model = save_only_model,
570
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
571
+ no_cuda = no_cuda,
572
+ use_cpu = use_cpu,
573
+ use_mps_device = use_mps_device,
574
+ seed = seed,
575
+ data_seed = data_seed,
576
+ jit_mode_eval = jit_mode_eval,
577
+ use_ipex = use_ipex,
578
+ bf16 = bf16,
579
+ fp16 = fp16,
580
+ fp16_opt_level = fp16_opt_level,
581
+ half_precision_backend = half_precision_backend,
582
+ bf16_full_eval = bf16_full_eval,
583
+ fp16_full_eval = fp16_full_eval,
584
+ tf32 = tf32,
585
+ local_rank = local_rank,
586
+ ddp_backend = ddp_backend,
587
+ tpu_num_cores = tpu_num_cores,
588
+ tpu_metrics_debug = tpu_metrics_debug,
589
+ debug = debug,
590
+ dataloader_drop_last = dataloader_drop_last,
591
+ eval_steps = eval_steps,
592
+ dataloader_num_workers = dataloader_num_workers,
593
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
594
+ past_index = past_index,
595
+ run_name = run_name,
596
+ disable_tqdm = disable_tqdm,
597
+ remove_unused_columns = remove_unused_columns,
598
+ label_names = label_names,
599
+ load_best_model_at_end = load_best_model_at_end,
600
+ metric_for_best_model = metric_for_best_model,
601
+ greater_is_better = greater_is_better,
602
+ ignore_data_skip = ignore_data_skip,
603
+ fsdp = fsdp,
604
+ fsdp_min_num_params = fsdp_min_num_params,
605
+ fsdp_config = fsdp_config,
606
+ tp_size = tp_size,
607
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
608
+ accelerator_config = accelerator_config,
609
+ deepspeed = deepspeed,
610
+ label_smoothing_factor = label_smoothing_factor,
611
+ optim = optim,
612
+ optim_args = optim_args,
613
+ adafactor = adafactor,
614
+ group_by_length = group_by_length,
615
+ length_column_name = length_column_name,
616
+ report_to = report_to,
617
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
618
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
619
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
620
+ dataloader_pin_memory = dataloader_pin_memory,
621
+ dataloader_persistent_workers = dataloader_persistent_workers,
622
+ skip_memory_metrics = skip_memory_metrics,
623
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
624
+ push_to_hub = push_to_hub,
625
+ resume_from_checkpoint = resume_from_checkpoint,
626
+ hub_model_id = hub_model_id,
627
+ hub_strategy = hub_strategy,
628
+ hub_token = hub_token,
629
+ hub_private_repo = hub_private_repo,
630
+ hub_always_push = hub_always_push,
631
+ gradient_checkpointing = gradient_checkpointing,
632
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
633
+ include_inputs_for_metrics = include_inputs_for_metrics,
634
+ eval_do_concat_batches = eval_do_concat_batches,
635
+ fp16_backend = fp16_backend,
636
+ evaluation_strategy = evaluation_strategy,
637
+ push_to_hub_model_id = push_to_hub_model_id,
638
+ push_to_hub_organization = push_to_hub_organization,
639
+ push_to_hub_token = push_to_hub_token,
640
+ mp_parameters = mp_parameters,
641
+ auto_find_batch_size = auto_find_batch_size,
642
+ full_determinism = full_determinism,
643
+ torchdynamo = torchdynamo,
644
+ ray_scope = ray_scope,
645
+ ddp_timeout = ddp_timeout,
646
+ torch_compile = torch_compile,
647
+ torch_compile_backend = torch_compile_backend,
648
+ torch_compile_mode = torch_compile_mode,
649
+ dispatch_batches = dispatch_batches,
650
+ split_batches = split_batches,
651
+ include_tokens_per_second = include_tokens_per_second,
652
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
653
+ neftune_noise_alpha = neftune_noise_alpha,
654
+ optim_target_modules = optim_target_modules,
655
+ batch_eval_metrics = batch_eval_metrics,
656
+ eval_on_start = eval_on_start,
657
+ use_liger_kernel = use_liger_kernel,
658
+ eval_use_gather_object = eval_use_gather_object,
659
+ average_tokens_across_devices = average_tokens_across_devices,
660
+ model_init_kwargs = model_init_kwargs,
661
+ max_prompt_length = max_prompt_length,
662
+ num_generations = num_generations,
663
+ temperature = temperature,
664
+ max_completion_length = max_completion_length,
665
+ ds3_gather_for_generation = ds3_gather_for_generation,
666
+ use_vllm = use_vllm,
667
+ vllm_device = vllm_device,
668
+ vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
669
+ vllm_dtype = vllm_dtype,
670
+ vllm_max_model_len = vllm_max_model_len,
671
+ beta = beta,
672
+ reward_weights = reward_weights,
673
+ sync_ref_model = sync_ref_model,
674
+ ref_model_mixup_alpha = ref_model_mixup_alpha,
675
+ ref_model_sync_steps = ref_model_sync_steps,
676
+ log_completions = log_completions,**kwargs)
677
+ self.vllm_sampling_params = vllm_sampling_params
678
+ self.unsloth_num_chunks = unsloth_num_chunks
679
+ pass
680
+
681
+ class _UnslothGRPOTrainer(Trainer):
682
+ """"""
683
+
684
+ _tag_names = ["trl", "grpo"]
685
+
686
+ def __init__(
687
+ self,
688
+ model: Union[str, PreTrainedModel],
689
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
690
+ args: GRPOConfig = None,
691
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
692
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
693
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
694
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
695
+ callbacks: Optional[list[TrainerCallback]] = None,
696
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
697
+ peft_config: Optional["PeftConfig"] = None,
698
+ ):
699
+
700
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
701
+ # Args
702
+ if args is None:
703
+ model_name = model if isinstance(model, str) else model.config._name_or_path
704
+ model_name = model_name.split("/")[-1]
705
+ args = GRPOConfig(f"{model_name}-GRPO")
706
+
707
+ # Models
708
+ # Trained model
709
+ model_init_kwargs = args.model_init_kwargs or {}
710
+ if isinstance(model, str):
711
+ model_id = model
712
+ torch_dtype = model_init_kwargs.get("torch_dtype")
713
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
714
+ pass # torch_dtype is already a torch.dtype or "auto" or None
715
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
716
+ torch_dtype = getattr(torch, torch_dtype)
717
+ model_init_kwargs["torch_dtype"] = torch_dtype
718
+ else:
719
+ raise ValueError(
720
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
721
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
722
+ )
723
+ # Disable caching if gradient checkpointing is enabled (not supported)
724
+ model_init_kwargs["use_cache"] = (
725
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
726
+ )
727
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
728
+ else:
729
+ model_id = model.config._name_or_path
730
+ if args.model_init_kwargs is not None:
731
+ raise ValueError(
732
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
733
+ "This argument can only be used when the `model` argument is a string."
734
+ )
735
+
736
+ if False:
737
+ model = model
738
+
739
+ # Reference model
740
+ if is_deepspeed_zero3_enabled():
741
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
742
+ elif not is_peft_model(model):
743
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
744
+ self.ref_model = create_reference_model(model)
745
+ else:
746
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
747
+ # to revert to the initial model.
748
+ self.ref_model = None
749
+
750
+ # Processing class
751
+ if processing_class is None:
752
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
753
+
754
+ # Reward functions
755
+ if not isinstance(reward_funcs, list):
756
+ reward_funcs = [reward_funcs]
757
+ for i, reward_func in enumerate(reward_funcs):
758
+ if isinstance(reward_func, str):
759
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
760
+ reward_func, num_labels=1, **model_init_kwargs
761
+ )
762
+ self.reward_funcs = reward_funcs
763
+
764
+ # Reward weights
765
+ if args.reward_weights is not None:
766
+ if len(args.reward_weights) != len(reward_funcs):
767
+ raise ValueError(
768
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
769
+ f"functions ({len(reward_funcs)})"
770
+ )
771
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
772
+ else:
773
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
774
+
775
+ # Reward processing class
776
+ if reward_processing_classes is None:
777
+ reward_processing_classes = [None] * len(reward_funcs)
778
+ elif not isinstance(reward_processing_classes, list):
779
+ reward_processing_classes = [reward_processing_classes]
780
+ else:
781
+ if len(reward_processing_classes) != len(reward_funcs):
782
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
783
+
784
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
785
+ if isinstance(reward_func, PreTrainedModel):
786
+ if reward_processing_class is None:
787
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
788
+ if reward_processing_class.pad_token_id is None:
789
+ reward_processing_class.pad_token = reward_processing_class.eos_token
790
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
791
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
792
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
793
+ reward_processing_classes[i] = reward_processing_class
794
+ self.reward_processing_classes = reward_processing_classes
795
+
796
+ # Data collator
797
+ def data_collator(features): # No data collation is needed in GRPO
798
+ return features
799
+
800
+ # Training arguments
801
+ self.max_prompt_length = args.max_prompt_length
802
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
803
+ self.num_generations = args.num_generations # = G in the GRPO paper
804
+ self.use_vllm = args.use_vllm
805
+
806
+ self.beta = args.beta
807
+
808
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
809
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
810
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
811
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
812
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
813
+ # This acts as a flag to indicate that the warning has already been issued.
814
+ model.warnings_issued["estimate_tokens"] = True
815
+
816
+ # Initialize the metrics
817
+ self._metrics = defaultdict(list)
818
+ self.log_completions = args.log_completions
819
+
820
+ super().__init__(
821
+ model=model,
822
+ args=args,
823
+ data_collator=data_collator,
824
+ train_dataset=train_dataset,
825
+ eval_dataset=eval_dataset,
826
+ processing_class=processing_class,
827
+ callbacks=callbacks,
828
+ optimizers=optimizers,
829
+ )
830
+
831
+ # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
832
+ num_processes = self.accelerator.num_processes
833
+ global_batch_size = args.per_device_train_batch_size * num_processes
834
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
835
+ if self.num_generations not in possible_values:
836
+ raise ValueError(
837
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
838
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
839
+ f"batch size, the valid values for the number of generations are: {possible_values}."
840
+ )
841
+ if self.args.eval_strategy != "no":
842
+ global_batch_size = args.per_device_eval_batch_size * num_processes
843
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
844
+ if self.num_generations not in possible_values:
845
+ raise ValueError(
846
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
847
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
848
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
849
+ )
850
+
851
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
852
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
853
+ # it's safer to set it in all cases.
854
+ set_seed(args.seed, device_specific=True)
855
+
856
+ if self.use_vllm:
857
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
858
+ temperature=args.temperature,
859
+ max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
860
+ else:
861
+ self.generation_config = GenerationConfig(
862
+ max_new_tokens=self.max_completion_length,
863
+ do_sample=True,
864
+ temperature=args.temperature,
865
+ pad_token_id=processing_class.pad_token_id,
866
+ )
867
+
868
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
869
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
870
+ # self.model_accepts_loss_kwargs to False to enable scaling.
871
+ self.model_accepts_loss_kwargs = False
872
+
873
+ # Add tags to the model
874
+ self.model.add_model_tags(self._tag_names)
875
+
876
+ if self.ref_model is not None:
877
+ if self.is_deepspeed_enabled:
878
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
879
+ else:
880
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
881
+
882
+ if args.sync_ref_model:
883
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
884
+
885
+ for i, reward_func in enumerate(self.reward_funcs):
886
+ if isinstance(reward_func, PreTrainedModel):
887
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
888
+
889
+ def _set_signature_columns_if_needed(self):
890
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
891
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
892
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
893
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
894
+ if self._signature_columns is None:
895
+ self._signature_columns = ["prompt"]
896
+
897
+ def _get_train_sampler(self) -> Sampler:
898
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
899
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
900
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
901
+ # preventing discrepancies in group formation.
902
+ return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
903
+
904
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
905
+ # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
906
+ # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
907
+ # within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
908
+ # preventing discrepancies in group formation.
909
+ return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
910
+
911
+ # Get the per-token log probabilities for the completions for the model and the reference model
912
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
913
+ if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
914
+ return None # Unsloth efficient GRPO
915
+ # Otherwise, calculate normally:
916
+ if not hasattr(self, '_autocast_dtype'):
917
+ self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
918
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
919
+ with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
920
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
921
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
922
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
923
+
924
+ input_ids = input_ids[:, -logits_to_keep:]
925
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
926
+ # See https://github.com/huggingface/trl/issues/2770
927
+ logits = logits[:, -logits_to_keep:]
928
+ return logits
929
+ # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
930
+ pass
931
+
932
+ def _move_model_to_vllm(self, *args, **kwargs): return None
933
+
934
+ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
935
+ device = self.accelerator.device
936
+ prompts = [x["prompt"] for x in inputs]
937
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
938
+ prompt_inputs = self.processing_class(
939
+ prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
940
+ )
941
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
942
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
943
+
944
+ if self.max_prompt_length is not None:
945
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
946
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
947
+
948
+ # Generate completions using either vLLM or regular generation
949
+ if self.args.use_vllm:
950
+ # First, have main process load weights if needed
951
+ if self.state.global_step != self._last_loaded_step:
952
+ self._move_model_to_vllm()
953
+ self._last_loaded_step = self.state.global_step
954
+
955
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
956
+ all_prompts_text = gather_object(prompts_text)
957
+ if self.accelerator.is_main_process:
958
+ outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
959
+ completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
960
+ else:
961
+ completion_ids = [None] * len(all_prompts_text)
962
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
963
+ # corresponding slice.
964
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
965
+ process_slice = slice(
966
+ self.accelerator.process_index * len(prompts),
967
+ (self.accelerator.process_index + 1) * len(prompts),
968
+ )
969
+ completion_ids = completion_ids[process_slice]
970
+
971
+ # Pad the completions, and concatenate them with the prompts
972
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
973
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
974
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
975
+ else:
976
+ # Regular generation path
977
+ with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
978
+ prompt_completion_ids = unwrapped_model.generate(
979
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
980
+ )
981
+
982
+ # Compute prompt length and extract completion ids
983
+ prompt_length = prompt_ids.size(1)
984
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
985
+ completion_ids = prompt_completion_ids[:, prompt_length:]
986
+
987
+ # Mask everything after the first EOS token
988
+ is_eos = completion_ids == self.processing_class.eos_token_id
989
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
990
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
991
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
992
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
993
+
994
+ # Concatenate prompt_mask with completion_mask for logit computation
995
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
996
+
997
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
998
+
999
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1000
+ if self.ref_model is not None:
1001
+ ref_per_token_logps = self._get_per_token_logps(
1002
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
1003
+ )
1004
+ else:
1005
+ with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
1006
+ ref_per_token_logps = self._get_per_token_logps(
1007
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
1008
+ )
1009
+
1010
+ # Decode the generated completions
1011
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1012
+ if is_conversational(inputs[0]):
1013
+ completions = []
1014
+ for prompt, completion in zip(prompts, completions_text):
1015
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1016
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
1017
+ else:
1018
+ completions = completions_text
1019
+
1020
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
1021
+ for i, (reward_func, reward_processing_class) in enumerate(
1022
+ zip(self.reward_funcs, self.reward_processing_classes)
1023
+ ):
1024
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1025
+ if is_conversational(inputs[0]):
1026
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
1027
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1028
+ else:
1029
+ texts = [p + c for p, c in zip(prompts, completions)]
1030
+ reward_inputs = reward_processing_class(
1031
+ texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1032
+ )
1033
+ reward_inputs = super()._prepare_inputs(reward_inputs)
1034
+ with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
1035
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1036
+ else:
1037
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1038
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
1039
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1040
+ output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1041
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1042
+
1043
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1044
+ # completions may be distributed across processes
1045
+ rewards_per_func = gather(rewards_per_func)
1046
+
1047
+ # Apply weights to each reward function's output and sum
1048
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
1049
+
1050
+ # Compute grouped-wise rewards
1051
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
1052
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
1053
+
1054
+ # Normalize the rewards to compute the advantages
1055
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1056
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
1057
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
1058
+
1059
+ # Slice to keep only the local part of the data
1060
+ process_slice = slice(
1061
+ self.accelerator.process_index * len(prompts),
1062
+ (self.accelerator.process_index + 1) * len(prompts),
1063
+ )
1064
+ advantages = advantages[process_slice]
1065
+
1066
+ # Log the metrics
1067
+ reward_per_func = rewards_per_func.mean(0)
1068
+ for i, reward_func in enumerate(self.reward_funcs):
1069
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1070
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1071
+ else:
1072
+ reward_func_name = reward_func.__name__
1073
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1074
+
1075
+ self._metrics["reward"].append(rewards.mean().item())
1076
+ self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
1077
+
1078
+ if (
1079
+ self.log_completions
1080
+ and self.state.global_step % self.args.logging_steps == 0
1081
+ and "wandb" in self.args.report_to
1082
+ ):
1083
+ import pandas as pd
1084
+
1085
+ # For logging
1086
+ table = {
1087
+ "step": [str(self.state.global_step)] * len(rewards),
1088
+ "prompt": gather_object(prompts_text),
1089
+ "completion": gather_object(completions_text),
1090
+ "reward": rewards.tolist(),
1091
+ }
1092
+ df = pd.DataFrame(table)
1093
+
1094
+ if wandb.run is not None and self.accelerator.is_main_process:
1095
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1096
+
1097
+ return {
1098
+ "prompt_ids": prompt_ids,
1099
+ "prompt_mask": prompt_mask,
1100
+ "completion_ids": completion_ids,
1101
+ "completion_mask": completion_mask,
1102
+ "ref_per_token_logps": ref_per_token_logps,
1103
+ "advantages": advantages,
1104
+ }
1105
+
1106
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
1107
+ if return_outputs:
1108
+ raise ValueError("The GRPOTrainer does not support returning outputs")
1109
+ # Compute the per-token log probabilities for the model
1110
+
1111
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1112
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1113
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1114
+ bsz, qlen = input_ids.shape
1115
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1116
+ # attention_mask = None
1117
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1118
+ _input_ids = input_ids
1119
+ _logits_to_keep = logits_to_keep
1120
+
1121
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
1122
+
1123
+ # Compute the KL divergence between the model and the reference model
1124
+ ref_per_token_logps = inputs["ref_per_token_logps"]
1125
+ # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
1126
+
1127
+ # x - x.detach() allows for preserving gradients from x
1128
+ advantages = inputs["advantages"]
1129
+ # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
1130
+ # per_token_loss = -(per_token_loss - self.beta * per_token_kl)
1131
+ # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1132
+ input_ids = input_ids[:, -logits_to_keep:]
1133
+ if per_token_logps is not None:
1134
+ loss, completion_length, mean_kl = grpo_compute_loss_slow(
1135
+ ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
1136
+ )
1137
+ else:
1138
+ loss, completion_length, mean_kl = grpo_accumulated_loss(
1139
+ self, _input_ids, logits_to_keep, completion_mask, advantages,
1140
+ n_chunks = self.args.unsloth_num_chunks,
1141
+ )
1142
+
1143
+ # Log the metrics
1144
+ # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1145
+
1146
+ # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1147
+ # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1148
+
1149
+ if "train" in self._metrics:
1150
+ mode = "eval" if self.control.should_evaluate else "train"
1151
+ self._metrics[mode]["completion_length"].append(completion_length.item())
1152
+ self._metrics[mode]["kl"].append(mean_kl.item())
1153
+ else:
1154
+ self._metrics["completion_length"].append(completion_length.item())
1155
+ self._metrics["kl"].append(mean_kl.item())
1156
+ return loss
1157
+
1158
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
1159
+ inputs = self._prepare_inputs(inputs)
1160
+ with torch.no_grad():
1161
+ with self.compute_loss_context_manager():
1162
+ loss = self.compute_loss(model, inputs)
1163
+ loss = loss.mean().detach()
1164
+ return loss, None, None
1165
+
1166
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1167
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1168
+
1169
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1170
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1171
+ if next(iter(logs.keys())).startswith("eval_"):
1172
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1173
+
1174
+ logs = {**logs, **metrics}
1175
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1176
+ super().log(logs, start_time)
1177
+ else: # transformers<=4.46
1178
+ super().log(logs)
1179
+ self._metrics.clear()
1180
+
1181
+ def create_model_card(
1182
+ self,
1183
+ model_name: Optional[str] = None,
1184
+ dataset_name: Optional[str] = None,
1185
+ tags: Union[str, list[str], None] = None,
1186
+ ):
1187
+ """
1188
+ Creates a draft of a model card using the information available to the `Trainer`.
1189
+
1190
+ Args:
1191
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1192
+ Name of the model.
1193
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1194
+ Name of the dataset used for training.
1195
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1196
+ Tags to be associated with the model card.
1197
+ """
1198
+ if not self.is_world_process_zero():
1199
+ return
1200
+
1201
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1202
+ base_model = self.model.config._name_or_path
1203
+ else:
1204
+ base_model = None
1205
+
1206
+ tags = tags or []
1207
+ if isinstance(tags, str):
1208
+ tags = [tags]
1209
+
1210
+ if hasattr(self.model.config, "unsloth_version"):
1211
+ tags.append("unsloth")
1212
+
1213
+ citation = textwrap.dedent(
1214
+ """\
1215
+ @article{zhihong2024deepseekmath,
1216
+ title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1217
+ author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1218
+ year = 2024,
1219
+ eprint = {arXiv:2402.03300},
1220
+ }
1221
+ """
1222
+ )
1223
+
1224
+ model_card = generate_model_card(
1225
+ base_model=base_model,
1226
+ model_name=model_name,
1227
+ hub_model_id=self.hub_model_id,
1228
+ dataset_name=dataset_name,
1229
+ tags=tags,
1230
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1231
+ comet_url=get_comet_experiment_url(),
1232
+ trainer_name="GRPO",
1233
+ trainer_citation=citation,
1234
+ paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1235
+ paper_id="2402.03300",
1236
+ )
1237
+
1238
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1239
+ class UnslothGRPOTrainer(_UnslothGRPOTrainer):
1240
+ """
1241
+
1242
+ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
1243
+ paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
1244
+
1245
+ Example:
1246
+
1247
+ ```python
1248
+ from datasets import load_dataset
1249
+ from trl import GRPOTrainer
1250
+
1251
+ dataset = load_dataset("trl-lib/tldr", split="train")
1252
+
1253
+ def reward_func(completions, **kwargs):
1254
+ # Dummy reward function that rewards completions with more unique letters.
1255
+ return [float(len(set(completion))) for completion in completions]
1256
+
1257
+ trainer = GRPOTrainer(
1258
+ model="Qwen/Qwen2-0.5B-Instruct",
1259
+ reward_funcs=reward_func,
1260
+ train_dataset=dataset,
1261
+ )
1262
+
1263
+ trainer.train()
1264
+ ```
1265
+
1266
+ Args:
1267
+ model (`Union[str, PreTrainedModel]`):
1268
+ Model to be trained. Can be either:
1269
+
1270
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
1271
+ a path to a *directory* containing model weights saved using
1272
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
1273
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
1274
+ in `args.model_init_kwargs`.
1275
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
1276
+ reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
1277
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
1278
+ functions with the prompts and completions and sum the rewards. Can be either:
1279
+
1280
+ - A single reward function, such as:
1281
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
1282
+ path to a *directory* containing model weights saved using
1283
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
1284
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
1285
+ keyword arguments in `args.model_init_kwargs`.
1286
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
1287
+ - A custom reward function: The function is provided with the prompts and the generated completions,
1288
+ plus any additional columns in the dataset. It should return a list of rewards. For more details, see
1289
+ [Using a custom reward function](#using-a-custom-reward-function).
1290
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
1291
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
1292
+ args ([`GRPOConfig`], *optional*, defaults to `None`):
1293
+ Configuration for this trainer. If `None`, a default configuration is used.
1294
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
1295
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
1296
+ ignored. The format of the samples can be either:
1297
+
1298
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
1299
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
1300
+ and content).
1301
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
1302
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
1303
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
1304
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
1305
+ processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
1306
+ reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
1307
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
1308
+
1309
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
1310
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
1311
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
1312
+ `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
1313
+ For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
1314
+ the corresponding entries in `reward_processing_classes` are ignored.
1315
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
1316
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
1317
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
1318
+
1319
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
1320
+ method.
1321
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
1322
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
1323
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
1324
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
1325
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
1326
+
1327
+ """
1328
+ def __init__(
1329
+ self,
1330
+ model,
1331
+ reward_funcs,
1332
+ args = None,
1333
+ train_dataset = None,
1334
+ eval_dataset = None,
1335
+ processing_class = None,
1336
+ reward_processing_classes = None,
1337
+ callbacks = None,
1338
+ peft_config = None,
1339
+ **kwargs
1340
+ ):
1341
+ if args is None: args = UnslothGRPOConfig()
1342
+ use_bf16 = getattr(args, 'bf16', False)
1343
+ use_fp16 = getattr(args, 'fp16', False)
1344
+ force_float32 = False
1345
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1346
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1347
+ force_float32 = True
1348
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1349
+ dtype = getattr(model.config, 'torch_dtype', None)
1350
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1351
+ from unsloth_zoo.utils import _get_dtype
1352
+ dtype = _get_dtype(dtype)
1353
+ float16 = dtype == torch.float16
1354
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1355
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1356
+ if force_float32:
1357
+ args.fp16 = False
1358
+ args.bf16 = False
1359
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1360
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1361
+ args.fp16 = float16
1362
+ args.bf16 = not float16
1363
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1364
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1365
+ args.eval_strategy = 'steps'
1366
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1367
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1368
+ if ga_steps is not None and ga_steps > 1:
1369
+ from transformers import __version__ as transformers_version
1370
+ if Version(transformers_version) <= Version('4.45.2'):
1371
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1372
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1373
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1374
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1375
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1376
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1377
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1378
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1379
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1380
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1381
+ if force_float32:
1382
+ args.bf16_full_eval = False
1383
+ args.fp16_full_eval = False
1384
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1385
+ args.bf16_full_eval = True
1386
+ args.fp16_full_eval = False
1387
+ elif not bf16_full_eval and not fp16_full_eval:
1388
+ args.bf16_full_eval = args.bf16
1389
+ args.fp16_full_eval = args.fp16
1390
+ _output_logits = False
1391
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1392
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1393
+ if _output_logits:
1394
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1395
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1396
+ pass
1397
+ else:
1398
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1399
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1400
+ if args_max_seq_length is None and model_max_seq_length is not None:
1401
+ max_seq_length = model.max_seq_length
1402
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1403
+ if model is not None and hasattr(model, 'for_training'):
1404
+ model.for_training()
1405
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1406
+ if 'processing_class' in locals():
1407
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1408
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1409
+ other_metrics = []
1410
+ if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
1411
+ else: _reward_funcs = reward_funcs
1412
+ for reward_func in _reward_funcs:
1413
+ try:
1414
+ reward_func_name = reward_func.__name__
1415
+ other_metrics.append(f'rewards/{reward_func_name}')
1416
+ except: pass
1417
+
1418
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1419
+ PatchRLStatistics('grpo_trainer', other_metrics)
1420
+
1421
+ super().__init__(
1422
+ model = model,
1423
+ reward_funcs = reward_funcs,
1424
+ args = args,
1425
+ train_dataset = train_dataset,
1426
+ eval_dataset = eval_dataset,
1427
+ processing_class = processing_class,
1428
+ reward_processing_classes = reward_processing_classes,
1429
+ callbacks = callbacks,
1430
+ peft_config = peft_config,**kwargs)
1431
+ if hasattr(self, 'neftune_hook_handle'):
1432
+ self.neftune_hook_handle.remove()
1433
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1434
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1435
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1436
+ pass
1437
+
1438
+ pass
unsloth_compiled_cache/UnslothKTOTrainer.py ADDED
@@ -0,0 +1,1840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothKTOConfig(KTOConfig):
44
+ """
45
+
46
+ Configuration class for the [`KTOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
66
+ reference model.
67
+ loss_type (`str`, *optional*, defaults to `"kto"`):
68
+ Type of loss to use. Possible values are:
69
+
70
+ - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
71
+ - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
72
+
73
+ desirable_weight (`float`, *optional*, defaults to `1.0`):
74
+ Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
75
+ undesirable_weight (`float`, *optional*, defaults to `1.0`):
76
+ Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
77
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
78
+ Label pad token id. This argument is required if you want to use the default data collator.
79
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
80
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
81
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
82
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
83
+ This argument is required if you want to use the default data collator.
84
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
85
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
86
+ evaluation.
87
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
+ you need to specify if the model returned by the callable is an encoder-decoder model.
90
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
91
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
92
+ useful when training without the reference model to reduce the total GPU memory needed.
93
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
94
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
95
+ string.
96
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
97
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
98
+ from a string.
99
+ dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
100
+ Number of processes to use for processing the dataset.
101
+ disable_dropout (`bool`, *optional*, defaults to `True`):
102
+ Whether to disable dropout in the model and reference model.
103
+
104
+ """
105
+ vllm_sampling_params: Optional[Any] = field(
106
+ default = None,
107
+ metadata = {'help': 'vLLM SamplingParams'},
108
+ )
109
+ unsloth_num_chunks : Optional[int] = field(
110
+ default = -1,
111
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
112
+ )
113
+ def __init__(
114
+ self,
115
+ output_dir = None,
116
+ overwrite_output_dir = None,
117
+ do_train = False,
118
+ do_eval = False,
119
+ do_predict = False,
120
+ eval_strategy = 'no',
121
+ prediction_loss_only = False,
122
+ per_device_train_batch_size = 4,
123
+ per_device_eval_batch_size = 4,
124
+ per_gpu_train_batch_size = None,
125
+ per_gpu_eval_batch_size = None,
126
+ gradient_accumulation_steps = 2,
127
+ eval_accumulation_steps = 2,
128
+ eval_delay = 0,
129
+ torch_empty_cache_steps = 250,
130
+ learning_rate = 5e-05,
131
+ weight_decay = 0.01,
132
+ adam_beta1 = 0.9,
133
+ adam_beta2 = 0.999,
134
+ adam_epsilon = 1e-08,
135
+ max_grad_norm = 1.0,
136
+ num_train_epochs = 3.0,
137
+ max_steps = -1,
138
+ lr_scheduler_type = 'linear',
139
+ warmup_ratio = 0.1,
140
+ warmup_steps = 0,
141
+ log_level = 'passive',
142
+ log_level_replica = 'warning',
143
+ log_on_each_node = True,
144
+ logging_dir = None,
145
+ logging_strategy = 'steps',
146
+ logging_first_step = False,
147
+ logging_steps = 1,
148
+ logging_nan_inf_filter = False,
149
+ save_strategy = 'steps',
150
+ save_steps = 500,
151
+ save_total_limit = None,
152
+ save_safetensors = True,
153
+ save_on_each_node = False,
154
+ save_only_model = False,
155
+ restore_callback_states_from_checkpoint = False,
156
+ no_cuda = False,
157
+ use_cpu = False,
158
+ use_mps_device = False,
159
+ seed = 3407,
160
+ data_seed = 3407,
161
+ jit_mode_eval = False,
162
+ use_ipex = False,
163
+ bf16 = False,
164
+ fp16 = False,
165
+ fp16_opt_level = 'O1',
166
+ half_precision_backend = 'auto',
167
+ bf16_full_eval = False,
168
+ fp16_full_eval = False,
169
+ tf32 = None,
170
+ local_rank = -1,
171
+ ddp_backend = None,
172
+ tpu_num_cores = None,
173
+ tpu_metrics_debug = False,
174
+ debug = '',
175
+ dataloader_drop_last = False,
176
+ eval_steps = None,
177
+ dataloader_num_workers = 0,
178
+ dataloader_prefetch_factor = None,
179
+ past_index = -1,
180
+ run_name = None,
181
+ disable_tqdm = None,
182
+ remove_unused_columns = True,
183
+ label_names = None,
184
+ load_best_model_at_end = False,
185
+ metric_for_best_model = None,
186
+ greater_is_better = None,
187
+ ignore_data_skip = False,
188
+ fsdp = '',
189
+ fsdp_min_num_params = 0,
190
+ fsdp_config = None,
191
+ tp_size = 0,
192
+ fsdp_transformer_layer_cls_to_wrap = None,
193
+ accelerator_config = None,
194
+ deepspeed = None,
195
+ label_smoothing_factor = 0.0,
196
+ optim = 'adamw_8bit',
197
+ optim_args = None,
198
+ adafactor = False,
199
+ group_by_length = False,
200
+ length_column_name = 'length',
201
+ report_to = None,
202
+ ddp_find_unused_parameters = None,
203
+ ddp_bucket_cap_mb = None,
204
+ ddp_broadcast_buffers = None,
205
+ dataloader_pin_memory = True,
206
+ dataloader_persistent_workers = False,
207
+ skip_memory_metrics = True,
208
+ use_legacy_prediction_loop = False,
209
+ push_to_hub = False,
210
+ resume_from_checkpoint = None,
211
+ hub_model_id = None,
212
+ hub_strategy = 'every_save',
213
+ hub_token = None,
214
+ hub_private_repo = None,
215
+ hub_always_push = False,
216
+ gradient_checkpointing = False,
217
+ gradient_checkpointing_kwargs = None,
218
+ include_inputs_for_metrics = False,
219
+ eval_do_concat_batches = True,
220
+ fp16_backend = 'auto',
221
+ evaluation_strategy = None,
222
+ push_to_hub_model_id = None,
223
+ push_to_hub_organization = None,
224
+ push_to_hub_token = None,
225
+ mp_parameters = '',
226
+ auto_find_batch_size = False,
227
+ full_determinism = False,
228
+ torchdynamo = None,
229
+ ray_scope = 'last',
230
+ ddp_timeout = 1800,
231
+ torch_compile = False,
232
+ torch_compile_backend = None,
233
+ torch_compile_mode = None,
234
+ dispatch_batches = None,
235
+ split_batches = None,
236
+ include_tokens_per_second = False,
237
+ include_num_input_tokens_seen = False,
238
+ neftune_noise_alpha = None,
239
+ optim_target_modules = None,
240
+ batch_eval_metrics = False,
241
+ eval_on_start = False,
242
+ use_liger_kernel = False,
243
+ eval_use_gather_object = False,
244
+ average_tokens_across_devices = False,
245
+ max_length = 1024,
246
+ max_prompt_length = 512,
247
+ max_completion_length = None,
248
+ beta = 0.1,
249
+ loss_type = 'kto',
250
+ desirable_weight = 1.0,
251
+ undesirable_weight = 1.0,
252
+ label_pad_token_id = -100,
253
+ padding_value = None,
254
+ truncation_mode = 'keep_end',
255
+ generate_during_eval = False,
256
+ is_encoder_decoder = None,
257
+ disable_dropout = True,
258
+ precompute_ref_log_probs = False,
259
+ model_init_kwargs = None,
260
+ ref_model_init_kwargs = None,
261
+ dataset_num_proc = None,
262
+ vllm_sampling_params = None,
263
+ unsloth_num_chunks = -1,
264
+ **kwargs,
265
+ ):
266
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
267
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
268
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
269
+ output_dir = 'unsloth_training_checkpoints'
270
+ save_strategy = 'no'
271
+ if dataset_num_proc is None:
272
+ from multiprocessing import cpu_count
273
+ dataset_num_proc = cpu_count()
274
+
275
+ super().__init__(
276
+ output_dir = output_dir,
277
+ overwrite_output_dir = overwrite_output_dir,
278
+ do_train = do_train,
279
+ do_eval = do_eval,
280
+ do_predict = do_predict,
281
+ eval_strategy = eval_strategy,
282
+ prediction_loss_only = prediction_loss_only,
283
+ per_device_train_batch_size = per_device_train_batch_size,
284
+ per_device_eval_batch_size = per_device_eval_batch_size,
285
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
286
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
287
+ gradient_accumulation_steps = gradient_accumulation_steps,
288
+ eval_accumulation_steps = eval_accumulation_steps,
289
+ eval_delay = eval_delay,
290
+ torch_empty_cache_steps = torch_empty_cache_steps,
291
+ learning_rate = learning_rate,
292
+ weight_decay = weight_decay,
293
+ adam_beta1 = adam_beta1,
294
+ adam_beta2 = adam_beta2,
295
+ adam_epsilon = adam_epsilon,
296
+ max_grad_norm = max_grad_norm,
297
+ num_train_epochs = num_train_epochs,
298
+ max_steps = max_steps,
299
+ lr_scheduler_type = lr_scheduler_type,
300
+ warmup_ratio = warmup_ratio,
301
+ warmup_steps = warmup_steps,
302
+ log_level = log_level,
303
+ log_level_replica = log_level_replica,
304
+ log_on_each_node = log_on_each_node,
305
+ logging_dir = logging_dir,
306
+ logging_strategy = logging_strategy,
307
+ logging_first_step = logging_first_step,
308
+ logging_steps = logging_steps,
309
+ logging_nan_inf_filter = logging_nan_inf_filter,
310
+ save_strategy = save_strategy,
311
+ save_steps = save_steps,
312
+ save_total_limit = save_total_limit,
313
+ save_safetensors = save_safetensors,
314
+ save_on_each_node = save_on_each_node,
315
+ save_only_model = save_only_model,
316
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
317
+ no_cuda = no_cuda,
318
+ use_cpu = use_cpu,
319
+ use_mps_device = use_mps_device,
320
+ seed = seed,
321
+ data_seed = data_seed,
322
+ jit_mode_eval = jit_mode_eval,
323
+ use_ipex = use_ipex,
324
+ bf16 = bf16,
325
+ fp16 = fp16,
326
+ fp16_opt_level = fp16_opt_level,
327
+ half_precision_backend = half_precision_backend,
328
+ bf16_full_eval = bf16_full_eval,
329
+ fp16_full_eval = fp16_full_eval,
330
+ tf32 = tf32,
331
+ local_rank = local_rank,
332
+ ddp_backend = ddp_backend,
333
+ tpu_num_cores = tpu_num_cores,
334
+ tpu_metrics_debug = tpu_metrics_debug,
335
+ debug = debug,
336
+ dataloader_drop_last = dataloader_drop_last,
337
+ eval_steps = eval_steps,
338
+ dataloader_num_workers = dataloader_num_workers,
339
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
340
+ past_index = past_index,
341
+ run_name = run_name,
342
+ disable_tqdm = disable_tqdm,
343
+ remove_unused_columns = remove_unused_columns,
344
+ label_names = label_names,
345
+ load_best_model_at_end = load_best_model_at_end,
346
+ metric_for_best_model = metric_for_best_model,
347
+ greater_is_better = greater_is_better,
348
+ ignore_data_skip = ignore_data_skip,
349
+ fsdp = fsdp,
350
+ fsdp_min_num_params = fsdp_min_num_params,
351
+ fsdp_config = fsdp_config,
352
+ tp_size = tp_size,
353
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
354
+ accelerator_config = accelerator_config,
355
+ deepspeed = deepspeed,
356
+ label_smoothing_factor = label_smoothing_factor,
357
+ optim = optim,
358
+ optim_args = optim_args,
359
+ adafactor = adafactor,
360
+ group_by_length = group_by_length,
361
+ length_column_name = length_column_name,
362
+ report_to = report_to,
363
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
364
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
365
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
366
+ dataloader_pin_memory = dataloader_pin_memory,
367
+ dataloader_persistent_workers = dataloader_persistent_workers,
368
+ skip_memory_metrics = skip_memory_metrics,
369
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
370
+ push_to_hub = push_to_hub,
371
+ resume_from_checkpoint = resume_from_checkpoint,
372
+ hub_model_id = hub_model_id,
373
+ hub_strategy = hub_strategy,
374
+ hub_token = hub_token,
375
+ hub_private_repo = hub_private_repo,
376
+ hub_always_push = hub_always_push,
377
+ gradient_checkpointing = gradient_checkpointing,
378
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
379
+ include_inputs_for_metrics = include_inputs_for_metrics,
380
+ eval_do_concat_batches = eval_do_concat_batches,
381
+ fp16_backend = fp16_backend,
382
+ evaluation_strategy = evaluation_strategy,
383
+ push_to_hub_model_id = push_to_hub_model_id,
384
+ push_to_hub_organization = push_to_hub_organization,
385
+ push_to_hub_token = push_to_hub_token,
386
+ mp_parameters = mp_parameters,
387
+ auto_find_batch_size = auto_find_batch_size,
388
+ full_determinism = full_determinism,
389
+ torchdynamo = torchdynamo,
390
+ ray_scope = ray_scope,
391
+ ddp_timeout = ddp_timeout,
392
+ torch_compile = torch_compile,
393
+ torch_compile_backend = torch_compile_backend,
394
+ torch_compile_mode = torch_compile_mode,
395
+ dispatch_batches = dispatch_batches,
396
+ split_batches = split_batches,
397
+ include_tokens_per_second = include_tokens_per_second,
398
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
399
+ neftune_noise_alpha = neftune_noise_alpha,
400
+ optim_target_modules = optim_target_modules,
401
+ batch_eval_metrics = batch_eval_metrics,
402
+ eval_on_start = eval_on_start,
403
+ use_liger_kernel = use_liger_kernel,
404
+ eval_use_gather_object = eval_use_gather_object,
405
+ average_tokens_across_devices = average_tokens_across_devices,
406
+ max_length = max_length,
407
+ max_prompt_length = max_prompt_length,
408
+ max_completion_length = max_completion_length,
409
+ beta = beta,
410
+ loss_type = loss_type,
411
+ desirable_weight = desirable_weight,
412
+ undesirable_weight = undesirable_weight,
413
+ label_pad_token_id = label_pad_token_id,
414
+ padding_value = padding_value,
415
+ truncation_mode = truncation_mode,
416
+ generate_during_eval = generate_during_eval,
417
+ is_encoder_decoder = is_encoder_decoder,
418
+ disable_dropout = disable_dropout,
419
+ precompute_ref_log_probs = precompute_ref_log_probs,
420
+ model_init_kwargs = model_init_kwargs,
421
+ ref_model_init_kwargs = ref_model_init_kwargs,
422
+ dataset_num_proc = dataset_num_proc,**kwargs)
423
+ self.vllm_sampling_params = vllm_sampling_params
424
+ self.unsloth_num_chunks = unsloth_num_chunks
425
+ pass
426
+
427
+ class _UnslothKTOTrainer(Trainer):
428
+ r""""""
429
+
430
+ _tag_names = ["trl", "kto"]
431
+
432
+ def __init__(
433
+ self,
434
+ model: Union[PreTrainedModel, nn.Module, str] = None,
435
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
436
+ args: KTOConfig = None,
437
+ train_dataset: Optional[Dataset] = None,
438
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
439
+ processing_class: Optional[
440
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
441
+ ] = None,
442
+ data_collator: Optional[DataCollator] = None,
443
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
444
+ callbacks: Optional[list[TrainerCallback]] = None,
445
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
446
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
447
+ peft_config: Optional[dict] = None,
448
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
449
+ model_adapter_name: Optional[str] = None,
450
+ ref_adapter_name: Optional[str] = None,
451
+ ):
452
+ if type(args) is TrainingArguments:
453
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
454
+
455
+ if not isinstance(model, str) and ref_model is model:
456
+ raise ValueError(
457
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
458
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
459
+ )
460
+
461
+ if args.model_init_kwargs is None:
462
+ model_init_kwargs = {}
463
+ elif not isinstance(model, str):
464
+ raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
465
+ else:
466
+ model_init_kwargs = args.model_init_kwargs
467
+ torch_dtype = model_init_kwargs.get("torch_dtype")
468
+ if torch_dtype is not None:
469
+ # Convert to `torch.dtype` if an str is passed
470
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
471
+ torch_dtype = getattr(torch, torch_dtype)
472
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
473
+ raise ValueError(
474
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
475
+ )
476
+ model_init_kwargs["torch_dtype"] = torch_dtype
477
+
478
+ if args.ref_model_init_kwargs is None:
479
+ ref_model_init_kwargs = {}
480
+ elif not isinstance(ref_model, str):
481
+ raise ValueError(
482
+ "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
483
+ )
484
+ else:
485
+ ref_model_init_kwargs = args.ref_model_init_kwargs
486
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
487
+ if torch_dtype is not None:
488
+ # Convert to `torch.dtype` if an str is passed
489
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
490
+ torch_dtype = getattr(torch, torch_dtype)
491
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
492
+ raise ValueError(
493
+ f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
494
+ )
495
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
496
+
497
+ if isinstance(model, str):
498
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
499
+
500
+ if isinstance(ref_model, str):
501
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
502
+
503
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
504
+ # has been called in order to properly call autocast if needed.
505
+ self._peft_has_been_casted_to_bf16 = False
506
+
507
+ if not is_peft_available() and peft_config is not None:
508
+ raise ValueError(
509
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
510
+ )
511
+ elif is_peft_available() and peft_config is not None:
512
+ # if model is a peft model and we have a peft_config, we merge and unload it first
513
+ if isinstance(model, PeftModel):
514
+ model = model.merge_and_unload()
515
+
516
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
517
+ _support_gc_kwargs = hasattr(
518
+ args, "gradient_checkpointing_kwargs"
519
+ ) and "gradient_checkpointing_kwargs" in list(
520
+ inspect.signature(prepare_model_for_kbit_training).parameters
521
+ )
522
+
523
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
524
+
525
+ if _support_gc_kwargs:
526
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
527
+
528
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
529
+ elif getattr(args, "gradient_checkpointing", False):
530
+ # For backward compatibility with older versions of transformers
531
+ if hasattr(model, "enable_input_require_grads"):
532
+ model.enable_input_require_grads()
533
+ else:
534
+
535
+ def make_inputs_require_grad(module, input, output):
536
+ output.requires_grad_(True)
537
+
538
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
539
+
540
+ # get peft model with the given config
541
+ model = model
542
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
543
+ peft_module_casting_to_bf16(model)
544
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
545
+ self._peft_has_been_casted_to_bf16 = True
546
+
547
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
548
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
549
+ # fail or completely fail.
550
+ elif getattr(args, "gradient_checkpointing", False):
551
+ # For backward compatibility with older versions of transformers
552
+ if hasattr(model, "enable_input_require_grads"):
553
+ model.enable_input_require_grads()
554
+ else:
555
+
556
+ def make_inputs_require_grad(module, input, output):
557
+ output.requires_grad_(True)
558
+
559
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
560
+
561
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
562
+ raise ValueError(
563
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
564
+ " Please install `wandb` or `comet-ml` to resolve."
565
+ )
566
+
567
+ if model is not None:
568
+ self.is_encoder_decoder = model.config.is_encoder_decoder
569
+ elif args.is_encoder_decoder is None:
570
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
571
+ else:
572
+ self.is_encoder_decoder = args.is_encoder_decoder
573
+
574
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
575
+ self.model_adapter_name = model_adapter_name
576
+ self.ref_adapter_name = ref_adapter_name
577
+
578
+ if ref_model:
579
+ self.ref_model = ref_model
580
+ elif self.is_peft_model or args.precompute_ref_log_probs:
581
+ # The `model` with adapters turned off will be used as the reference model
582
+ self.ref_model = None
583
+ else:
584
+ self.ref_model = create_reference_model(model)
585
+
586
+ if processing_class is None:
587
+ raise ValueError(
588
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
589
+ )
590
+ if args.max_length is None:
591
+ warnings.warn(
592
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
593
+ " it will be set to `512` by default, but you should do it yourself in the future.",
594
+ UserWarning,
595
+ )
596
+ max_length = 512
597
+ if args.max_length is not None:
598
+ max_length = args.max_length
599
+
600
+ if args.max_prompt_length is None:
601
+ warnings.warn(
602
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
603
+ " it will be set to `128` by default, but you should do it yourself in the future.",
604
+ UserWarning,
605
+ )
606
+ max_prompt_length = 128
607
+ if args.max_prompt_length is not None:
608
+ max_prompt_length = args.max_prompt_length
609
+
610
+ max_completion_length = None
611
+ if args.max_completion_length is None and self.is_encoder_decoder:
612
+ warnings.warn(
613
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
614
+ " it will be set to `128` by default, but you should do it yourself in the future.",
615
+ UserWarning,
616
+ )
617
+ max_completion_length = 128
618
+ if args.max_completion_length is not None and self.is_encoder_decoder:
619
+ max_completion_length = args.max_completion_length
620
+
621
+ if data_collator is None:
622
+ data_collator = DPODataCollatorWithPadding(
623
+ pad_token_id=processing_class.pad_token_id,
624
+ label_pad_token_id=args.label_pad_token_id,
625
+ is_encoder_decoder=self.is_encoder_decoder,
626
+ )
627
+
628
+ if args.remove_unused_columns:
629
+ args.remove_unused_columns = False
630
+ # warn users
631
+ warnings.warn(
632
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
633
+ " we have set it for you, but you should do it yourself in the future.",
634
+ UserWarning,
635
+ )
636
+
637
+ self.use_dpo_data_collator = True
638
+ else:
639
+ self.use_dpo_data_collator = False
640
+
641
+ # Disable dropout in the model and reference model
642
+ if args.disable_dropout:
643
+ disable_dropout_in_model(model)
644
+ if self.ref_model is not None:
645
+ disable_dropout_in_model(self.ref_model)
646
+
647
+ self.loss_type = args.loss_type
648
+ self.max_length = max_length
649
+ self.generate_during_eval = args.generate_during_eval
650
+ self.label_pad_token_id = args.label_pad_token_id
651
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
652
+ self.max_prompt_length = max_prompt_length
653
+ self.truncation_mode = args.truncation_mode
654
+ self.max_completion_length = max_completion_length
655
+ self.processing_class = processing_class
656
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
657
+
658
+ # Not all losses require a KL calculation
659
+ self.calculate_KL = True
660
+ if self.loss_type in ["apo_zero_unpaired"]:
661
+ self.calculate_KL = False
662
+
663
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
664
+ # keep track of first called to avoid computation of future calls
665
+ self._precomputed_train_ref_log_probs = False
666
+ self._precomputed_eval_ref_log_probs = False
667
+
668
+ # metric
669
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
670
+
671
+ # KTO parameter
672
+ self.beta = args.beta
673
+ self.desirable_weight = args.desirable_weight
674
+ self.undesirable_weight = args.undesirable_weight
675
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
676
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
677
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
678
+ warnings.warn(
679
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
680
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
681
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
682
+ "loss.",
683
+ UserWarning,
684
+ )
685
+
686
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
687
+ # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
688
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
689
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
690
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
691
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
692
+ # issued.
693
+ model.warnings_issued["estimate_tokens"] = True
694
+
695
+ # Compute that only on the main process for faster data processing.
696
+ # see: https://github.com/huggingface/trl/pull/1255
697
+ with PartialState().local_main_process_first():
698
+ # Extract the prompt if needed
699
+ train_dataset = train_dataset.map(
700
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
701
+ )
702
+ # Unpair the dataset if needed
703
+ train_dataset = maybe_unpair_preference_dataset(
704
+ train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
705
+ )
706
+ # Apply the chat template if needed
707
+ train_dataset = train_dataset.map(
708
+ maybe_apply_chat_template,
709
+ fn_kwargs={"tokenizer": processing_class},
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Applying chat template to train dataset",
712
+ )
713
+ if eval_dataset is not None:
714
+ eval_dataset = eval_dataset.map(
715
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
716
+ )
717
+ eval_dataset = maybe_unpair_preference_dataset(
718
+ eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
719
+ )
720
+ eval_dataset = eval_dataset.map(
721
+ maybe_apply_chat_template,
722
+ fn_kwargs={"tokenizer": processing_class},
723
+ num_proc=args.dataset_num_proc,
724
+ desc="Applying chat template to eval dataset",
725
+ )
726
+
727
+ # Tokenize and prepare the training datasets
728
+ train_dataset = train_dataset.map(
729
+ _tokenize,
730
+ batched=True,
731
+ fn_kwargs={"tokenizer": self.processing_class},
732
+ num_proc=args.dataset_num_proc,
733
+ desc="Tokenizing train dataset",
734
+ )
735
+
736
+ fn_kwargs = {
737
+ "prefix": "",
738
+ "is_encoder_decoder": self.is_encoder_decoder,
739
+ "tokenizer": self.processing_class,
740
+ "max_length": self.max_length,
741
+ "truncation_mode": self.truncation_mode,
742
+ "label_pad_token_id": self.label_pad_token_id,
743
+ "max_prompt_length": self.max_prompt_length,
744
+ "max_completion_length": self.max_completion_length,
745
+ }
746
+
747
+ train_dataset = train_dataset.map(
748
+ _process_tokens,
749
+ fn_kwargs=fn_kwargs,
750
+ num_proc=args.dataset_num_proc,
751
+ desc="Processing tokenized train dataset",
752
+ )
753
+
754
+ # Tokenize and prepare the eval datasets
755
+ if eval_dataset is not None:
756
+ eval_dataset = eval_dataset.map(
757
+ _tokenize,
758
+ fn_kwargs={"tokenizer": self.processing_class},
759
+ batched=True,
760
+ num_proc=args.dataset_num_proc,
761
+ desc="Tokenizing eval dataset",
762
+ )
763
+
764
+ eval_dataset = eval_dataset.map(
765
+ _process_tokens,
766
+ fn_kwargs=fn_kwargs,
767
+ num_proc=args.dataset_num_proc,
768
+ desc="Processing tokenized eval dataset",
769
+ )
770
+
771
+ # Get KL datasets if needed
772
+ if self.calculate_KL:
773
+ if args.per_device_train_batch_size <= 1:
774
+ raise ValueError(
775
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
776
+ )
777
+
778
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
779
+ # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
780
+ train_kl_dataset = train_dataset.map(
781
+ _get_kl_dataset,
782
+ batched=True,
783
+ batch_size=args.per_device_train_batch_size,
784
+ num_proc=args.dataset_num_proc,
785
+ desc="Extracting KL train dataset",
786
+ )
787
+
788
+ fn_kwargs["prefix"] = "KL_"
789
+ train_kl_dataset = train_kl_dataset.map(
790
+ _process_tokens,
791
+ fn_kwargs=fn_kwargs,
792
+ num_proc=args.dataset_num_proc,
793
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
794
+ desc="Processing tokenized train KL dataset",
795
+ )
796
+
797
+ # merge the datasets
798
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
799
+
800
+ if eval_dataset is not None:
801
+ # Get KL dataset
802
+ eval_kl_dataset = eval_dataset.map(
803
+ _get_kl_dataset,
804
+ batched=True,
805
+ batch_size=args.per_device_train_batch_size,
806
+ num_proc=args.dataset_num_proc,
807
+ desc="Extracting eval KL dataset",
808
+ )
809
+
810
+ eval_kl_dataset = eval_kl_dataset.map(
811
+ _process_tokens,
812
+ fn_kwargs=fn_kwargs,
813
+ num_proc=args.dataset_num_proc,
814
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
815
+ desc="Processing tokenized eval KL dataset",
816
+ )
817
+
818
+ # merge the datasets
819
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
820
+
821
+ # calculate dataset desirability balance
822
+ num_desirable = max(sum(train_dataset["label"]), 1)
823
+ num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
824
+
825
+ if num_desirable != num_undesirable:
826
+ # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
827
+ des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
828
+ des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
829
+ und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
830
+ und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
831
+
832
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
833
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
834
+
835
+ if not (des_weight_in_range or und_weight_in_range):
836
+ warnings.warn(
837
+ "You have different amounts of desirable/positive and undesirable/negative examples but the "
838
+ "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
839
+ f"on your data, we recommend EITHER "
840
+ f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
841
+ f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
842
+ "See the documentation on how to optimally set these weights.",
843
+ UserWarning,
844
+ )
845
+
846
+ super().__init__(
847
+ model=model,
848
+ args=args,
849
+ data_collator=data_collator,
850
+ train_dataset=train_dataset,
851
+ eval_dataset=eval_dataset,
852
+ processing_class=processing_class,
853
+ model_init=model_init,
854
+ compute_metrics=compute_metrics,
855
+ callbacks=callbacks,
856
+ optimizers=optimizers,
857
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
858
+ )
859
+
860
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
861
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
862
+ # self.model_accepts_loss_kwargs to False to enable scaling.
863
+ self.model_accepts_loss_kwargs = False
864
+
865
+ # Add tags for models that have been loaded with the correct transformers version
866
+ if hasattr(self.model, "add_model_tags"):
867
+ self.model.add_model_tags(self._tag_names)
868
+
869
+ if not hasattr(self, "accelerator"):
870
+ raise AttributeError(
871
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
872
+ )
873
+
874
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
875
+ if self.is_deepspeed_enabled:
876
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
877
+ raise ValueError(
878
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
879
+ )
880
+
881
+ if self.ref_model is None:
882
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
883
+ raise ValueError(
884
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
885
+ )
886
+ else:
887
+ if self.is_deepspeed_enabled:
888
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
889
+ else:
890
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
891
+
892
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
893
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
894
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
895
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
896
+
897
+ if model is not None:
898
+ if hasattr(model, "config"):
899
+ hidden_size = (
900
+ max(model.config.hidden_sizes)
901
+ if getattr(model.config, "hidden_sizes", None)
902
+ else getattr(model.config, "hidden_size", None)
903
+ )
904
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
905
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
906
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
907
+ config_kwargs.update(
908
+ {
909
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
910
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
911
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
912
+ }
913
+ )
914
+
915
+ # If ZeRO-3 is used, we shard both the active and reference model.
916
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
917
+ if config_kwargs["zero_optimization"]["stage"] != 3:
918
+ config_kwargs["zero_optimization"]["stage"] = 0
919
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
920
+ model.eval()
921
+ return model
922
+
923
+ @contextmanager
924
+ def null_ref_context(self):
925
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
926
+ with (
927
+ self.accelerator.unwrap_model(self.model).disable_adapter()
928
+ if self.is_peft_model and not self.ref_adapter_name
929
+ else nullcontext()
930
+ ):
931
+ if self.ref_adapter_name:
932
+ self.model.set_adapter(self.ref_adapter_name)
933
+ yield
934
+ if self.ref_adapter_name:
935
+ self.model.set_adapter(self.model_adapter_name or "default")
936
+
937
+ def get_train_dataloader(self) -> DataLoader:
938
+ """
939
+ Returns the training [`~torch.utils.data.DataLoader`].
940
+
941
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
942
+ """
943
+
944
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
945
+ dataloader_params = {
946
+ "batch_size": self.args.per_device_train_batch_size,
947
+ "collate_fn": self.data_collator,
948
+ "num_workers": self.args.dataloader_num_workers,
949
+ "pin_memory": self.args.dataloader_pin_memory,
950
+ "shuffle": False,
951
+ }
952
+
953
+ # prepare dataloader
954
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
955
+ reference_completion_logps = []
956
+ reference_KL_logps = []
957
+
958
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
959
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
960
+
961
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
962
+ reference_completion_logps.append(reference_completion_logp.cpu())
963
+
964
+ if self.calculate_KL:
965
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
966
+ reference_KL_logps.append(reference_KL_logp.cpu())
967
+
968
+ self.train_dataset = self.train_dataset.add_column(
969
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
970
+ )
971
+
972
+ if self.calculate_KL:
973
+ self.train_dataset = self.train_dataset.add_column(
974
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
975
+ )
976
+
977
+ self._precomputed_train_ref_log_probs = True
978
+
979
+ return super().get_train_dataloader()
980
+
981
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
982
+ """
983
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
984
+
985
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
986
+
987
+ Args:
988
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
989
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
990
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
991
+ """
992
+ if eval_dataset is None and self.eval_dataset is None:
993
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
994
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
995
+
996
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
997
+ dataloader_params = {
998
+ "batch_size": self.args.per_device_eval_batch_size,
999
+ "collate_fn": self.data_collator,
1000
+ "num_workers": self.args.dataloader_num_workers,
1001
+ "pin_memory": self.args.dataloader_pin_memory,
1002
+ "shuffle": False,
1003
+ }
1004
+
1005
+ # prepare dataloader
1006
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1007
+
1008
+ reference_completion_logps = []
1009
+ reference_KL_logps = []
1010
+
1011
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1012
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1013
+
1014
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1015
+ reference_completion_logps.append(reference_completion_logp.cpu())
1016
+
1017
+ if self.calculate_KL:
1018
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1019
+ reference_KL_logps.append(reference_KL_logp.cpu())
1020
+
1021
+ eval_dataset = eval_dataset.add_column(
1022
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1023
+ )
1024
+ if self.calculate_KL:
1025
+ eval_dataset = eval_dataset.add_column(
1026
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1027
+ )
1028
+
1029
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1030
+ if self.eval_dataset is not None:
1031
+ self.eval_dataset = eval_dataset
1032
+ self._precomputed_eval_ref_log_probs = True
1033
+
1034
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1035
+
1036
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1037
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1038
+ with torch.no_grad():
1039
+ if self.ref_model is None:
1040
+ with self.null_ref_context():
1041
+ if self.is_encoder_decoder:
1042
+ completion_logits = self.model(
1043
+ padded_batch["prompt_input_ids"],
1044
+ attention_mask=padded_batch["prompt_attention_mask"],
1045
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1046
+ labels=padded_batch["completion_labels"],
1047
+ ).logits
1048
+
1049
+ if self.calculate_KL:
1050
+ KL_logits = self.model(
1051
+ padded_batch["KL_prompt_input_ids"],
1052
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1053
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1054
+ labels=padded_batch["KL_completion_labels"],
1055
+ ).logits
1056
+ else:
1057
+ completion_logits = self.model(
1058
+ padded_batch["completion_input_ids"],
1059
+ attention_mask=padded_batch["completion_attention_mask"],
1060
+ ).logits
1061
+
1062
+ if self.calculate_KL:
1063
+ KL_logits = self.model(
1064
+ padded_batch["KL_completion_input_ids"],
1065
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1066
+ ).logits
1067
+ else:
1068
+ if self.is_encoder_decoder:
1069
+ completion_logits = self.ref_model(
1070
+ padded_batch["prompt_input_ids"],
1071
+ attention_mask=padded_batch["prompt_attention_mask"],
1072
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1073
+ labels=padded_batch["completion_labels"],
1074
+ ).logits
1075
+
1076
+ if self.calculate_KL:
1077
+ KL_logits = self.ref_model(
1078
+ padded_batch["KL_prompt_input_ids"],
1079
+ attention_mask=padded_batch["KL_prompt_attention_mask"],
1080
+ decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1081
+ labels=padded_batch["KL_completion_labels"],
1082
+ ).logits
1083
+ else:
1084
+ completion_logits = self.ref_model(
1085
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1086
+ ).logits
1087
+
1088
+ if self.calculate_KL:
1089
+ KL_logits = self.ref_model(
1090
+ padded_batch["KL_completion_input_ids"],
1091
+ attention_mask=padded_batch["KL_completion_attention_mask"],
1092
+ ).logits
1093
+
1094
+ completion_logps = self.get_batch_logps(
1095
+ completion_logits,
1096
+ padded_batch["completion_labels"],
1097
+ average_log_prob=False,
1098
+ is_encoder_decoder=self.is_encoder_decoder,
1099
+ label_pad_token_id=self.label_pad_token_id,
1100
+ )
1101
+
1102
+ if self.calculate_KL:
1103
+ KL_logps = self.get_batch_logps(
1104
+ KL_logits,
1105
+ padded_batch["KL_completion_labels"],
1106
+ average_log_prob=False,
1107
+ is_encoder_decoder=self.is_encoder_decoder,
1108
+ label_pad_token_id=self.label_pad_token_id,
1109
+ )
1110
+ else:
1111
+ KL_logps = None
1112
+
1113
+ return completion_logps, KL_logps
1114
+
1115
+ @staticmethod
1116
+ def get_batch_logps(
1117
+ logits: torch.FloatTensor,
1118
+ labels: torch.LongTensor,
1119
+ average_log_prob: bool = False,
1120
+ label_pad_token_id: int = -100,
1121
+ is_encoder_decoder: bool = False,
1122
+ ) -> torch.FloatTensor:
1123
+ """Compute the log probabilities of the given labels under the given logits.
1124
+
1125
+ Args:
1126
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1127
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1128
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1129
+
1130
+ Returns:
1131
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1132
+ """
1133
+ if logits.shape[:-1] != labels.shape:
1134
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1135
+
1136
+ if not is_encoder_decoder:
1137
+ labels = labels[:, 1:].clone()
1138
+ logits = logits[:, :-1, :]
1139
+ else:
1140
+ # Fixes end-dec RuntimeError
1141
+ labels = labels.clone()
1142
+
1143
+ loss_mask = labels != label_pad_token_id
1144
+
1145
+ # dummy token; we'll ignore the losses on these tokens later
1146
+ labels[labels == label_pad_token_id] = 0
1147
+
1148
+ per_token_logps = selective_log_softmax(logits, labels)
1149
+
1150
+ if average_log_prob:
1151
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1152
+ else:
1153
+ return (per_token_logps * loss_mask).sum(-1)
1154
+
1155
+ def forward(
1156
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1157
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1158
+ if self.calculate_KL:
1159
+ KL_logps = None
1160
+ KL_model_kwargs = (
1161
+ {
1162
+ "input_ids": batch["KL_prompt_input_ids"],
1163
+ "attention_mask": batch["KL_prompt_attention_mask"],
1164
+ "labels": batch["KL_completion_labels"],
1165
+ "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1166
+ }
1167
+ if self.is_encoder_decoder
1168
+ else {
1169
+ "input_ids": batch["KL_completion_input_ids"],
1170
+ "attention_mask": batch["KL_completion_attention_mask"],
1171
+ }
1172
+ )
1173
+ with torch.no_grad():
1174
+ KL_logits = model(
1175
+ **KL_model_kwargs,
1176
+ ).logits
1177
+
1178
+ KL_logps = self.get_batch_logps(
1179
+ KL_logits,
1180
+ batch["KL_completion_labels"],
1181
+ average_log_prob=False,
1182
+ is_encoder_decoder=self.is_encoder_decoder,
1183
+ label_pad_token_id=self.label_pad_token_id,
1184
+ )
1185
+ else:
1186
+ KL_logps = None
1187
+
1188
+ model_kwargs = (
1189
+ {
1190
+ "labels": batch["completion_labels"],
1191
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1192
+ }
1193
+ if self.is_encoder_decoder
1194
+ else {}
1195
+ )
1196
+ if self.aux_loss_enabled:
1197
+ model_kwargs["output_router_logits"] = True
1198
+
1199
+ outputs = model(
1200
+ batch["completion_input_ids"],
1201
+ attention_mask=batch["completion_attention_mask"],
1202
+ **model_kwargs,
1203
+ )
1204
+ completion_logits = outputs.logits
1205
+
1206
+ completion_logps = self.get_batch_logps(
1207
+ completion_logits,
1208
+ batch["completion_labels"],
1209
+ average_log_prob=False,
1210
+ is_encoder_decoder=self.is_encoder_decoder,
1211
+ label_pad_token_id=self.label_pad_token_id,
1212
+ )
1213
+
1214
+ if completion_logps.shape[0] != len(batch["label"]):
1215
+ raise ValueError(
1216
+ "There is a mismatch between the number of examples in this batch and the number of "
1217
+ "examples for which an output sequence was predicted."
1218
+ )
1219
+
1220
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1221
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1222
+
1223
+ chosen_logps = completion_logps[chosen_idx, ...]
1224
+ rejected_logps = completion_logps[rejected_idx, ...]
1225
+
1226
+ chosen_logits = completion_logits[chosen_idx, ...]
1227
+ rejected_logits = completion_logits[rejected_idx, ...]
1228
+
1229
+ if self.aux_loss_enabled:
1230
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1231
+ else:
1232
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1233
+
1234
+ def kto_loss(
1235
+ self,
1236
+ policy_chosen_logps: torch.FloatTensor,
1237
+ policy_rejected_logps: torch.FloatTensor,
1238
+ policy_KL_logps: torch.FloatTensor,
1239
+ reference_chosen_logps: torch.FloatTensor,
1240
+ reference_rejected_logps: torch.FloatTensor,
1241
+ reference_KL_logps: torch.FloatTensor,
1242
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1243
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
1244
+
1245
+ Args:
1246
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1247
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1248
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1249
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1250
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1251
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1252
+
1253
+ Returns:
1254
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1255
+ The losses tensor contains the KTO loss for each example in the batch.
1256
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1257
+ The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1258
+ """
1259
+ if self.calculate_KL:
1260
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1261
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1262
+ else:
1263
+ kl = torch.zeros(1).to(policy_chosen_logps.device)
1264
+
1265
+ # Chosen losses
1266
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1267
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1268
+
1269
+ if self.loss_type == "kto":
1270
+ # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1271
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1272
+ elif self.loss_type == "apo_zero_unpaired":
1273
+ # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1274
+ # Use this loss when you believe the chosen outputs are better than your model's default output
1275
+ chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1276
+
1277
+ chosen_rewards = self.beta * chosen_logratios.detach()
1278
+
1279
+ else:
1280
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1281
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1282
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1283
+
1284
+ # Rejected losses
1285
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1286
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1287
+
1288
+ if self.loss_type == "kto":
1289
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1290
+ elif self.loss_type == "apo_zero_unpaired":
1291
+ rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1292
+
1293
+ rejected_rewards = self.beta * rejected_logratios.detach()
1294
+ else:
1295
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1296
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1297
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1298
+
1299
+ losses = torch.cat(
1300
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1301
+ 0,
1302
+ )
1303
+
1304
+ return losses, chosen_rewards, rejected_rewards, kl
1305
+
1306
+ def get_batch_loss_metrics(
1307
+ self,
1308
+ model,
1309
+ batch: dict[str, Union[list, torch.LongTensor]],
1310
+ ):
1311
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1312
+ metrics = {}
1313
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1314
+
1315
+ forward_output = self.forward(model, batch)
1316
+ (
1317
+ policy_chosen_logps,
1318
+ policy_rejected_logps,
1319
+ policy_chosen_logits,
1320
+ policy_rejected_logits,
1321
+ policy_KL_logps,
1322
+ ) = forward_output[:5]
1323
+ if self.aux_loss_enabled:
1324
+ aux_loss = forward_output[5]
1325
+
1326
+ # if reference_logps in batch use them, otherwise use the reference model
1327
+ if "reference_logps" in batch:
1328
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1329
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1330
+
1331
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1332
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1333
+ if self.calculate_KL:
1334
+ reference_KL_logps = batch["reference_KL_logps"]
1335
+ else:
1336
+ reference_KL_logps = None
1337
+ else:
1338
+ with torch.no_grad():
1339
+ if self.ref_model is None:
1340
+ with self.null_ref_context():
1341
+ (
1342
+ reference_chosen_logps,
1343
+ reference_rejected_logps,
1344
+ _,
1345
+ _,
1346
+ reference_KL_logps,
1347
+ ) = self.forward(self.model, batch)[:5]
1348
+ else:
1349
+ (
1350
+ reference_chosen_logps,
1351
+ reference_rejected_logps,
1352
+ _,
1353
+ _,
1354
+ reference_KL_logps,
1355
+ ) = self.forward(self.ref_model, batch)[:5]
1356
+
1357
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1358
+ policy_chosen_logps,
1359
+ policy_rejected_logps,
1360
+ policy_KL_logps,
1361
+ reference_chosen_logps,
1362
+ reference_rejected_logps,
1363
+ reference_KL_logps,
1364
+ )
1365
+ metrics["kl"] = kl.item()
1366
+
1367
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1368
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1369
+
1370
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1371
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1372
+
1373
+ if all_num_chosen > 0:
1374
+ metrics["rewards/chosen_sum"] = (
1375
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1376
+ )
1377
+ metrics["logps/chosen_sum"] = (
1378
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1379
+ )
1380
+ metrics["logits/chosen_sum"] = (
1381
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1382
+ )
1383
+ metrics["count/chosen"] = all_num_chosen
1384
+
1385
+ if all_num_rejected > 0:
1386
+ metrics["rewards/rejected_sum"] = (
1387
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1388
+ )
1389
+ metrics["logps/rejected_sum"] = (
1390
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1391
+ )
1392
+ metrics["logits/rejected_sum"] = (
1393
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1394
+ )
1395
+ metrics["count/rejected"] = all_num_rejected
1396
+
1397
+ loss = losses.nanmean()
1398
+ if self.aux_loss_enabled:
1399
+ loss += self.aux_loss_coef * aux_loss
1400
+
1401
+ return loss, metrics
1402
+
1403
+ def compute_loss(
1404
+ self,
1405
+ model: Union[PreTrainedModel, nn.Module],
1406
+ inputs: dict[str, Union[torch.Tensor, Any]],
1407
+ return_outputs=False,
1408
+ num_items_in_batch=None,
1409
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1410
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1411
+
1412
+ with compute_loss_context_manager:
1413
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1414
+
1415
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1416
+ loss = loss.to(self.args.device)
1417
+ # force log the metrics
1418
+ if self.accelerator.is_main_process:
1419
+ self.store_metrics(metrics, train_eval="train")
1420
+
1421
+ if return_outputs:
1422
+ return (loss, metrics)
1423
+ return loss
1424
+
1425
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1426
+ for key, value in metrics.items():
1427
+ self._stored_metrics[train_eval][key].append(value)
1428
+
1429
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1430
+ if self.train_dataset is None or not has_length(self.train_dataset):
1431
+ return None
1432
+ return SequentialSampler(self.train_dataset)
1433
+
1434
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1435
+ """Generate samples from the model and reference model for the given batch of inputs."""
1436
+
1437
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1438
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1439
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1440
+
1441
+ with generate_context_manager:
1442
+ policy_output = model.generate(
1443
+ input_ids=batch["prompt_input_ids"],
1444
+ attention_mask=batch["prompt_attention_mask"],
1445
+ max_length=self.max_length,
1446
+ do_sample=True,
1447
+ pad_token_id=self.processing_class.pad_token_id,
1448
+ )
1449
+
1450
+ # if reference_output in batch use that otherwise use the reference model
1451
+ if "reference_output" in batch:
1452
+ reference_output = batch["reference_output"]
1453
+ else:
1454
+ if self.ref_model is None:
1455
+ with self.null_ref_context():
1456
+ reference_output = self.model.generate(
1457
+ input_ids=batch["prompt_input_ids"],
1458
+ attention_mask=batch["prompt_attention_mask"],
1459
+ max_length=self.max_length,
1460
+ do_sample=True,
1461
+ pad_token_id=self.processing_class.pad_token_id,
1462
+ )
1463
+ else:
1464
+ reference_output = self.ref_model.generate(
1465
+ input_ids=batch["prompt_input_ids"],
1466
+ attention_mask=batch["prompt_attention_mask"],
1467
+ max_length=self.max_length,
1468
+ do_sample=True,
1469
+ pad_token_id=self.processing_class.pad_token_id,
1470
+ )
1471
+
1472
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1473
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1474
+
1475
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1476
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1477
+
1478
+ return policy_output_decoded, reference_output_decoded
1479
+
1480
+ def prediction_step(
1481
+ self,
1482
+ model: Union[PreTrainedModel, nn.Module],
1483
+ inputs: dict[str, Union[torch.Tensor, Any]],
1484
+ prediction_loss_only: bool,
1485
+ ignore_keys: Optional[list[str]] = None,
1486
+ ):
1487
+ if ignore_keys is None:
1488
+ if hasattr(model, "config"):
1489
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1490
+ else:
1491
+ ignore_keys = []
1492
+
1493
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1494
+ with torch.no_grad(), prediction_context_manager:
1495
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1496
+
1497
+ # force log the metrics
1498
+ if self.accelerator.is_main_process:
1499
+ self.store_metrics(metrics, train_eval="eval")
1500
+
1501
+ if prediction_loss_only:
1502
+ return (loss.detach(), None, None)
1503
+
1504
+ # logits for the chosen and rejected samples from model
1505
+ logits_dict = {
1506
+ "eval_logits/chosen": metrics["logits/chosen"],
1507
+ "eval_logits/rejected": metrics["logits/rejected"],
1508
+ }
1509
+ logits = torch.tensor(
1510
+ [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
1511
+ )
1512
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1513
+
1514
+ return (loss.detach(), logits, labels)
1515
+
1516
+ def evaluation_loop(
1517
+ self,
1518
+ dataloader: DataLoader,
1519
+ description: str,
1520
+ prediction_loss_only: Optional[bool] = None,
1521
+ ignore_keys: Optional[list[str]] = None,
1522
+ metric_key_prefix: str = "eval",
1523
+ ) -> EvalLoopOutput:
1524
+ """
1525
+ Overriding built-in evaluation loop to store metrics for each batch.
1526
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1527
+
1528
+ Works both with or without labels.
1529
+ """
1530
+
1531
+ # Sample and save to game log if requested (for one batch to save time)
1532
+ if self.generate_during_eval:
1533
+ # Generate random indices within the range of the total number of samples
1534
+ num_samples = len(dataloader.dataset)
1535
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1536
+
1537
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1538
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1539
+ random_batch = self.data_collator(random_batch_dataset)
1540
+ random_batch = self._prepare_inputs(random_batch)
1541
+
1542
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1543
+ target_batch = {
1544
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1545
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1546
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1547
+ }
1548
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1549
+
1550
+ table = pd.DataFrame(
1551
+ columns=["Prompt", "Policy", "Ref Model"],
1552
+ data=[
1553
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1554
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1555
+ ],
1556
+ )
1557
+ if "wandb" in self.args.report_to:
1558
+ wandb.log({"game_log": wandb.Table(data=table)})
1559
+
1560
+ if "comet_ml" in self.args.report_to:
1561
+ log_table_to_comet_experiment(
1562
+ name="game_log.csv",
1563
+ table=table,
1564
+ )
1565
+
1566
+ # Base evaluation
1567
+ initial_output = super().evaluation_loop(
1568
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1569
+ )
1570
+
1571
+ return initial_output
1572
+
1573
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1574
+ """
1575
+ Log `logs` on the various objects watching training, including stored metrics.
1576
+
1577
+ Args:
1578
+ logs (`dict[str, float]`):
1579
+ The values to log.
1580
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1581
+ Start time of the training.
1582
+ """
1583
+ # logs either has 'loss' or 'eval_loss'
1584
+ train_eval = "train" if "loss" in logs else "eval"
1585
+ # train metrics should have no prefix, eval should have 'eval_'
1586
+ prefix = "eval_" if train_eval == "eval" else ""
1587
+ # accumulate average metrics from sums and lengths
1588
+ for split in ["chosen", "rejected"]:
1589
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1590
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1591
+ for metric in ["rewards", "logps", "logits"]:
1592
+ logs[f"{prefix}{metric}/{split}"] = (
1593
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1594
+ / count_sum
1595
+ )
1596
+ # delete obsolete metric
1597
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1598
+ del self._stored_metrics[train_eval][f"count/{split}"]
1599
+ # calculate reward margin
1600
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1601
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1602
+ # Add averaged stored metrics to logs
1603
+ for key, metrics in self._stored_metrics[train_eval].items():
1604
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1605
+ del self._stored_metrics[train_eval]
1606
+
1607
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1608
+ return super().log(logs, start_time)
1609
+ else: # transformers<=4.46
1610
+ return super().log(logs)
1611
+
1612
+ def create_model_card(
1613
+ self,
1614
+ model_name: Optional[str] = None,
1615
+ dataset_name: Optional[str] = None,
1616
+ tags: Union[str, list[str], None] = None,
1617
+ ):
1618
+ """
1619
+ Creates a draft of a model card using the information available to the `Trainer`.
1620
+
1621
+ Args:
1622
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1623
+ Name of the model.
1624
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1625
+ Name of the dataset used for training.
1626
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1627
+ Tags to be associated with the model card.
1628
+ """
1629
+ if not self.is_world_process_zero():
1630
+ return
1631
+
1632
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1633
+ base_model = self.model.config._name_or_path
1634
+ else:
1635
+ base_model = None
1636
+
1637
+ tags = tags or []
1638
+ if isinstance(tags, str):
1639
+ tags = [tags]
1640
+
1641
+ if hasattr(self.model.config, "unsloth_version"):
1642
+ tags.append("unsloth")
1643
+
1644
+ citation = textwrap.dedent("""\
1645
+ @article{ethayarajh2024kto,
1646
+ title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1647
+ author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1648
+ year = 2024,
1649
+ eprint = {arXiv:2402.01306},
1650
+ }""")
1651
+
1652
+ model_card = generate_model_card(
1653
+ base_model=base_model,
1654
+ model_name=model_name,
1655
+ hub_model_id=self.hub_model_id,
1656
+ dataset_name=dataset_name,
1657
+ tags=tags,
1658
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1659
+ comet_url=get_comet_experiment_url(),
1660
+ trainer_name="KTO",
1661
+ trainer_citation=citation,
1662
+ paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1663
+ paper_id="2402.01306",
1664
+ )
1665
+
1666
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1667
+ class UnslothKTOTrainer(_UnslothKTOTrainer):
1668
+ """
1669
+
1670
+ Initialize KTOTrainer.
1671
+
1672
+ Args:
1673
+ model (`transformers.PreTrainedModel`):
1674
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1675
+ ref_model (`PreTrainedModelWrapper`):
1676
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1677
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1678
+ args (`KTOConfig`):
1679
+ The arguments to use for training.
1680
+ train_dataset (`datasets.Dataset`):
1681
+ The dataset to use for training.
1682
+ eval_dataset (`datasets.Dataset`):
1683
+ The dataset to use for evaluation.
1684
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1685
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1686
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1687
+ reuse the fine-tuned model.
1688
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1689
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1690
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1691
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1692
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1693
+ callbacks (`list[transformers.TrainerCallback]`):
1694
+ The callbacks to use for training.
1695
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1696
+ The optimizer and scheduler to use for training.
1697
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1698
+ The function to use to preprocess the logits before computing the metrics.
1699
+ peft_config (`dict`, defaults to `None`):
1700
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1701
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1702
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1703
+ a dictionary string to metric values.
1704
+ model_adapter_name (`str`, defaults to `None`):
1705
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1706
+ ref_adapter_name (`str`, defaults to `None`):
1707
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1708
+
1709
+ """
1710
+ def __init__(
1711
+ self,
1712
+ model = None,
1713
+ ref_model = None,
1714
+ args = None,
1715
+ train_dataset = None,
1716
+ eval_dataset = None,
1717
+ processing_class = None,
1718
+ data_collator = None,
1719
+ model_init = None,
1720
+ callbacks = None,
1721
+ preprocess_logits_for_metrics = None,
1722
+ peft_config = None,
1723
+ compute_metrics = None,
1724
+ model_adapter_name = None,
1725
+ ref_adapter_name = None,
1726
+ **kwargs
1727
+ ):
1728
+ if args is None: args = UnslothKTOConfig()
1729
+ use_bf16 = getattr(args, 'bf16', False)
1730
+ use_fp16 = getattr(args, 'fp16', False)
1731
+ force_float32 = False
1732
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1733
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1734
+ force_float32 = True
1735
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1736
+ dtype = getattr(model.config, 'torch_dtype', None)
1737
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1738
+ from unsloth_zoo.utils import _get_dtype
1739
+ dtype = _get_dtype(dtype)
1740
+ float16 = dtype == torch.float16
1741
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1742
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1743
+ if force_float32:
1744
+ args.fp16 = False
1745
+ args.bf16 = False
1746
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1747
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1748
+ args.fp16 = float16
1749
+ args.bf16 = not float16
1750
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1751
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1752
+ args.eval_strategy = 'steps'
1753
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1754
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1755
+ if ga_steps is not None and ga_steps > 1:
1756
+ from transformers import __version__ as transformers_version
1757
+ if Version(transformers_version) <= Version('4.45.2'):
1758
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1759
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1760
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1761
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1762
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1763
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1764
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1765
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1766
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1767
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1768
+ if force_float32:
1769
+ args.bf16_full_eval = False
1770
+ args.fp16_full_eval = False
1771
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1772
+ args.bf16_full_eval = True
1773
+ args.fp16_full_eval = False
1774
+ elif not bf16_full_eval and not fp16_full_eval:
1775
+ args.bf16_full_eval = args.bf16
1776
+ args.fp16_full_eval = args.fp16
1777
+ _output_logits = False
1778
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1779
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1780
+ if _output_logits:
1781
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1782
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1783
+ pass
1784
+ else:
1785
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1786
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1787
+ if args_max_seq_length is None and model_max_seq_length is not None:
1788
+ max_seq_length = model.max_seq_length
1789
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1790
+ if model is not None and hasattr(model, 'for_training'):
1791
+ model.for_training()
1792
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1793
+ if 'processing_class' in locals():
1794
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1795
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1796
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1797
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1798
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1799
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1800
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1801
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1802
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1803
+ else:
1804
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1805
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1806
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1807
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1808
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1809
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1810
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1811
+ else:
1812
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1813
+ other_metrics = []
1814
+
1815
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1816
+ PatchRLStatistics('kto_trainer', other_metrics)
1817
+
1818
+ super().__init__(
1819
+ model = model,
1820
+ ref_model = ref_model,
1821
+ args = args,
1822
+ train_dataset = train_dataset,
1823
+ eval_dataset = eval_dataset,
1824
+ processing_class = processing_class,
1825
+ data_collator = data_collator,
1826
+ model_init = model_init,
1827
+ callbacks = callbacks,
1828
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1829
+ peft_config = peft_config,
1830
+ compute_metrics = compute_metrics,
1831
+ model_adapter_name = model_adapter_name,
1832
+ ref_adapter_name = ref_adapter_name,**kwargs)
1833
+ if hasattr(self, 'neftune_hook_handle'):
1834
+ self.neftune_hook_handle.remove()
1835
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1836
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1837
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1838
+ pass
1839
+
1840
+ pass
unsloth_compiled_cache/UnslothNashMDTrainer.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothNashMDConfig(NashMDConfig):
44
+ """
45
+
46
+ Configuration class for the [`NashMDTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
52
+ Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
53
+ mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
54
+ epochs.
55
+
56
+ """
57
+ vllm_sampling_params: Optional[Any] = field(
58
+ default = None,
59
+ metadata = {'help': 'vLLM SamplingParams'},
60
+ )
61
+ unsloth_num_chunks : Optional[int] = field(
62
+ default = -1,
63
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
64
+ )
65
+ def __init__(
66
+ self,
67
+ output_dir = None,
68
+ overwrite_output_dir = None,
69
+ do_train = False,
70
+ do_eval = False,
71
+ do_predict = False,
72
+ eval_strategy = 'no',
73
+ prediction_loss_only = False,
74
+ per_device_train_batch_size = 4,
75
+ per_device_eval_batch_size = 4,
76
+ per_gpu_train_batch_size = None,
77
+ per_gpu_eval_batch_size = None,
78
+ gradient_accumulation_steps = 2,
79
+ eval_accumulation_steps = 2,
80
+ eval_delay = 0,
81
+ torch_empty_cache_steps = 250,
82
+ learning_rate = 5e-05,
83
+ weight_decay = 0.01,
84
+ adam_beta1 = 0.9,
85
+ adam_beta2 = 0.999,
86
+ adam_epsilon = 1e-08,
87
+ max_grad_norm = 1.0,
88
+ num_train_epochs = 3.0,
89
+ max_steps = -1,
90
+ lr_scheduler_type = 'linear',
91
+ warmup_ratio = 0.1,
92
+ warmup_steps = 0,
93
+ log_level = 'passive',
94
+ log_level_replica = 'warning',
95
+ log_on_each_node = True,
96
+ logging_dir = None,
97
+ logging_strategy = 'steps',
98
+ logging_first_step = False,
99
+ logging_steps = 1,
100
+ logging_nan_inf_filter = False,
101
+ save_strategy = 'steps',
102
+ save_steps = 500,
103
+ save_total_limit = None,
104
+ save_safetensors = True,
105
+ save_on_each_node = False,
106
+ save_only_model = False,
107
+ restore_callback_states_from_checkpoint = False,
108
+ no_cuda = False,
109
+ use_cpu = False,
110
+ use_mps_device = False,
111
+ seed = 3407,
112
+ data_seed = 3407,
113
+ jit_mode_eval = False,
114
+ use_ipex = False,
115
+ bf16 = False,
116
+ fp16 = False,
117
+ fp16_opt_level = 'O1',
118
+ half_precision_backend = 'auto',
119
+ bf16_full_eval = False,
120
+ fp16_full_eval = False,
121
+ tf32 = None,
122
+ local_rank = -1,
123
+ ddp_backend = None,
124
+ tpu_num_cores = None,
125
+ tpu_metrics_debug = False,
126
+ debug = '',
127
+ dataloader_drop_last = False,
128
+ eval_steps = None,
129
+ dataloader_num_workers = 0,
130
+ dataloader_prefetch_factor = None,
131
+ past_index = -1,
132
+ run_name = None,
133
+ disable_tqdm = None,
134
+ remove_unused_columns = True,
135
+ label_names = None,
136
+ load_best_model_at_end = False,
137
+ metric_for_best_model = None,
138
+ greater_is_better = None,
139
+ ignore_data_skip = False,
140
+ fsdp = '',
141
+ fsdp_min_num_params = 0,
142
+ fsdp_config = None,
143
+ tp_size = 0,
144
+ fsdp_transformer_layer_cls_to_wrap = None,
145
+ accelerator_config = None,
146
+ deepspeed = None,
147
+ label_smoothing_factor = 0.0,
148
+ optim = 'adamw_8bit',
149
+ optim_args = None,
150
+ adafactor = False,
151
+ group_by_length = False,
152
+ length_column_name = 'length',
153
+ report_to = None,
154
+ ddp_find_unused_parameters = None,
155
+ ddp_bucket_cap_mb = None,
156
+ ddp_broadcast_buffers = None,
157
+ dataloader_pin_memory = True,
158
+ dataloader_persistent_workers = False,
159
+ skip_memory_metrics = True,
160
+ use_legacy_prediction_loop = False,
161
+ push_to_hub = False,
162
+ resume_from_checkpoint = None,
163
+ hub_model_id = None,
164
+ hub_strategy = 'every_save',
165
+ hub_token = None,
166
+ hub_private_repo = None,
167
+ hub_always_push = False,
168
+ gradient_checkpointing = False,
169
+ gradient_checkpointing_kwargs = None,
170
+ include_inputs_for_metrics = False,
171
+ eval_do_concat_batches = True,
172
+ fp16_backend = 'auto',
173
+ evaluation_strategy = None,
174
+ push_to_hub_model_id = None,
175
+ push_to_hub_organization = None,
176
+ push_to_hub_token = None,
177
+ mp_parameters = '',
178
+ auto_find_batch_size = False,
179
+ full_determinism = False,
180
+ torchdynamo = None,
181
+ ray_scope = 'last',
182
+ ddp_timeout = 1800,
183
+ torch_compile = False,
184
+ torch_compile_backend = None,
185
+ torch_compile_mode = None,
186
+ dispatch_batches = None,
187
+ split_batches = None,
188
+ include_tokens_per_second = False,
189
+ include_num_input_tokens_seen = False,
190
+ neftune_noise_alpha = None,
191
+ optim_target_modules = None,
192
+ batch_eval_metrics = False,
193
+ eval_on_start = False,
194
+ use_liger_kernel = False,
195
+ eval_use_gather_object = False,
196
+ average_tokens_across_devices = False,
197
+ reward_model_path = None,
198
+ judge = None,
199
+ max_new_tokens = 64,
200
+ max_length = 512,
201
+ temperature = 0.9,
202
+ missing_eos_penalty = None,
203
+ loss_type = 'sigmoid',
204
+ dataset_num_proc = None,
205
+ disable_dropout = True,
206
+ use_vllm = False,
207
+ ds3_gather_for_generation = True,
208
+ vllm_sampling_params = None,
209
+ unsloth_num_chunks = -1,
210
+ **kwargs,
211
+ ):
212
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
213
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
214
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
215
+ output_dir = 'unsloth_training_checkpoints'
216
+ save_strategy = 'no'
217
+ if dataset_num_proc is None:
218
+ from multiprocessing import cpu_count
219
+ dataset_num_proc = cpu_count()
220
+
221
+ super().__init__(
222
+ output_dir = output_dir,
223
+ overwrite_output_dir = overwrite_output_dir,
224
+ do_train = do_train,
225
+ do_eval = do_eval,
226
+ do_predict = do_predict,
227
+ eval_strategy = eval_strategy,
228
+ prediction_loss_only = prediction_loss_only,
229
+ per_device_train_batch_size = per_device_train_batch_size,
230
+ per_device_eval_batch_size = per_device_eval_batch_size,
231
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
232
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
233
+ gradient_accumulation_steps = gradient_accumulation_steps,
234
+ eval_accumulation_steps = eval_accumulation_steps,
235
+ eval_delay = eval_delay,
236
+ torch_empty_cache_steps = torch_empty_cache_steps,
237
+ learning_rate = learning_rate,
238
+ weight_decay = weight_decay,
239
+ adam_beta1 = adam_beta1,
240
+ adam_beta2 = adam_beta2,
241
+ adam_epsilon = adam_epsilon,
242
+ max_grad_norm = max_grad_norm,
243
+ num_train_epochs = num_train_epochs,
244
+ max_steps = max_steps,
245
+ lr_scheduler_type = lr_scheduler_type,
246
+ warmup_ratio = warmup_ratio,
247
+ warmup_steps = warmup_steps,
248
+ log_level = log_level,
249
+ log_level_replica = log_level_replica,
250
+ log_on_each_node = log_on_each_node,
251
+ logging_dir = logging_dir,
252
+ logging_strategy = logging_strategy,
253
+ logging_first_step = logging_first_step,
254
+ logging_steps = logging_steps,
255
+ logging_nan_inf_filter = logging_nan_inf_filter,
256
+ save_strategy = save_strategy,
257
+ save_steps = save_steps,
258
+ save_total_limit = save_total_limit,
259
+ save_safetensors = save_safetensors,
260
+ save_on_each_node = save_on_each_node,
261
+ save_only_model = save_only_model,
262
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
263
+ no_cuda = no_cuda,
264
+ use_cpu = use_cpu,
265
+ use_mps_device = use_mps_device,
266
+ seed = seed,
267
+ data_seed = data_seed,
268
+ jit_mode_eval = jit_mode_eval,
269
+ use_ipex = use_ipex,
270
+ bf16 = bf16,
271
+ fp16 = fp16,
272
+ fp16_opt_level = fp16_opt_level,
273
+ half_precision_backend = half_precision_backend,
274
+ bf16_full_eval = bf16_full_eval,
275
+ fp16_full_eval = fp16_full_eval,
276
+ tf32 = tf32,
277
+ local_rank = local_rank,
278
+ ddp_backend = ddp_backend,
279
+ tpu_num_cores = tpu_num_cores,
280
+ tpu_metrics_debug = tpu_metrics_debug,
281
+ debug = debug,
282
+ dataloader_drop_last = dataloader_drop_last,
283
+ eval_steps = eval_steps,
284
+ dataloader_num_workers = dataloader_num_workers,
285
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
286
+ past_index = past_index,
287
+ run_name = run_name,
288
+ disable_tqdm = disable_tqdm,
289
+ remove_unused_columns = remove_unused_columns,
290
+ label_names = label_names,
291
+ load_best_model_at_end = load_best_model_at_end,
292
+ metric_for_best_model = metric_for_best_model,
293
+ greater_is_better = greater_is_better,
294
+ ignore_data_skip = ignore_data_skip,
295
+ fsdp = fsdp,
296
+ fsdp_min_num_params = fsdp_min_num_params,
297
+ fsdp_config = fsdp_config,
298
+ tp_size = tp_size,
299
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
300
+ accelerator_config = accelerator_config,
301
+ deepspeed = deepspeed,
302
+ label_smoothing_factor = label_smoothing_factor,
303
+ optim = optim,
304
+ optim_args = optim_args,
305
+ adafactor = adafactor,
306
+ group_by_length = group_by_length,
307
+ length_column_name = length_column_name,
308
+ report_to = report_to,
309
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
310
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
311
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
312
+ dataloader_pin_memory = dataloader_pin_memory,
313
+ dataloader_persistent_workers = dataloader_persistent_workers,
314
+ skip_memory_metrics = skip_memory_metrics,
315
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
316
+ push_to_hub = push_to_hub,
317
+ resume_from_checkpoint = resume_from_checkpoint,
318
+ hub_model_id = hub_model_id,
319
+ hub_strategy = hub_strategy,
320
+ hub_token = hub_token,
321
+ hub_private_repo = hub_private_repo,
322
+ hub_always_push = hub_always_push,
323
+ gradient_checkpointing = gradient_checkpointing,
324
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
325
+ include_inputs_for_metrics = include_inputs_for_metrics,
326
+ eval_do_concat_batches = eval_do_concat_batches,
327
+ fp16_backend = fp16_backend,
328
+ evaluation_strategy = evaluation_strategy,
329
+ push_to_hub_model_id = push_to_hub_model_id,
330
+ push_to_hub_organization = push_to_hub_organization,
331
+ push_to_hub_token = push_to_hub_token,
332
+ mp_parameters = mp_parameters,
333
+ auto_find_batch_size = auto_find_batch_size,
334
+ full_determinism = full_determinism,
335
+ torchdynamo = torchdynamo,
336
+ ray_scope = ray_scope,
337
+ ddp_timeout = ddp_timeout,
338
+ torch_compile = torch_compile,
339
+ torch_compile_backend = torch_compile_backend,
340
+ torch_compile_mode = torch_compile_mode,
341
+ dispatch_batches = dispatch_batches,
342
+ split_batches = split_batches,
343
+ include_tokens_per_second = include_tokens_per_second,
344
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
345
+ neftune_noise_alpha = neftune_noise_alpha,
346
+ optim_target_modules = optim_target_modules,
347
+ batch_eval_metrics = batch_eval_metrics,
348
+ eval_on_start = eval_on_start,
349
+ use_liger_kernel = use_liger_kernel,
350
+ eval_use_gather_object = eval_use_gather_object,
351
+ average_tokens_across_devices = average_tokens_across_devices,
352
+ reward_model_path = reward_model_path,
353
+ judge = judge,
354
+ max_new_tokens = max_new_tokens,
355
+ max_length = max_length,
356
+ temperature = temperature,
357
+ missing_eos_penalty = missing_eos_penalty,
358
+ loss_type = loss_type,
359
+ dataset_num_proc = dataset_num_proc,
360
+ disable_dropout = disable_dropout,
361
+ use_vllm = use_vllm,
362
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
363
+ self.vllm_sampling_params = vllm_sampling_params
364
+ self.unsloth_num_chunks = unsloth_num_chunks
365
+ pass
366
+
367
+ class _UnslothNashMDTrainer(OnlineDPOTrainer):
368
+ r""""""
369
+
370
+ _tag_names = ["trl", "nash-md"]
371
+
372
+ def __init__(
373
+ self,
374
+ model: Union[PreTrainedModel, nn.Module] = None,
375
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
376
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
377
+ judge: Optional[BasePairwiseJudge] = None,
378
+ args: Optional[NashMDConfig] = None,
379
+ data_collator: Optional[Callable] = None,
380
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
381
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
382
+ processing_class: Optional[
383
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
384
+ ] = None,
385
+ peft_config: Optional[dict] = None,
386
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
387
+ callbacks: Optional[list[TrainerCallback]] = None,
388
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
389
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
390
+ ) -> None:
391
+ super().__init__(
392
+ model=model,
393
+ ref_model=ref_model,
394
+ reward_model=reward_model,
395
+ judge=judge,
396
+ args=args,
397
+ data_collator=data_collator,
398
+ train_dataset=train_dataset,
399
+ eval_dataset=eval_dataset,
400
+ processing_class=processing_class,
401
+ reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
402
+ peft_config=peft_config,
403
+ compute_metrics=compute_metrics,
404
+ callbacks=callbacks,
405
+ optimizers=optimizers,
406
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
407
+ )
408
+
409
+ self._mixture_coef = self.args.mixture_coef
410
+
411
+ # Overwrite the stats dictionary to include NashMD specific statistics
412
+ self.stats = {
413
+ # Remove "non_score_reward", "rlhf_reward", "scores_margin"
414
+ # Add "mixture_coef"
415
+ "loss/kl": [],
416
+ "objective/entropy": [],
417
+ "loss/score": [],
418
+ "rewards/probabilities": [],
419
+ "rewards/accuracies": [],
420
+ "rewards/margins": [],
421
+ "logps/chosen": [],
422
+ "logps/rejected": [],
423
+ "val/model_contain_eos_token": [],
424
+ "val/ref_contain_eos_token": [],
425
+ "beta": [],
426
+ "mixture_coef": [],
427
+ }
428
+ if self.reward_model is not None:
429
+ self.stats["rewards/chosen"] = []
430
+ self.stats["rewards/rejected"] = []
431
+
432
+ @property
433
+ def mixture_coef(self):
434
+ if isinstance(self._mixture_coef, list):
435
+ epoch = self.state.epoch
436
+ return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
437
+ else:
438
+ return self._mixture_coef
439
+
440
+ def _generate_completions(self, model, prompts):
441
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
442
+ model_output = unwrapped_model.generate(
443
+ input_ids=prompts["input_ids"],
444
+ attention_mask=prompts["attention_mask"],
445
+ generation_config=self.generation_config,
446
+ )
447
+
448
+ ref_model = model if self.ref_model is None else self.ref_model
449
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
450
+ mixture_model = GeometricMixtureWrapper(
451
+ model=unwrapped_model,
452
+ ref_model=unwrapped_ref_model,
453
+ generation_config=self.generation_config,
454
+ mixture_coef=self.mixture_coef,
455
+ device=self.accelerator.device,
456
+ )
457
+
458
+ mixture_output = mixture_model.generate(
459
+ input_ids=prompts["input_ids"],
460
+ attention_mask=prompts["attention_mask"],
461
+ generation_config=self.generation_config,
462
+ )
463
+
464
+ return model_output, mixture_output
465
+
466
+ def _process_completions(self, model_output, mixture_output, prompts):
467
+ context_length = prompts["input_ids"].shape[1]
468
+
469
+ # Process model completions
470
+ model_completion_ids = model_output[:, context_length:]
471
+ model_completion_ids, model_completion_mask = truncate_right(
472
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
473
+ )
474
+ model_data = {
475
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
476
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
477
+ "raw": prompts["raw"],
478
+ }
479
+
480
+ # Process reference model completions
481
+ mixture_completion_ids = mixture_output[:, context_length:]
482
+ mixture_completion_ids, mixture_completion_mask = truncate_right(
483
+ mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
484
+ )
485
+ mixture_data = {
486
+ "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
487
+ "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
488
+ "raw": prompts["raw"],
489
+ }
490
+
491
+ return model_data, mixture_data
492
+
493
+ def _compute_rewards(self, model_data, mixture_data, context_length):
494
+ with torch.no_grad():
495
+ _, model_scores, _ = get_reward(
496
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
497
+ )
498
+ _, mixture_scores, _ = get_reward(
499
+ self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
500
+ )
501
+
502
+ # Apply EOS penalty if needed
503
+ if self.args.missing_eos_penalty is not None:
504
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
505
+ mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
506
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
507
+ mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
508
+
509
+ return model_scores, mixture_scores
510
+
511
+ def _compute_judge(self, model_data, mixture_data, context_length):
512
+ prompts = model_data["raw"]
513
+ model_data_completions = self.processing_class.batch_decode(
514
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
515
+ )
516
+ model_data_completions = [completion.strip() for completion in model_data_completions]
517
+
518
+ mixture_data_completions = self.processing_class.batch_decode(
519
+ mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
520
+ )
521
+ mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
522
+ if is_conversational({"prompt": prompts[0]}):
523
+ model_data_completions = [
524
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
525
+ ]
526
+ environment = jinja2.Environment()
527
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
528
+ prompts = [template.render(messages=message) for message in prompts]
529
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
530
+
531
+ mixture_data_completions = [
532
+ [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
533
+ ]
534
+ mixture_data_completions = [
535
+ template.render(messages=completion) for completion in mixture_data_completions
536
+ ]
537
+
538
+ probability = self.judge.judge(
539
+ prompts,
540
+ list(zip(model_data_completions, mixture_data_completions)),
541
+ return_scores=True,
542
+ )
543
+ return torch.tensor(probability, device=model_data["input_ids"].device)
544
+
545
+ def _compute_logprobs(self, model, model_data, context_length):
546
+ def compute_logprobs_for_data(m, data):
547
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
548
+ logits = output.logits[:, context_length - 1 : -1]
549
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
550
+ return token_logprobs
551
+
552
+ # Compute logprobs for model completions under the model
553
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
554
+
555
+ # Compute logprobs of model completions under the reference model
556
+ with torch.no_grad():
557
+ if self.ref_model is None:
558
+ with model.disable_adapter():
559
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
560
+ else:
561
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
562
+
563
+ # Mask padding tokens
564
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
565
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
566
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
567
+
568
+ return (model_logprobs_model_data, ref_logprobs_model_data)
569
+
570
+ def _compute_losses(
571
+ self,
572
+ model_logprobs_model_data,
573
+ ref_logprobs_model_data,
574
+ probability,
575
+ ):
576
+ # reinforce score where 0.5 is a control variate
577
+ score = (probability - 0.5) * model_logprobs_model_data.sum(1)
578
+
579
+ # kl divergence via reinforce
580
+ with torch.no_grad():
581
+ log_ratio = model_logprobs_model_data - ref_logprobs_model_data
582
+ kl_div_log = log_ratio.sum(1)
583
+ kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
584
+
585
+ # final loss
586
+ loss = self.beta * kl_div_loss - score
587
+
588
+ return loss.mean(), score, kl_div_log
589
+
590
+ def _log_statistics(
591
+ self,
592
+ model_data,
593
+ mixture_data,
594
+ model_logprobs_model_data,
595
+ ref_logprobs_model_data,
596
+ probability,
597
+ score,
598
+ kl_div,
599
+ context_length,
600
+ model_scores=None,
601
+ mixture_scores=None,
602
+ ):
603
+ # Helper function to gather and compute mean
604
+ def gather_mean(tensor):
605
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
606
+
607
+ # Log score
608
+ self.stats["loss/score"].append(gather_mean(score))
609
+ # Log KL divergence
610
+ self.stats["loss/kl"].append(gather_mean(kl_div))
611
+
612
+ # Log logprobs
613
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
614
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
615
+
616
+ self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
617
+ self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
618
+
619
+ # Log rewards
620
+ if self.reward_model is not None:
621
+ self.stats["rewards/chosen"].append(gather_mean(model_scores))
622
+ self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
623
+
624
+ # Log probabilities
625
+ self.stats["rewards/probabilities"].append(gather_mean(probability))
626
+
627
+ # Calculate entropy for model data
628
+ entropy_model_data = -model_logprobs_model_data.sum(1)
629
+ self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
630
+
631
+ # Calculate margins
632
+ margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
633
+ self.stats["rewards/margins"].append(gather_mean(margin))
634
+
635
+ # Calculate accuracy
636
+ accuracy = (margin > 0).float()
637
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy))
638
+
639
+ # Log EOS token statistics
640
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
641
+ mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
642
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
643
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
644
+
645
+ # Log beta and mixture coef
646
+ self.stats["beta"].append(self.beta)
647
+ self.stats["mixture_coef"].append(self.mixture_coef)
648
+
649
+ def training_step(
650
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
651
+ ) -> torch.Tensor:
652
+ model.train()
653
+
654
+ # Apply chat template and tokenize the input
655
+ batch_size = len(next(iter(inputs.values())))
656
+ prompts = inputs["prompt"]
657
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
658
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
659
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
660
+ inputs = self.data_collator(inputs)
661
+
662
+ # need the prompt_ only
663
+ inputs = self._prepare_inputs(inputs)
664
+ context_length = inputs["prompt_input_ids"].shape[1]
665
+ prompts = {
666
+ "input_ids": inputs["prompt_input_ids"],
667
+ "attention_mask": inputs["prompt_attention_mask"],
668
+ "raw": prompts,
669
+ }
670
+ del inputs
671
+
672
+ # Sample completions from both the model and the reference model
673
+ model_output, mixture_output = self._generate_completions(model, prompts)
674
+
675
+ # Process model completions
676
+ model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
677
+
678
+ # Compute rewards
679
+ if self.reward_model is not None:
680
+ model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
681
+ # probability of the model data vs the mixture data
682
+ probability = F.sigmoid(model_scores - mixture_scores)
683
+ else:
684
+ model_scores, mixture_scores = None, None
685
+ probability = self._compute_judge(model_data, mixture_data, context_length)
686
+
687
+ # Compute logprobs
688
+ model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
689
+
690
+ # Compute loss
691
+ loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
692
+
693
+ # Log everything
694
+ self._log_statistics(
695
+ model_data,
696
+ mixture_data,
697
+ model_logprobs_model_data.detach(),
698
+ ref_logprobs_model_data,
699
+ probability,
700
+ score.detach(),
701
+ kl_div.detach(),
702
+ context_length,
703
+ model_scores,
704
+ mixture_scores,
705
+ )
706
+
707
+ if (
708
+ self.args.torch_empty_cache_steps is not None
709
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
710
+ ):
711
+ empty_cache()
712
+
713
+ kwargs = {}
714
+ # For LOMO optimizers you need to explicitly use the learning rate
715
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
716
+ kwargs["learning_rate"] = self._get_learning_rate()
717
+
718
+ if self.args.n_gpu > 1:
719
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
720
+
721
+ if self.use_apex:
722
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
723
+ scaled_loss.backward()
724
+ else:
725
+ self.accelerator.backward(loss, **kwargs)
726
+
727
+ return loss.detach() / self.args.gradient_accumulation_steps
728
+
729
+ def create_model_card(
730
+ self,
731
+ model_name: Optional[str] = None,
732
+ dataset_name: Optional[str] = None,
733
+ tags: Union[str, list[str], None] = None,
734
+ ):
735
+ """
736
+ Creates a draft of a model card using the information available to the `Trainer`.
737
+
738
+ Args:
739
+ model_name (`str` or `None`, *optional*, defaults to `None`):
740
+ Name of the model.
741
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
742
+ Name of the dataset used for training.
743
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
744
+ Tags to be associated with the model card.
745
+ """
746
+ if not self.is_world_process_zero():
747
+ return
748
+
749
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
750
+ base_model = self.model.config._name_or_path
751
+ else:
752
+ base_model = None
753
+
754
+ tags = tags or []
755
+ if isinstance(tags, str):
756
+ tags = [tags]
757
+
758
+ if hasattr(self.model.config, "unsloth_version"):
759
+ tags.append("unsloth")
760
+
761
+ citation = textwrap.dedent("""\
762
+ @inproceedings{munos2024nash,
763
+ title = {{Nash Learning from Human Feedback}},
764
+ author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
765
+ year = 2024,
766
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
767
+ publisher = {OpenReview.net},
768
+ url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
769
+ }""")
770
+
771
+ model_card = generate_model_card(
772
+ base_model=base_model,
773
+ model_name=model_name,
774
+ hub_model_id=self.hub_model_id,
775
+ dataset_name=dataset_name,
776
+ tags=tags,
777
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
778
+ comet_url=get_comet_experiment_url(),
779
+ trainer_name="Nash-MD",
780
+ trainer_citation=citation,
781
+ paper_title="Nash Learning from Human Feedback",
782
+ paper_id="2312.00886",
783
+ )
784
+
785
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
786
+ class UnslothNashMDTrainer(_UnslothNashMDTrainer):
787
+ """
788
+
789
+ Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
790
+
791
+ Args:
792
+ model (`transformers.PreTrainedModel`):
793
+ The model to train, preferably an `AutoModelForCausalLM`.
794
+ ref_model (`PreTrainedModelWrapper`):
795
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
796
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
797
+ reward_model (`transformers.PreTrainedModel`):
798
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
799
+ judge (`BasePairwiseJudge`):
800
+ The judge to use for pairwise comparison of model completions.
801
+ args (`NashMDConfig`):
802
+ The NashMD config arguments to use for training.
803
+ data_collator (`transformers.DataCollator`):
804
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
805
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
806
+ train_dataset (`datasets.Dataset`):
807
+ The dataset to use for training.
808
+ eval_dataset (`datasets.Dataset`):
809
+ The dataset to use for evaluation.
810
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
811
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
812
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
813
+ reuse the fine-tuned model.
814
+ peft_config (`dict`):
815
+ The peft config to use for training.
816
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
817
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
818
+ a dictionary string to metric values.
819
+ callbacks (`list[transformers.TrainerCallback]`):
820
+ The callbacks to use for training.
821
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
822
+ The optimizer and scheduler to use for training.
823
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
824
+ The function to use to preprocess the logits before computing the metrics.
825
+
826
+ """
827
+ def __init__(
828
+ self,
829
+ model = None,
830
+ ref_model = None,
831
+ reward_model = None,
832
+ judge = None,
833
+ args = None,
834
+ data_collator = None,
835
+ train_dataset = None,
836
+ eval_dataset = None,
837
+ processing_class = None,
838
+ peft_config = None,
839
+ compute_metrics = None,
840
+ callbacks = None,
841
+ preprocess_logits_for_metrics = None,
842
+ **kwargs
843
+ ):
844
+ if args is None: args = UnslothNashMDConfig()
845
+ use_bf16 = getattr(args, 'bf16', False)
846
+ use_fp16 = getattr(args, 'fp16', False)
847
+ force_float32 = False
848
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
849
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
850
+ force_float32 = True
851
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
852
+ dtype = getattr(model.config, 'torch_dtype', None)
853
+ if dtype is None: dtype = model.get_input_embeddings().dtype
854
+ from unsloth_zoo.utils import _get_dtype
855
+ dtype = _get_dtype(dtype)
856
+ float16 = dtype == torch.float16
857
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
858
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
859
+ if force_float32:
860
+ args.fp16 = False
861
+ args.bf16 = False
862
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
863
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
864
+ args.fp16 = float16
865
+ args.bf16 = not float16
866
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
867
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
868
+ args.eval_strategy = 'steps'
869
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
870
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
871
+ if ga_steps is not None and ga_steps > 1:
872
+ from transformers import __version__ as transformers_version
873
+ if Version(transformers_version) <= Version('4.45.2'):
874
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
875
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
876
+ if getattr(args, 'eval_strategy', 'no') != 'no':
877
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
878
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
879
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
880
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
881
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
882
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
883
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
884
+ if force_float32:
885
+ args.bf16_full_eval = False
886
+ args.fp16_full_eval = False
887
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
888
+ args.bf16_full_eval = True
889
+ args.fp16_full_eval = False
890
+ elif not bf16_full_eval and not fp16_full_eval:
891
+ args.bf16_full_eval = args.bf16
892
+ args.fp16_full_eval = args.fp16
893
+ _output_logits = False
894
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
895
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
896
+ if _output_logits:
897
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
898
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
899
+ pass
900
+ else:
901
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
902
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
903
+ if args_max_seq_length is None and model_max_seq_length is not None:
904
+ max_seq_length = model.max_seq_length
905
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
906
+ if model is not None and hasattr(model, 'for_training'):
907
+ model.for_training()
908
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
909
+ if 'processing_class' in locals():
910
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
911
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
912
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
913
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
914
+ if not isinstance(data_collator, UnslothVisionDataCollator):
915
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
916
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
917
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
918
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
919
+ else:
920
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
921
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
922
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
923
+ if not isinstance(data_collator, UnslothVisionDataCollator):
924
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
925
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
926
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
927
+ else:
928
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
929
+ other_metrics = []
930
+
931
+ from unsloth_zoo.logging_utils import PatchRLStatistics
932
+ PatchRLStatistics('nash_md_trainer', other_metrics)
933
+
934
+ super().__init__(
935
+ model = model,
936
+ ref_model = ref_model,
937
+ reward_model = reward_model,
938
+ judge = judge,
939
+ args = args,
940
+ data_collator = data_collator,
941
+ train_dataset = train_dataset,
942
+ eval_dataset = eval_dataset,
943
+ processing_class = processing_class,
944
+ peft_config = peft_config,
945
+ compute_metrics = compute_metrics,
946
+ callbacks = callbacks,
947
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
948
+ if hasattr(self, 'neftune_hook_handle'):
949
+ self.neftune_hook_handle.remove()
950
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
951
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
952
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
953
+ pass
954
+
955
+ pass
unsloth_compiled_cache/UnslothORPOTrainer.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothORPOConfig(ORPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`ORPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-6`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
58
+ to use the default data collator.
59
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
60
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
61
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
62
+ Maximum length of the completion. This argument is required if you want to use the default data collator
63
+ and your model is an encoder-decoder.
64
+ beta (`float`, *optional*, defaults to `0.1`):
65
+ Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
66
+ it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
67
+ disable_dropout (`bool`, *optional*, defaults to `True`):
68
+ Whether to disable dropout in the model.
69
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
70
+ Label pad token id. This argument is required if you want to use the default data collator.
71
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
72
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
73
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
74
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
75
+ This argument is required if you want to use the default data collator.
76
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
77
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
78
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
79
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
80
+ you need to specify if the model returned by the callable is an encoder-decoder model.
81
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
82
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
83
+ string.
84
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
85
+ Number of processes to use for processing the dataset.
86
+
87
+ """
88
+ vllm_sampling_params: Optional[Any] = field(
89
+ default = None,
90
+ metadata = {'help': 'vLLM SamplingParams'},
91
+ )
92
+ unsloth_num_chunks : Optional[int] = field(
93
+ default = -1,
94
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
95
+ )
96
+ def __init__(
97
+ self,
98
+ output_dir = None,
99
+ overwrite_output_dir = None,
100
+ do_train = False,
101
+ do_eval = False,
102
+ do_predict = False,
103
+ eval_strategy = 'no',
104
+ prediction_loss_only = False,
105
+ per_device_train_batch_size = 4,
106
+ per_device_eval_batch_size = 4,
107
+ per_gpu_train_batch_size = None,
108
+ per_gpu_eval_batch_size = None,
109
+ gradient_accumulation_steps = 2,
110
+ eval_accumulation_steps = 2,
111
+ eval_delay = 0,
112
+ torch_empty_cache_steps = 250,
113
+ learning_rate = 5e-05,
114
+ weight_decay = 0.01,
115
+ adam_beta1 = 0.9,
116
+ adam_beta2 = 0.999,
117
+ adam_epsilon = 1e-08,
118
+ max_grad_norm = 1.0,
119
+ num_train_epochs = 3.0,
120
+ max_steps = -1,
121
+ lr_scheduler_type = 'linear',
122
+ warmup_ratio = 0.1,
123
+ warmup_steps = 0,
124
+ log_level = 'passive',
125
+ log_level_replica = 'warning',
126
+ log_on_each_node = True,
127
+ logging_dir = None,
128
+ logging_strategy = 'steps',
129
+ logging_first_step = False,
130
+ logging_steps = 1,
131
+ logging_nan_inf_filter = False,
132
+ save_strategy = 'steps',
133
+ save_steps = 500,
134
+ save_total_limit = None,
135
+ save_safetensors = True,
136
+ save_on_each_node = False,
137
+ save_only_model = False,
138
+ restore_callback_states_from_checkpoint = False,
139
+ no_cuda = False,
140
+ use_cpu = False,
141
+ use_mps_device = False,
142
+ seed = 3407,
143
+ data_seed = 3407,
144
+ jit_mode_eval = False,
145
+ use_ipex = False,
146
+ bf16 = False,
147
+ fp16 = False,
148
+ fp16_opt_level = 'O1',
149
+ half_precision_backend = 'auto',
150
+ bf16_full_eval = False,
151
+ fp16_full_eval = False,
152
+ tf32 = None,
153
+ local_rank = -1,
154
+ ddp_backend = None,
155
+ tpu_num_cores = None,
156
+ tpu_metrics_debug = False,
157
+ debug = '',
158
+ dataloader_drop_last = False,
159
+ eval_steps = None,
160
+ dataloader_num_workers = 0,
161
+ dataloader_prefetch_factor = None,
162
+ past_index = -1,
163
+ run_name = None,
164
+ disable_tqdm = None,
165
+ remove_unused_columns = True,
166
+ label_names = None,
167
+ load_best_model_at_end = False,
168
+ metric_for_best_model = None,
169
+ greater_is_better = None,
170
+ ignore_data_skip = False,
171
+ fsdp = '',
172
+ fsdp_min_num_params = 0,
173
+ fsdp_config = None,
174
+ tp_size = 0,
175
+ fsdp_transformer_layer_cls_to_wrap = None,
176
+ accelerator_config = None,
177
+ deepspeed = None,
178
+ label_smoothing_factor = 0.0,
179
+ optim = 'adamw_8bit',
180
+ optim_args = None,
181
+ adafactor = False,
182
+ group_by_length = False,
183
+ length_column_name = 'length',
184
+ report_to = None,
185
+ ddp_find_unused_parameters = None,
186
+ ddp_bucket_cap_mb = None,
187
+ ddp_broadcast_buffers = None,
188
+ dataloader_pin_memory = True,
189
+ dataloader_persistent_workers = False,
190
+ skip_memory_metrics = True,
191
+ use_legacy_prediction_loop = False,
192
+ push_to_hub = False,
193
+ resume_from_checkpoint = None,
194
+ hub_model_id = None,
195
+ hub_strategy = 'every_save',
196
+ hub_token = None,
197
+ hub_private_repo = None,
198
+ hub_always_push = False,
199
+ gradient_checkpointing = False,
200
+ gradient_checkpointing_kwargs = None,
201
+ include_inputs_for_metrics = False,
202
+ eval_do_concat_batches = True,
203
+ fp16_backend = 'auto',
204
+ evaluation_strategy = None,
205
+ push_to_hub_model_id = None,
206
+ push_to_hub_organization = None,
207
+ push_to_hub_token = None,
208
+ mp_parameters = '',
209
+ auto_find_batch_size = False,
210
+ full_determinism = False,
211
+ torchdynamo = None,
212
+ ray_scope = 'last',
213
+ ddp_timeout = 1800,
214
+ torch_compile = False,
215
+ torch_compile_backend = None,
216
+ torch_compile_mode = None,
217
+ dispatch_batches = None,
218
+ split_batches = None,
219
+ include_tokens_per_second = False,
220
+ include_num_input_tokens_seen = False,
221
+ neftune_noise_alpha = None,
222
+ optim_target_modules = None,
223
+ batch_eval_metrics = False,
224
+ eval_on_start = False,
225
+ use_liger_kernel = False,
226
+ eval_use_gather_object = False,
227
+ average_tokens_across_devices = False,
228
+ max_length = 1024,
229
+ max_prompt_length = 512,
230
+ max_completion_length = None,
231
+ beta = 0.1,
232
+ disable_dropout = True,
233
+ label_pad_token_id = -100,
234
+ padding_value = None,
235
+ truncation_mode = 'keep_end',
236
+ generate_during_eval = False,
237
+ is_encoder_decoder = None,
238
+ model_init_kwargs = None,
239
+ dataset_num_proc = None,
240
+ vllm_sampling_params = None,
241
+ unsloth_num_chunks = -1,
242
+ **kwargs,
243
+ ):
244
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
245
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
246
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
247
+ output_dir = 'unsloth_training_checkpoints'
248
+ save_strategy = 'no'
249
+ if dataset_num_proc is None:
250
+ from multiprocessing import cpu_count
251
+ dataset_num_proc = cpu_count()
252
+
253
+ super().__init__(
254
+ output_dir = output_dir,
255
+ overwrite_output_dir = overwrite_output_dir,
256
+ do_train = do_train,
257
+ do_eval = do_eval,
258
+ do_predict = do_predict,
259
+ eval_strategy = eval_strategy,
260
+ prediction_loss_only = prediction_loss_only,
261
+ per_device_train_batch_size = per_device_train_batch_size,
262
+ per_device_eval_batch_size = per_device_eval_batch_size,
263
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
264
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
265
+ gradient_accumulation_steps = gradient_accumulation_steps,
266
+ eval_accumulation_steps = eval_accumulation_steps,
267
+ eval_delay = eval_delay,
268
+ torch_empty_cache_steps = torch_empty_cache_steps,
269
+ learning_rate = learning_rate,
270
+ weight_decay = weight_decay,
271
+ adam_beta1 = adam_beta1,
272
+ adam_beta2 = adam_beta2,
273
+ adam_epsilon = adam_epsilon,
274
+ max_grad_norm = max_grad_norm,
275
+ num_train_epochs = num_train_epochs,
276
+ max_steps = max_steps,
277
+ lr_scheduler_type = lr_scheduler_type,
278
+ warmup_ratio = warmup_ratio,
279
+ warmup_steps = warmup_steps,
280
+ log_level = log_level,
281
+ log_level_replica = log_level_replica,
282
+ log_on_each_node = log_on_each_node,
283
+ logging_dir = logging_dir,
284
+ logging_strategy = logging_strategy,
285
+ logging_first_step = logging_first_step,
286
+ logging_steps = logging_steps,
287
+ logging_nan_inf_filter = logging_nan_inf_filter,
288
+ save_strategy = save_strategy,
289
+ save_steps = save_steps,
290
+ save_total_limit = save_total_limit,
291
+ save_safetensors = save_safetensors,
292
+ save_on_each_node = save_on_each_node,
293
+ save_only_model = save_only_model,
294
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
295
+ no_cuda = no_cuda,
296
+ use_cpu = use_cpu,
297
+ use_mps_device = use_mps_device,
298
+ seed = seed,
299
+ data_seed = data_seed,
300
+ jit_mode_eval = jit_mode_eval,
301
+ use_ipex = use_ipex,
302
+ bf16 = bf16,
303
+ fp16 = fp16,
304
+ fp16_opt_level = fp16_opt_level,
305
+ half_precision_backend = half_precision_backend,
306
+ bf16_full_eval = bf16_full_eval,
307
+ fp16_full_eval = fp16_full_eval,
308
+ tf32 = tf32,
309
+ local_rank = local_rank,
310
+ ddp_backend = ddp_backend,
311
+ tpu_num_cores = tpu_num_cores,
312
+ tpu_metrics_debug = tpu_metrics_debug,
313
+ debug = debug,
314
+ dataloader_drop_last = dataloader_drop_last,
315
+ eval_steps = eval_steps,
316
+ dataloader_num_workers = dataloader_num_workers,
317
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
318
+ past_index = past_index,
319
+ run_name = run_name,
320
+ disable_tqdm = disable_tqdm,
321
+ remove_unused_columns = remove_unused_columns,
322
+ label_names = label_names,
323
+ load_best_model_at_end = load_best_model_at_end,
324
+ metric_for_best_model = metric_for_best_model,
325
+ greater_is_better = greater_is_better,
326
+ ignore_data_skip = ignore_data_skip,
327
+ fsdp = fsdp,
328
+ fsdp_min_num_params = fsdp_min_num_params,
329
+ fsdp_config = fsdp_config,
330
+ tp_size = tp_size,
331
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
332
+ accelerator_config = accelerator_config,
333
+ deepspeed = deepspeed,
334
+ label_smoothing_factor = label_smoothing_factor,
335
+ optim = optim,
336
+ optim_args = optim_args,
337
+ adafactor = adafactor,
338
+ group_by_length = group_by_length,
339
+ length_column_name = length_column_name,
340
+ report_to = report_to,
341
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
342
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
343
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
344
+ dataloader_pin_memory = dataloader_pin_memory,
345
+ dataloader_persistent_workers = dataloader_persistent_workers,
346
+ skip_memory_metrics = skip_memory_metrics,
347
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
348
+ push_to_hub = push_to_hub,
349
+ resume_from_checkpoint = resume_from_checkpoint,
350
+ hub_model_id = hub_model_id,
351
+ hub_strategy = hub_strategy,
352
+ hub_token = hub_token,
353
+ hub_private_repo = hub_private_repo,
354
+ hub_always_push = hub_always_push,
355
+ gradient_checkpointing = gradient_checkpointing,
356
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
357
+ include_inputs_for_metrics = include_inputs_for_metrics,
358
+ eval_do_concat_batches = eval_do_concat_batches,
359
+ fp16_backend = fp16_backend,
360
+ evaluation_strategy = evaluation_strategy,
361
+ push_to_hub_model_id = push_to_hub_model_id,
362
+ push_to_hub_organization = push_to_hub_organization,
363
+ push_to_hub_token = push_to_hub_token,
364
+ mp_parameters = mp_parameters,
365
+ auto_find_batch_size = auto_find_batch_size,
366
+ full_determinism = full_determinism,
367
+ torchdynamo = torchdynamo,
368
+ ray_scope = ray_scope,
369
+ ddp_timeout = ddp_timeout,
370
+ torch_compile = torch_compile,
371
+ torch_compile_backend = torch_compile_backend,
372
+ torch_compile_mode = torch_compile_mode,
373
+ dispatch_batches = dispatch_batches,
374
+ split_batches = split_batches,
375
+ include_tokens_per_second = include_tokens_per_second,
376
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
377
+ neftune_noise_alpha = neftune_noise_alpha,
378
+ optim_target_modules = optim_target_modules,
379
+ batch_eval_metrics = batch_eval_metrics,
380
+ eval_on_start = eval_on_start,
381
+ use_liger_kernel = use_liger_kernel,
382
+ eval_use_gather_object = eval_use_gather_object,
383
+ average_tokens_across_devices = average_tokens_across_devices,
384
+ max_length = max_length,
385
+ max_prompt_length = max_prompt_length,
386
+ max_completion_length = max_completion_length,
387
+ beta = beta,
388
+ disable_dropout = disable_dropout,
389
+ label_pad_token_id = label_pad_token_id,
390
+ padding_value = padding_value,
391
+ truncation_mode = truncation_mode,
392
+ generate_during_eval = generate_during_eval,
393
+ is_encoder_decoder = is_encoder_decoder,
394
+ model_init_kwargs = model_init_kwargs,
395
+ dataset_num_proc = dataset_num_proc,**kwargs)
396
+ self.vllm_sampling_params = vllm_sampling_params
397
+ self.unsloth_num_chunks = unsloth_num_chunks
398
+ pass
399
+
400
+ class _UnslothORPOTrainer(Trainer):
401
+ r""""""
402
+
403
+ _tag_names = ["trl", "orpo"]
404
+
405
+ def __init__(
406
+ self,
407
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
408
+ args: Optional[ORPOConfig] = None,
409
+ data_collator: Optional[DataCollator] = None,
410
+ train_dataset: Optional[Dataset] = None,
411
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
412
+ processing_class: Optional[
413
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
414
+ ] = None,
415
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
416
+ callbacks: Optional[list[TrainerCallback]] = None,
417
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
418
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
419
+ peft_config: Optional[dict] = None,
420
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
421
+ ):
422
+ if args.model_init_kwargs is None:
423
+ model_init_kwargs = {}
424
+ elif not isinstance(model, str):
425
+ raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
426
+ else:
427
+ model_init_kwargs = args.model_init_kwargs
428
+ torch_dtype = model_init_kwargs.get("torch_dtype")
429
+ if torch_dtype is not None:
430
+ # Convert to `torch.dtype` if an str is passed
431
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
432
+ torch_dtype = getattr(torch, torch_dtype)
433
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
434
+ raise ValueError(
435
+ f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
436
+ )
437
+ model_init_kwargs["torch_dtype"] = torch_dtype
438
+
439
+ if isinstance(model, str):
440
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
441
+
442
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
443
+ # has been called in order to properly call autocast if needed.
444
+ self._peft_has_been_casted_to_bf16 = False
445
+
446
+ if not is_peft_available() and peft_config is not None:
447
+ raise ValueError(
448
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
449
+ )
450
+ elif is_peft_available() and peft_config is not None:
451
+ # if model is a peft model and we have a peft_config, we merge and unload it first
452
+ if isinstance(model, PeftModel):
453
+ model = model.merge_and_unload()
454
+
455
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
456
+ _support_gc_kwargs = hasattr(
457
+ args, "gradient_checkpointing_kwargs"
458
+ ) and "gradient_checkpointing_kwargs" in list(
459
+ inspect.signature(prepare_model_for_kbit_training).parameters
460
+ )
461
+
462
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
463
+
464
+ if _support_gc_kwargs:
465
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
466
+
467
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
468
+ elif getattr(args, "gradient_checkpointing", False):
469
+ # For backward compatibility with older versions of transformers
470
+ if hasattr(model, "enable_input_require_grads"):
471
+ model.enable_input_require_grads()
472
+ else:
473
+
474
+ def make_inputs_require_grad(module, input, output):
475
+ output.requires_grad_(True)
476
+
477
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
478
+
479
+ # get peft model with the given config
480
+ model = model
481
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
482
+ peft_module_casting_to_bf16(model)
483
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
484
+ self._peft_has_been_casted_to_bf16 = True
485
+
486
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
487
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
488
+ # fail or completely fail.
489
+ elif getattr(args, "gradient_checkpointing", False):
490
+ # For backward compatibility with older versions of transformers
491
+ if hasattr(model, "enable_input_require_grads"):
492
+ model.enable_input_require_grads()
493
+ else:
494
+
495
+ def make_inputs_require_grad(module, input, output):
496
+ output.requires_grad_(True)
497
+
498
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
499
+
500
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
501
+ raise ValueError(
502
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
503
+ " Please install `wandb` or `comet-ml` to resolve."
504
+ )
505
+
506
+ if model is not None:
507
+ self.is_encoder_decoder = model.config.is_encoder_decoder
508
+ elif args.is_encoder_decoder is None:
509
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
510
+ else:
511
+ self.is_encoder_decoder = args.is_encoder_decoder
512
+
513
+ if self.is_encoder_decoder:
514
+ self.decoder_start_token_id = model.config.decoder_start_token_id
515
+ self.pad_token_id = model.config.pad_token_id
516
+
517
+ if processing_class is None:
518
+ raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
519
+ if args.max_length is None:
520
+ warnings.warn(
521
+ "`max_length` is not set in the ORPOConfig's init"
522
+ " it will default to `512` by default, but you should do it yourself in the future.",
523
+ UserWarning,
524
+ )
525
+ max_length = 512
526
+ else:
527
+ max_length = args.max_length
528
+ if args.max_prompt_length is None:
529
+ warnings.warn(
530
+ "`max_prompt_length` is not set in the ORPOConfig's init"
531
+ " it will default to `128` by default, but you should do it yourself in the future.",
532
+ UserWarning,
533
+ )
534
+ max_prompt_length = 128
535
+ else:
536
+ max_prompt_length = args.max_prompt_length
537
+
538
+ if args.max_completion_length is None and self.is_encoder_decoder:
539
+ warnings.warn(
540
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
541
+ " it will default to `128` by default, but you should do it yourself in the future.",
542
+ UserWarning,
543
+ )
544
+ self.max_completion_length = 128
545
+ else:
546
+ self.max_completion_length = args.max_completion_length
547
+
548
+ if data_collator is None:
549
+ data_collator = DPODataCollatorWithPadding(
550
+ pad_token_id=processing_class.pad_token_id,
551
+ label_pad_token_id=args.label_pad_token_id,
552
+ is_encoder_decoder=self.is_encoder_decoder,
553
+ )
554
+
555
+ if args.remove_unused_columns:
556
+ args.remove_unused_columns = False
557
+ # warn users
558
+ warnings.warn(
559
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
560
+ " we have set it for you, but you should do it yourself in the future.",
561
+ UserWarning,
562
+ )
563
+
564
+ self.use_dpo_data_collator = True
565
+ else:
566
+ self.use_dpo_data_collator = False
567
+
568
+ # Disable dropout in the model and reference model
569
+ if args.disable_dropout:
570
+ disable_dropout_in_model(model)
571
+
572
+ self.max_length = max_length
573
+ self.generate_during_eval = args.generate_during_eval
574
+ self.label_pad_token_id = args.label_pad_token_id
575
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
576
+ self.max_prompt_length = max_prompt_length
577
+ self.truncation_mode = args.truncation_mode
578
+ self.processing_class = processing_class
579
+
580
+ self.beta = args.beta
581
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
582
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
583
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
584
+ warnings.warn(
585
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
586
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
587
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
588
+ "loss.",
589
+ UserWarning,
590
+ )
591
+
592
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
593
+
594
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
595
+ # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
596
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
597
+ # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
598
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
599
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
600
+ # that the warning has already been issued.
601
+ model.warnings_issued["estimate_tokens"] = True
602
+
603
+ # Compute that only on the main process for faster data processing.
604
+ # see: https://github.com/huggingface/trl/pull/1255
605
+ with PartialState().local_main_process_first():
606
+ # Extract the prompt if needed, and apply the chat template if needed
607
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
608
+ train_dataset = train_dataset.map(
609
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
610
+ )
611
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
612
+ if eval_dataset is not None:
613
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
614
+ eval_dataset = eval_dataset.map(
615
+ maybe_apply_chat_template,
616
+ fn_kwargs={"tokenizer": processing_class},
617
+ num_proc=args.dataset_num_proc,
618
+ )
619
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
620
+
621
+ super().__init__(
622
+ model=model,
623
+ args=args,
624
+ data_collator=data_collator,
625
+ train_dataset=train_dataset,
626
+ eval_dataset=eval_dataset,
627
+ processing_class=processing_class,
628
+ model_init=model_init,
629
+ compute_metrics=compute_metrics,
630
+ callbacks=callbacks,
631
+ optimizers=optimizers,
632
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
633
+ )
634
+
635
+ # Add tags for models that have been loaded with the correct transformers version
636
+ if hasattr(self.model, "add_model_tags"):
637
+ self.model.add_model_tags(self._tag_names)
638
+
639
+ if not hasattr(self, "accelerator"):
640
+ raise AttributeError(
641
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
642
+ )
643
+
644
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
645
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
646
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
647
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
648
+
649
+ if model is not None:
650
+ if hasattr(model, "config"):
651
+ hidden_size = (
652
+ max(model.config.hidden_sizes)
653
+ if getattr(model.config, "hidden_sizes", None)
654
+ else getattr(model.config, "hidden_size", None)
655
+ )
656
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
657
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
658
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
659
+ config_kwargs.update(
660
+ {
661
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
662
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
663
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
664
+ }
665
+ )
666
+
667
+ # If ZeRO-3 is used, we shard both the active and reference model.
668
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
669
+ if config_kwargs["zero_optimization"]["stage"] != 3:
670
+ config_kwargs["zero_optimization"]["stage"] = 0
671
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
672
+ model.eval()
673
+ return model
674
+
675
+ def build_tokenized_answer(self, prompt, answer):
676
+ """
677
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
678
+ It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
679
+ Reference:
680
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
681
+ """
682
+
683
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
684
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
685
+
686
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
687
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
688
+
689
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
690
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
691
+
692
+ # Prepare input tokens for token by token comparison
693
+ full_input_ids = np.array(full_tokenized["input_ids"])
694
+
695
+ if len(full_input_ids) != len(full_concat_input_ids):
696
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
697
+
698
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
699
+ # can be merged together when tokenizing prompt+answer. This could result
700
+ # on the last token from the prompt being different when tokenized on its own
701
+ # vs when done as prompt+answer.
702
+ response_token_ids_start_idx = len(prompt_input_ids)
703
+
704
+ # If tokenized prompt is different than both prompt+answer, then it means the
705
+ # last token has changed due to merging.
706
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
707
+ response_token_ids_start_idx -= 1
708
+
709
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
710
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
711
+
712
+ if len(prompt_input_ids) != len(prompt_attention_mask):
713
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
714
+
715
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
716
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
717
+
718
+ return dict(
719
+ prompt_input_ids=prompt_input_ids,
720
+ prompt_attention_mask=prompt_attention_mask,
721
+ input_ids=answer_input_ids,
722
+ attention_mask=answer_attention_mask,
723
+ )
724
+
725
+ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
726
+ """Tokenize a single row from a ORPO specific dataset.
727
+
728
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
729
+ in case the prompt + chosen or prompt + rejected responses is/are too long. First
730
+ we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
731
+
732
+ We also create the labels for the chosen/rejected responses, which are of length equal to
733
+ the sum of the length of the prompt and the chosen/rejected response, with
734
+ label_pad_token_id for the prompt tokens.
735
+ """
736
+ batch = {}
737
+ prompt = feature["prompt"]
738
+ chosen = feature["chosen"]
739
+ rejected = feature["rejected"]
740
+
741
+ if not self.is_encoder_decoder:
742
+ # Check issues below for more details
743
+ # 1. https://github.com/huggingface/trl/issues/907
744
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
745
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
746
+
747
+ if not isinstance(prompt, str):
748
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
749
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
750
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
751
+
752
+ if not isinstance(chosen, str):
753
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
754
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
755
+
756
+ if not isinstance(rejected, str):
757
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
758
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
759
+
760
+ # Last prompt token might get merged by tokenizer and
761
+ # it should not be included for generation if that happens
762
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
763
+
764
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
765
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
766
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
767
+
768
+ for k, v in prompt_tokens.items():
769
+ prompt_tokens[k] = v[:prompt_len_input_ids]
770
+
771
+ # Make sure prompts only have one different token at most an
772
+ # and length only differs by 1 at most
773
+ num_diff_tokens = sum(
774
+ [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
775
+ )
776
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
777
+ if num_diff_tokens > 1 or num_diff_len > 1:
778
+ raise ValueError(
779
+ "Chosen and rejected prompt_input_ids might only differ on the "
780
+ "last token due to tokenizer merge ops."
781
+ )
782
+
783
+ # add BOS token to head of prompt. Avoid adding if it's already there
784
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
785
+ self.processing_class.bos_token_id,
786
+ prompt_len_input_ids,
787
+ prompt_tokens,
788
+ chosen_prompt_len_input_ids,
789
+ chosen_tokens,
790
+ rejected_prompt_len_input_ids,
791
+ rejected_tokens,
792
+ )
793
+
794
+ # add EOS token to end of answer. Avoid adding if it's already there
795
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
796
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
797
+ )
798
+
799
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
800
+
801
+ # if combined sequence is too long, truncate the prompt
802
+ for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
803
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
804
+ if self.truncation_mode == "keep_start":
805
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
806
+ answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
807
+ elif self.truncation_mode == "keep_end":
808
+ for k in ["prompt_input_ids", "prompt_attention_mask"]:
809
+ answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
810
+ else:
811
+ raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
812
+
813
+ # if that's still too long, truncate the response
814
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
815
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
816
+ for k in ["input_ids", "attention_mask"]:
817
+ answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
818
+
819
+ # Create labels
820
+ chosen_sequence_tokens = {
821
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
822
+ }
823
+ rejected_sequence_tokens = {
824
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
825
+ }
826
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
827
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
828
+ self.label_pad_token_id
829
+ ] * len(chosen_tokens["prompt_input_ids"])
830
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
831
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
832
+ self.label_pad_token_id
833
+ ] * len(rejected_tokens["prompt_input_ids"])
834
+
835
+ for k, toks in {
836
+ "chosen_": chosen_sequence_tokens,
837
+ "rejected_": rejected_sequence_tokens,
838
+ "": prompt_tokens,
839
+ }.items():
840
+ for type_key, tokens in toks.items():
841
+ if type_key == "token_type_ids":
842
+ continue
843
+ batch[f"{k}{type_key}"] = tokens
844
+
845
+ else:
846
+ chosen_tokens = self.processing_class(
847
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
848
+ )
849
+ rejected_tokens = self.processing_class(
850
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
851
+ )
852
+ prompt_tokens = self.processing_class(
853
+ prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
854
+ )
855
+
856
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
857
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
858
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
859
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
860
+
861
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
862
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
863
+ labels=torch.tensor(batch["rejected_labels"])
864
+ )
865
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
866
+ labels=torch.tensor(batch["chosen_labels"])
867
+ )
868
+
869
+ if is_torch_xla_available():
870
+ # Pad the sequences to global max_length to avoid TorchXLA recompilation
871
+ for k in batch:
872
+ if "labels" in k or self.is_encoder_decoder:
873
+ pad_value = self.label_pad_token_id
874
+ elif k.endswith("_input_ids"):
875
+ pad_value = self.padding_value
876
+ elif k.endswith("_attention_mask"):
877
+ pad_value = 0
878
+ batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
879
+ return batch
880
+
881
+ @staticmethod
882
+ def concatenated_inputs(
883
+ batch: dict[str, Union[list, torch.LongTensor]],
884
+ is_encoder_decoder: bool = False,
885
+ label_pad_token_id: int = -100,
886
+ padding_value: int = 0,
887
+ device: Optional[torch.device] = None,
888
+ ) -> dict[str, torch.LongTensor]:
889
+ """Concatenate the chosen and rejected inputs into a single tensor.
890
+
891
+ Args:
892
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
893
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
894
+ label_pad_token_id: The label pad token id.
895
+ padding_value: The padding value to use for the concatenated inputs_ids.
896
+ device: The device for the concatenated inputs.
897
+
898
+ Returns:
899
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
900
+ """
901
+ concatenated_batch = {}
902
+
903
+ if is_encoder_decoder:
904
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
905
+ else:
906
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
907
+
908
+ for k in batch:
909
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
910
+ if "labels" in k or is_encoder_decoder:
911
+ pad_value = label_pad_token_id
912
+ elif k.endswith("_input_ids"):
913
+ pad_value = padding_value
914
+ elif k.endswith("_attention_mask"):
915
+ pad_value = 0
916
+ concatenated_key = k.replace("chosen", "concatenated")
917
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
918
+ for k in batch:
919
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
920
+ if "labels" in k or is_encoder_decoder:
921
+ pad_value = label_pad_token_id
922
+ elif k.endswith("_input_ids"):
923
+ pad_value = padding_value
924
+ elif k.endswith("_attention_mask"):
925
+ pad_value = 0
926
+ concatenated_key = k.replace("rejected", "concatenated")
927
+ concatenated_batch[concatenated_key] = torch.cat(
928
+ (
929
+ concatenated_batch[concatenated_key],
930
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
931
+ ),
932
+ dim=0,
933
+ ).to(device=device)
934
+
935
+ if is_encoder_decoder:
936
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
937
+ concatenated_batch["concatenated_attention_mask"] = (
938
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
939
+ )
940
+
941
+ return concatenated_batch
942
+
943
+ def odds_ratio_loss(
944
+ self,
945
+ policy_chosen_logps: torch.FloatTensor,
946
+ policy_rejected_logps: torch.FloatTensor,
947
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
948
+ """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
949
+
950
+ Args:
951
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
952
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
953
+
954
+ Returns:
955
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
956
+ The losses tensor contains the ORPO loss for each example in the batch.
957
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
958
+ The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
959
+ The `log(sigmoid(log_odds_chosen))` for logging purposes.
960
+ """
961
+
962
+ # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
963
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
964
+ torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
965
+ )
966
+ ratio = F.logsigmoid(log_odds)
967
+ losses = self.beta * ratio
968
+
969
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
970
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
971
+
972
+ return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
973
+
974
+ @staticmethod
975
+ def get_batch_logps(
976
+ logits: torch.FloatTensor,
977
+ labels: torch.LongTensor,
978
+ average_log_prob: bool = False,
979
+ label_pad_token_id: int = -100,
980
+ is_encoder_decoder: bool = False,
981
+ ) -> torch.FloatTensor:
982
+ """Compute the log probabilities of the given labels under the given logits.
983
+
984
+ Args:
985
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
986
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
987
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
988
+ label_pad_token_id: The label pad token id.
989
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
990
+
991
+ Returns:
992
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
993
+ """
994
+ if logits.shape[:-1] != labels.shape:
995
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
996
+
997
+ if not is_encoder_decoder:
998
+ labels = labels[:, 1:].clone()
999
+ logits = logits[:, :-1, :]
1000
+ loss_mask = labels != label_pad_token_id
1001
+
1002
+ # dummy token; we'll ignore the losses on these tokens later
1003
+ labels = torch.where(labels == label_pad_token_id, 0, labels)
1004
+
1005
+ per_token_logps = selective_log_softmax(logits, labels)
1006
+
1007
+ if average_log_prob:
1008
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1009
+ else:
1010
+ return (per_token_logps * loss_mask).sum(-1)
1011
+
1012
+ def concatenated_forward(
1013
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1014
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1015
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1016
+
1017
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
1018
+ """
1019
+ concatenated_batch = self.concatenated_inputs(
1020
+ batch,
1021
+ is_encoder_decoder=self.is_encoder_decoder,
1022
+ label_pad_token_id=self.label_pad_token_id,
1023
+ padding_value=self.padding_value,
1024
+ device=self.accelerator.device,
1025
+ )
1026
+ len_chosen = batch["chosen_labels"].shape[0]
1027
+
1028
+ model_kwargs = (
1029
+ {
1030
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1031
+ }
1032
+ if self.is_encoder_decoder
1033
+ else {}
1034
+ )
1035
+
1036
+ if self.aux_loss_enabled:
1037
+ model_kwargs["output_router_logits"] = True
1038
+
1039
+ outputs = model(
1040
+ concatenated_batch["concatenated_input_ids"],
1041
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
1042
+ use_cache=False,
1043
+ **model_kwargs,
1044
+ )
1045
+ all_logits = outputs.logits
1046
+
1047
+ def cross_entropy_loss(logits, labels):
1048
+ if not self.is_encoder_decoder:
1049
+ # Shift so that tokens < n predict n
1050
+ logits = logits[..., :-1, :].contiguous()
1051
+ labels = labels[..., 1:].contiguous()
1052
+ # Flatten the tokens
1053
+ loss_fct = nn.CrossEntropyLoss()
1054
+ logits = logits.view(-1, logits.shape[-1])
1055
+ labels = labels.view(-1)
1056
+ # Enable model parallelism
1057
+ labels = labels.to(logits.device)
1058
+ loss = loss_fct(logits, labels)
1059
+ return loss
1060
+
1061
+ if self.is_encoder_decoder:
1062
+ labels = concatenated_batch["concatenated_labels"].clone()
1063
+ else:
1064
+ labels = concatenated_batch["concatenated_input_ids"].clone()
1065
+ attention_mask = concatenated_batch["concatenated_attention_mask"]
1066
+ labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1067
+ # orpo chosen nll loss is computed over the full prompt and response
1068
+ chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1069
+
1070
+ all_logps = self.get_batch_logps(
1071
+ all_logits,
1072
+ concatenated_batch["concatenated_labels"],
1073
+ average_log_prob=True,
1074
+ is_encoder_decoder=self.is_encoder_decoder,
1075
+ label_pad_token_id=self.label_pad_token_id,
1076
+ )
1077
+
1078
+ chosen_logps = all_logps[:len_chosen]
1079
+ rejected_logps = all_logps[len_chosen:]
1080
+
1081
+ if not self.is_encoder_decoder:
1082
+ chosen_logits = all_logits[:len_chosen, :-1, :]
1083
+ rejected_logits = all_logits[len_chosen:, :-1, :]
1084
+ else:
1085
+ chosen_logits = all_logits[:len_chosen]
1086
+ rejected_logits = all_logits[len_chosen:]
1087
+
1088
+ if self.aux_loss_enabled:
1089
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1090
+
1091
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1092
+
1093
+ def get_batch_loss_metrics(
1094
+ self,
1095
+ model,
1096
+ batch: dict[str, Union[list, torch.LongTensor]],
1097
+ train_eval: Literal["train", "eval"] = "train",
1098
+ ):
1099
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1100
+ metrics = {}
1101
+
1102
+ forward_output = self.concatenated_forward(model, batch)
1103
+ (
1104
+ policy_chosen_logps,
1105
+ policy_rejected_logps,
1106
+ policy_chosen_logits,
1107
+ policy_rejected_logits,
1108
+ policy_nll_loss,
1109
+ ) = forward_output[:5]
1110
+ if self.aux_loss_enabled:
1111
+ aux_loss = forward_output[5]
1112
+
1113
+ losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1114
+ policy_chosen_logps, policy_rejected_logps
1115
+ )
1116
+ # full ORPO loss
1117
+ loss = policy_nll_loss - losses.mean()
1118
+
1119
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
1120
+
1121
+ prefix = "eval_" if train_eval == "eval" else ""
1122
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1123
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1124
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1125
+ metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1126
+ chosen_rewards - rejected_rewards
1127
+ ).mean()
1128
+ metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1129
+ metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1130
+ metrics[f"{prefix}logits/rejected"] = (
1131
+ self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1132
+ )
1133
+ metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1134
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1135
+ metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1136
+ metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1137
+ if is_torch_xla_available():
1138
+ xm.mark_step() # needed because .item() calls
1139
+ for k, v in metrics.items():
1140
+ metrics[k] = v.item()
1141
+ if self.aux_loss_enabled:
1142
+ loss += self.aux_loss_coef * aux_loss
1143
+
1144
+ return loss, metrics
1145
+
1146
+ def compute_loss(
1147
+ self,
1148
+ model: Union[PreTrainedModel, nn.Module],
1149
+ inputs: dict[str, Union[torch.Tensor, Any]],
1150
+ return_outputs=False,
1151
+ num_items_in_batch=None,
1152
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1153
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1154
+
1155
+ with compute_loss_context_manager:
1156
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1157
+
1158
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1159
+ loss = loss.to(self.args.device)
1160
+
1161
+ # force log the metrics
1162
+ self.store_metrics(metrics, train_eval="train")
1163
+
1164
+ if return_outputs:
1165
+ return (loss, metrics)
1166
+ return loss
1167
+
1168
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1169
+ """Generate samples from the model and reference model for the given batch of inputs."""
1170
+
1171
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1172
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1173
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1174
+
1175
+ with generate_context_manager:
1176
+ policy_output = model.generate(
1177
+ input_ids=batch["prompt_input_ids"],
1178
+ attention_mask=batch["prompt_attention_mask"],
1179
+ max_length=self.max_length,
1180
+ do_sample=True,
1181
+ pad_token_id=self.processing_class.pad_token_id,
1182
+ )
1183
+
1184
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1185
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1186
+
1187
+ return policy_output_decoded
1188
+
1189
+ def prediction_step(
1190
+ self,
1191
+ model: Union[PreTrainedModel, nn.Module],
1192
+ inputs: dict[str, Union[torch.Tensor, Any]],
1193
+ prediction_loss_only: bool,
1194
+ ignore_keys: Optional[list[str]] = None,
1195
+ ):
1196
+ if not self.use_dpo_data_collator:
1197
+ warnings.warn(
1198
+ "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1199
+ "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1200
+ )
1201
+ if ignore_keys is None:
1202
+ if hasattr(model, "config"):
1203
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1204
+ else:
1205
+ ignore_keys = []
1206
+
1207
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1208
+
1209
+ with torch.no_grad(), prediction_context_manager:
1210
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1211
+
1212
+ # force log the metrics
1213
+ self.store_metrics(metrics, train_eval="eval")
1214
+
1215
+ if prediction_loss_only:
1216
+ return (loss.detach(), None, None)
1217
+
1218
+ # logits for the chosen and rejected samples from model
1219
+ logits_dict = {
1220
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
1221
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
1222
+ }
1223
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1224
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1225
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1226
+
1227
+ return (loss.detach(), logits, labels)
1228
+
1229
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1230
+ for key, value in metrics.items():
1231
+ self._stored_metrics[train_eval][key].append(value)
1232
+
1233
+ def evaluation_loop(
1234
+ self,
1235
+ dataloader: DataLoader,
1236
+ description: str,
1237
+ prediction_loss_only: Optional[bool] = None,
1238
+ ignore_keys: Optional[list[str]] = None,
1239
+ metric_key_prefix: str = "eval",
1240
+ ) -> EvalLoopOutput:
1241
+ """
1242
+ Overriding built-in evaluation loop to store metrics for each batch.
1243
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1244
+
1245
+ Works both with or without labels.
1246
+ """
1247
+
1248
+ # Sample and save to game log if requested (for one batch to save time)
1249
+ if self.generate_during_eval:
1250
+ # Generate random indices within the range of the total number of samples
1251
+ num_samples = len(dataloader.dataset)
1252
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1253
+
1254
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1255
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1256
+ random_batch = self.data_collator(random_batch_dataset)
1257
+ random_batch = self._prepare_inputs(random_batch)
1258
+
1259
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
1260
+
1261
+ table = pd.DataFrame(
1262
+ columns=["Prompt", "Policy"],
1263
+ data=[
1264
+ [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1265
+ ],
1266
+ )
1267
+ if "wandb" in self.args.report_to:
1268
+ wandb.log({"game_log": wandb.Table(data=table)})
1269
+
1270
+ if "comet_ml" in self.args.report_to:
1271
+ log_table_to_comet_experiment(
1272
+ name="game_log.csv",
1273
+ table=table,
1274
+ )
1275
+
1276
+ # Base evaluation
1277
+ initial_output = super().evaluation_loop(
1278
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1279
+ )
1280
+
1281
+ return initial_output
1282
+
1283
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1284
+ """
1285
+ Log `logs` on the various objects watching training, including stored metrics.
1286
+
1287
+ Args:
1288
+ logs (`dict[str, float]`):
1289
+ The values to log.
1290
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1291
+ Start time of the training.
1292
+ """
1293
+ # logs either has 'loss' or 'eval_loss'
1294
+ train_eval = "train" if "loss" in logs else "eval"
1295
+ # Add averaged stored metrics to logs
1296
+ for key, metrics in self._stored_metrics[train_eval].items():
1297
+ logs[key] = torch.tensor(metrics).mean().item()
1298
+ del self._stored_metrics[train_eval]
1299
+
1300
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1301
+ return super().log(logs, start_time)
1302
+ else: # transformers<=4.46
1303
+ return super().log(logs)
1304
+
1305
+ def _shift_right(self, input_ids):
1306
+ if self.decoder_start_token_id is None:
1307
+ raise ValueError(
1308
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1309
+ )
1310
+
1311
+ # shift inputs to the right
1312
+ if is_torch_fx_proxy(input_ids):
1313
+ # Item assignment is not supported natively for proxies.
1314
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1315
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1316
+ else:
1317
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1318
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1319
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1320
+
1321
+ if self.pad_token_id is None:
1322
+ raise ValueError("model.config.pad_token_id has to be defined.")
1323
+ # replace possible -100 values in labels by `pad_token_id`
1324
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1325
+
1326
+ return shifted_input_ids
1327
+
1328
+ def create_model_card(
1329
+ self,
1330
+ model_name: Optional[str] = None,
1331
+ dataset_name: Optional[str] = None,
1332
+ tags: Union[str, list[str], None] = None,
1333
+ ):
1334
+ """
1335
+ Creates a draft of a model card using the information available to the `Trainer`.
1336
+
1337
+ Args:
1338
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1339
+ Name of the model.
1340
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1341
+ Name of the dataset used for training.
1342
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1343
+ Tags to be associated with the model card.
1344
+ """
1345
+ if not self.is_world_process_zero():
1346
+ return
1347
+
1348
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1349
+ base_model = self.model.config._name_or_path
1350
+ else:
1351
+ base_model = None
1352
+
1353
+ tags = tags or []
1354
+ if isinstance(tags, str):
1355
+ tags = [tags]
1356
+
1357
+ if hasattr(self.model.config, "unsloth_version"):
1358
+ tags.append("unsloth")
1359
+
1360
+ citation = textwrap.dedent("""\
1361
+ @article{hong2024orpo,
1362
+ title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1363
+ author = {Jiwoo Hong and Noah Lee and James Thorne},
1364
+ year = 2024,
1365
+ eprint = {arXiv:2403.07691}
1366
+ }""")
1367
+
1368
+ model_card = generate_model_card(
1369
+ base_model=base_model,
1370
+ model_name=model_name,
1371
+ hub_model_id=self.hub_model_id,
1372
+ dataset_name=dataset_name,
1373
+ tags=tags,
1374
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1375
+ comet_url=get_comet_experiment_url(),
1376
+ trainer_name="ORPO",
1377
+ trainer_citation=citation,
1378
+ paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1379
+ paper_id="2403.07691",
1380
+ )
1381
+
1382
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1383
+ class UnslothORPOTrainer(_UnslothORPOTrainer):
1384
+ """
1385
+
1386
+ Initialize ORPOTrainer.
1387
+
1388
+ Args:
1389
+ model (`transformers.PreTrainedModel`):
1390
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1391
+ args (`ORPOConfig`):
1392
+ The ORPO config arguments to use for training.
1393
+ data_collator (`transformers.DataCollator`):
1394
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1395
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1396
+ train_dataset (`datasets.Dataset`):
1397
+ The dataset to use for training.
1398
+ eval_dataset (`datasets.Dataset`):
1399
+ The dataset to use for evaluation.
1400
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1401
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1402
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1403
+ reuse the fine-tuned model.
1404
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1405
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1406
+ callbacks (`list[transformers.TrainerCallback]`):
1407
+ The callbacks to use for training.
1408
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1409
+ The optimizer and scheduler to use for training.
1410
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1411
+ The function to use to preprocess the logits before computing the metrics.
1412
+ peft_config (`dict`, defaults to `None`):
1413
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1414
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1415
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1416
+ a dictionary string to metric values.
1417
+
1418
+ """
1419
+ def __init__(
1420
+ self,
1421
+ model = None,
1422
+ args = None,
1423
+ data_collator = None,
1424
+ train_dataset = None,
1425
+ eval_dataset = None,
1426
+ processing_class = None,
1427
+ model_init = None,
1428
+ callbacks = None,
1429
+ preprocess_logits_for_metrics = None,
1430
+ peft_config = None,
1431
+ compute_metrics = None,
1432
+ **kwargs
1433
+ ):
1434
+ if args is None: args = UnslothORPOConfig()
1435
+ use_bf16 = getattr(args, 'bf16', False)
1436
+ use_fp16 = getattr(args, 'fp16', False)
1437
+ force_float32 = False
1438
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1439
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1440
+ force_float32 = True
1441
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1442
+ dtype = getattr(model.config, 'torch_dtype', None)
1443
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1444
+ from unsloth_zoo.utils import _get_dtype
1445
+ dtype = _get_dtype(dtype)
1446
+ float16 = dtype == torch.float16
1447
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1448
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1449
+ if force_float32:
1450
+ args.fp16 = False
1451
+ args.bf16 = False
1452
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1453
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1454
+ args.fp16 = float16
1455
+ args.bf16 = not float16
1456
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1457
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1458
+ args.eval_strategy = 'steps'
1459
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1460
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1461
+ if ga_steps is not None and ga_steps > 1:
1462
+ from transformers import __version__ as transformers_version
1463
+ if Version(transformers_version) <= Version('4.45.2'):
1464
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1465
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1466
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1467
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1468
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1469
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1470
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1471
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1472
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1473
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1474
+ if force_float32:
1475
+ args.bf16_full_eval = False
1476
+ args.fp16_full_eval = False
1477
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1478
+ args.bf16_full_eval = True
1479
+ args.fp16_full_eval = False
1480
+ elif not bf16_full_eval and not fp16_full_eval:
1481
+ args.bf16_full_eval = args.bf16
1482
+ args.fp16_full_eval = args.fp16
1483
+ _output_logits = False
1484
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1485
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1486
+ if _output_logits:
1487
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1488
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1489
+ pass
1490
+ else:
1491
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1492
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1493
+ if args_max_seq_length is None and model_max_seq_length is not None:
1494
+ max_seq_length = model.max_seq_length
1495
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1496
+ if model is not None and hasattr(model, 'for_training'):
1497
+ model.for_training()
1498
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1499
+ if 'processing_class' in locals():
1500
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1501
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1502
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1503
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1504
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1505
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1506
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1507
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1508
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1509
+ else:
1510
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1511
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1512
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1513
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1514
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1515
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1516
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1517
+ else:
1518
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1519
+ other_metrics = []
1520
+
1521
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1522
+ PatchRLStatistics('orpo_trainer', other_metrics)
1523
+
1524
+ super().__init__(
1525
+ model = model,
1526
+ args = args,
1527
+ data_collator = data_collator,
1528
+ train_dataset = train_dataset,
1529
+ eval_dataset = eval_dataset,
1530
+ processing_class = processing_class,
1531
+ model_init = model_init,
1532
+ callbacks = callbacks,
1533
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1534
+ peft_config = peft_config,
1535
+ compute_metrics = compute_metrics,**kwargs)
1536
+ if hasattr(self, 'neftune_hook_handle'):
1537
+ self.neftune_hook_handle.remove()
1538
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1539
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1540
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1541
+ pass
1542
+
1543
+ pass
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py ADDED
@@ -0,0 +1,1269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, warnings, wraps, F, is_conversational, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ def vLLMSamplingParams(**kwargs):
43
+ from vllm import SamplingParams
44
+ sampling_params = SamplingParams(**kwargs)
45
+ sampling_params._set_kwargs = kwargs
46
+ return sampling_params
47
+ @dataclass
48
+ class UnslothOnlineDPOConfig(OnlineDPOConfig):
49
+ """
50
+
51
+ Configuration class for the [`OnlineDPOTrainer`].
52
+
53
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
54
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
55
+ command line.
56
+
57
+ Parameters:
58
+ learning_rate (`float`, *optional*, defaults to `5e-7`):
59
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
60
+ [`~transformers.TrainingArguments`].
61
+ reward_model_path (`str` or `None`, *optional*, defaults to `None`):
62
+ Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
63
+ judge (`str` or `None`, *optional*, defaults to `None`):
64
+ Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
65
+ max_new_tokens (`int`, *optional*, defaults to `64`):
66
+ Maximum number of tokens to generate per completion.
67
+ max_length (`int`, *optional*, defaults to `256`):
68
+ Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
69
+ sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
70
+ possible.
71
+ temperature (`float`, *optional*, defaults to `0.9`):
72
+ Temperature for sampling. The higher the temperature, the more random the completions.
73
+ missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
74
+ Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
75
+ to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
76
+ value.
77
+ beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
78
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
79
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
80
+ the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
81
+ selected for each new epoch and the last β is used for the rest of the epochs.
82
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
83
+ Type of loss to use. Possible values are:
84
+
85
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
86
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
87
+
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ disable_dropout (`bool`, *optional*, defaults to `True`):
91
+ Whether to disable dropout in the model and reference model.
92
+ use_vllm (`bool`, *optional*, defaults to `False`):
93
+ Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
94
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
95
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
96
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
97
+ capacity of a single GPU, albeit at the cost of slower generation.
98
+
99
+ """
100
+ vllm_sampling_params: Optional[Any] = field(
101
+ default = None,
102
+ metadata = {'help': 'vLLM SamplingParams'},
103
+ )
104
+ unsloth_num_chunks : Optional[int] = field(
105
+ default = -1,
106
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
107
+ )
108
+ def __init__(
109
+ self,
110
+ output_dir = None,
111
+ overwrite_output_dir = None,
112
+ do_train = False,
113
+ do_eval = False,
114
+ do_predict = False,
115
+ eval_strategy = 'no',
116
+ prediction_loss_only = False,
117
+ per_device_train_batch_size = 4,
118
+ per_device_eval_batch_size = 4,
119
+ per_gpu_train_batch_size = None,
120
+ per_gpu_eval_batch_size = None,
121
+ gradient_accumulation_steps = 2,
122
+ eval_accumulation_steps = 2,
123
+ eval_delay = 0,
124
+ torch_empty_cache_steps = 250,
125
+ learning_rate = 5e-05,
126
+ weight_decay = 0.01,
127
+ adam_beta1 = 0.9,
128
+ adam_beta2 = 0.999,
129
+ adam_epsilon = 1e-08,
130
+ max_grad_norm = 1.0,
131
+ num_train_epochs = 3.0,
132
+ max_steps = -1,
133
+ lr_scheduler_type = 'linear',
134
+ warmup_ratio = 0.1,
135
+ warmup_steps = 0,
136
+ log_level = 'passive',
137
+ log_level_replica = 'warning',
138
+ log_on_each_node = True,
139
+ logging_dir = None,
140
+ logging_strategy = 'steps',
141
+ logging_first_step = False,
142
+ logging_steps = 1,
143
+ logging_nan_inf_filter = False,
144
+ save_strategy = 'steps',
145
+ save_steps = 500,
146
+ save_total_limit = None,
147
+ save_safetensors = True,
148
+ save_on_each_node = False,
149
+ save_only_model = False,
150
+ restore_callback_states_from_checkpoint = False,
151
+ no_cuda = False,
152
+ use_cpu = False,
153
+ use_mps_device = False,
154
+ seed = 3407,
155
+ data_seed = 3407,
156
+ jit_mode_eval = False,
157
+ use_ipex = False,
158
+ bf16 = False,
159
+ fp16 = False,
160
+ fp16_opt_level = 'O1',
161
+ half_precision_backend = 'auto',
162
+ bf16_full_eval = False,
163
+ fp16_full_eval = False,
164
+ tf32 = None,
165
+ local_rank = -1,
166
+ ddp_backend = None,
167
+ tpu_num_cores = None,
168
+ tpu_metrics_debug = False,
169
+ debug = '',
170
+ dataloader_drop_last = False,
171
+ eval_steps = None,
172
+ dataloader_num_workers = 0,
173
+ dataloader_prefetch_factor = None,
174
+ past_index = -1,
175
+ run_name = None,
176
+ disable_tqdm = None,
177
+ remove_unused_columns = True,
178
+ label_names = None,
179
+ load_best_model_at_end = False,
180
+ metric_for_best_model = None,
181
+ greater_is_better = None,
182
+ ignore_data_skip = False,
183
+ fsdp = '',
184
+ fsdp_min_num_params = 0,
185
+ fsdp_config = None,
186
+ tp_size = 0,
187
+ fsdp_transformer_layer_cls_to_wrap = None,
188
+ accelerator_config = None,
189
+ deepspeed = None,
190
+ label_smoothing_factor = 0.0,
191
+ optim = 'adamw_8bit',
192
+ optim_args = None,
193
+ adafactor = False,
194
+ group_by_length = False,
195
+ length_column_name = 'length',
196
+ report_to = None,
197
+ ddp_find_unused_parameters = None,
198
+ ddp_bucket_cap_mb = None,
199
+ ddp_broadcast_buffers = None,
200
+ dataloader_pin_memory = True,
201
+ dataloader_persistent_workers = False,
202
+ skip_memory_metrics = True,
203
+ use_legacy_prediction_loop = False,
204
+ push_to_hub = False,
205
+ resume_from_checkpoint = None,
206
+ hub_model_id = None,
207
+ hub_strategy = 'every_save',
208
+ hub_token = None,
209
+ hub_private_repo = None,
210
+ hub_always_push = False,
211
+ gradient_checkpointing = False,
212
+ gradient_checkpointing_kwargs = None,
213
+ include_inputs_for_metrics = False,
214
+ eval_do_concat_batches = True,
215
+ fp16_backend = 'auto',
216
+ evaluation_strategy = None,
217
+ push_to_hub_model_id = None,
218
+ push_to_hub_organization = None,
219
+ push_to_hub_token = None,
220
+ mp_parameters = '',
221
+ auto_find_batch_size = False,
222
+ full_determinism = False,
223
+ torchdynamo = None,
224
+ ray_scope = 'last',
225
+ ddp_timeout = 1800,
226
+ torch_compile = False,
227
+ torch_compile_backend = None,
228
+ torch_compile_mode = None,
229
+ dispatch_batches = None,
230
+ split_batches = None,
231
+ include_tokens_per_second = False,
232
+ include_num_input_tokens_seen = False,
233
+ neftune_noise_alpha = None,
234
+ optim_target_modules = None,
235
+ batch_eval_metrics = False,
236
+ eval_on_start = False,
237
+ use_liger_kernel = False,
238
+ eval_use_gather_object = False,
239
+ average_tokens_across_devices = False,
240
+ reward_model_path = None,
241
+ judge = None,
242
+ max_new_tokens = 64,
243
+ max_length = 512,
244
+ temperature = 0.9,
245
+ missing_eos_penalty = None,
246
+ loss_type = 'sigmoid',
247
+ dataset_num_proc = None,
248
+ disable_dropout = True,
249
+ use_vllm = False,
250
+ ds3_gather_for_generation = True,
251
+ vllm_sampling_params = None,
252
+ unsloth_num_chunks = -1,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = cpu_count()
263
+
264
+ super().__init__(
265
+ output_dir = output_dir,
266
+ overwrite_output_dir = overwrite_output_dir,
267
+ do_train = do_train,
268
+ do_eval = do_eval,
269
+ do_predict = do_predict,
270
+ eval_strategy = eval_strategy,
271
+ prediction_loss_only = prediction_loss_only,
272
+ per_device_train_batch_size = per_device_train_batch_size,
273
+ per_device_eval_batch_size = per_device_eval_batch_size,
274
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
275
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
276
+ gradient_accumulation_steps = gradient_accumulation_steps,
277
+ eval_accumulation_steps = eval_accumulation_steps,
278
+ eval_delay = eval_delay,
279
+ torch_empty_cache_steps = torch_empty_cache_steps,
280
+ learning_rate = learning_rate,
281
+ weight_decay = weight_decay,
282
+ adam_beta1 = adam_beta1,
283
+ adam_beta2 = adam_beta2,
284
+ adam_epsilon = adam_epsilon,
285
+ max_grad_norm = max_grad_norm,
286
+ num_train_epochs = num_train_epochs,
287
+ max_steps = max_steps,
288
+ lr_scheduler_type = lr_scheduler_type,
289
+ warmup_ratio = warmup_ratio,
290
+ warmup_steps = warmup_steps,
291
+ log_level = log_level,
292
+ log_level_replica = log_level_replica,
293
+ log_on_each_node = log_on_each_node,
294
+ logging_dir = logging_dir,
295
+ logging_strategy = logging_strategy,
296
+ logging_first_step = logging_first_step,
297
+ logging_steps = logging_steps,
298
+ logging_nan_inf_filter = logging_nan_inf_filter,
299
+ save_strategy = save_strategy,
300
+ save_steps = save_steps,
301
+ save_total_limit = save_total_limit,
302
+ save_safetensors = save_safetensors,
303
+ save_on_each_node = save_on_each_node,
304
+ save_only_model = save_only_model,
305
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
306
+ no_cuda = no_cuda,
307
+ use_cpu = use_cpu,
308
+ use_mps_device = use_mps_device,
309
+ seed = seed,
310
+ data_seed = data_seed,
311
+ jit_mode_eval = jit_mode_eval,
312
+ use_ipex = use_ipex,
313
+ bf16 = bf16,
314
+ fp16 = fp16,
315
+ fp16_opt_level = fp16_opt_level,
316
+ half_precision_backend = half_precision_backend,
317
+ bf16_full_eval = bf16_full_eval,
318
+ fp16_full_eval = fp16_full_eval,
319
+ tf32 = tf32,
320
+ local_rank = local_rank,
321
+ ddp_backend = ddp_backend,
322
+ tpu_num_cores = tpu_num_cores,
323
+ tpu_metrics_debug = tpu_metrics_debug,
324
+ debug = debug,
325
+ dataloader_drop_last = dataloader_drop_last,
326
+ eval_steps = eval_steps,
327
+ dataloader_num_workers = dataloader_num_workers,
328
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
329
+ past_index = past_index,
330
+ run_name = run_name,
331
+ disable_tqdm = disable_tqdm,
332
+ remove_unused_columns = remove_unused_columns,
333
+ label_names = label_names,
334
+ load_best_model_at_end = load_best_model_at_end,
335
+ metric_for_best_model = metric_for_best_model,
336
+ greater_is_better = greater_is_better,
337
+ ignore_data_skip = ignore_data_skip,
338
+ fsdp = fsdp,
339
+ fsdp_min_num_params = fsdp_min_num_params,
340
+ fsdp_config = fsdp_config,
341
+ tp_size = tp_size,
342
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
343
+ accelerator_config = accelerator_config,
344
+ deepspeed = deepspeed,
345
+ label_smoothing_factor = label_smoothing_factor,
346
+ optim = optim,
347
+ optim_args = optim_args,
348
+ adafactor = adafactor,
349
+ group_by_length = group_by_length,
350
+ length_column_name = length_column_name,
351
+ report_to = report_to,
352
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
353
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
354
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
355
+ dataloader_pin_memory = dataloader_pin_memory,
356
+ dataloader_persistent_workers = dataloader_persistent_workers,
357
+ skip_memory_metrics = skip_memory_metrics,
358
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
359
+ push_to_hub = push_to_hub,
360
+ resume_from_checkpoint = resume_from_checkpoint,
361
+ hub_model_id = hub_model_id,
362
+ hub_strategy = hub_strategy,
363
+ hub_token = hub_token,
364
+ hub_private_repo = hub_private_repo,
365
+ hub_always_push = hub_always_push,
366
+ gradient_checkpointing = gradient_checkpointing,
367
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
368
+ include_inputs_for_metrics = include_inputs_for_metrics,
369
+ eval_do_concat_batches = eval_do_concat_batches,
370
+ fp16_backend = fp16_backend,
371
+ evaluation_strategy = evaluation_strategy,
372
+ push_to_hub_model_id = push_to_hub_model_id,
373
+ push_to_hub_organization = push_to_hub_organization,
374
+ push_to_hub_token = push_to_hub_token,
375
+ mp_parameters = mp_parameters,
376
+ auto_find_batch_size = auto_find_batch_size,
377
+ full_determinism = full_determinism,
378
+ torchdynamo = torchdynamo,
379
+ ray_scope = ray_scope,
380
+ ddp_timeout = ddp_timeout,
381
+ torch_compile = torch_compile,
382
+ torch_compile_backend = torch_compile_backend,
383
+ torch_compile_mode = torch_compile_mode,
384
+ dispatch_batches = dispatch_batches,
385
+ split_batches = split_batches,
386
+ include_tokens_per_second = include_tokens_per_second,
387
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
388
+ neftune_noise_alpha = neftune_noise_alpha,
389
+ optim_target_modules = optim_target_modules,
390
+ batch_eval_metrics = batch_eval_metrics,
391
+ eval_on_start = eval_on_start,
392
+ use_liger_kernel = use_liger_kernel,
393
+ eval_use_gather_object = eval_use_gather_object,
394
+ average_tokens_across_devices = average_tokens_across_devices,
395
+ reward_model_path = reward_model_path,
396
+ judge = judge,
397
+ max_new_tokens = max_new_tokens,
398
+ max_length = max_length,
399
+ temperature = temperature,
400
+ missing_eos_penalty = missing_eos_penalty,
401
+ loss_type = loss_type,
402
+ dataset_num_proc = dataset_num_proc,
403
+ disable_dropout = disable_dropout,
404
+ use_vllm = use_vllm,
405
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
406
+ self.vllm_sampling_params = vllm_sampling_params
407
+ self.unsloth_num_chunks = unsloth_num_chunks
408
+ pass
409
+
410
+ class _UnslothOnlineDPOTrainer(Trainer):
411
+ r""""""
412
+
413
+ _tag_names = ["trl", "online-dpo"]
414
+
415
+ def __init__(
416
+ self,
417
+ model: Union[PreTrainedModel, nn.Module],
418
+ ref_model: Union[PreTrainedModel, nn.Module, None] = None,
419
+ reward_model: Union[PreTrainedModel, nn.Module, None] = None,
420
+ judge: Optional[BasePairwiseJudge] = None,
421
+ args: Optional[OnlineDPOConfig] = None,
422
+ data_collator: Optional[DataCollator] = None,
423
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
424
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
425
+ processing_class: Optional[
426
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
427
+ ] = None,
428
+ reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
429
+ peft_config: Optional[dict] = None,
430
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
431
+ callbacks: Optional[list[TrainerCallback]] = None,
432
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
433
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
434
+ ) -> None:
435
+
436
+ if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
437
+ if ref_model is model:
438
+ raise ValueError(
439
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
440
+ "same as `model`, either omit the `ref_model` argument or pass `None`."
441
+ )
442
+
443
+ self.ref_model = ref_model
444
+
445
+ if reward_model is not None and judge is not None:
446
+ warnings.warn(
447
+ "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
448
+ "Ignoring `judge` and using `reward_model`.",
449
+ UserWarning,
450
+ )
451
+ judge = None
452
+ elif reward_model is None and judge is None:
453
+ raise ValueError("Either `reward_model` or `judge` must be provided.")
454
+
455
+ self.reward_model = reward_model
456
+ self.reward_processing_class = reward_processing_class
457
+ self.judge = judge
458
+
459
+ if args.missing_eos_penalty is not None and judge is not None:
460
+ raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
461
+
462
+ if args is None:
463
+ raise ValueError("`args` must be provided.")
464
+
465
+ # Check that the processing_class is provided
466
+ if processing_class is None:
467
+ raise ValueError("`processing_class` must be provided.")
468
+
469
+ # Convert to PEFT model if peft_config is provided
470
+ if False:
471
+ # Check if PEFT is available
472
+ if not is_peft_available():
473
+ raise ImportError(
474
+ "PEFT is not available and passed `peft_config`. Please install PEFT with "
475
+ "`pip install peft` to use it."
476
+ )
477
+
478
+ # If the model is already a PeftModel, we need to merge and unload it.
479
+ # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
480
+ if isinstance(model, PeftModel):
481
+ model = model.merge_and_unload()
482
+
483
+ # Get peft model with the given config
484
+ model = model
485
+
486
+ # Disable dropout in the model and reference model
487
+ if args.disable_dropout:
488
+ disable_dropout_in_model(model)
489
+ if self.ref_model is not None:
490
+ disable_dropout_in_model(self.ref_model)
491
+
492
+ # Handle the ref_model
493
+ # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
494
+ # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
495
+ # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
496
+ if ref_model is None: # No ref model provided, the most common case
497
+ if False:
498
+ self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
499
+ else:
500
+ self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
501
+ else: # rare case, the user provided a ref model
502
+ self.ref_model = ref_model
503
+ self.ref_model.eval()
504
+
505
+ # Disable the gradient and set the reward model in eval mode
506
+ if self.reward_model is not None:
507
+ self.reward_model.eval()
508
+
509
+ # Define the collator is not provided
510
+ if data_collator is None:
511
+ data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
512
+
513
+ self.max_length = args.max_length
514
+
515
+ self.stats = {
516
+ "objective/kl": [],
517
+ "objective/entropy": [],
518
+ "objective/non_score_reward": [],
519
+ "rewards/chosen": [],
520
+ "rewards/rejected": [],
521
+ "rewards/accuracies": [],
522
+ "rewards/margins": [],
523
+ "logps/chosen": [],
524
+ "logps/rejected": [],
525
+ "val/contain_eos_token": [],
526
+ "beta": [],
527
+ }
528
+ if self.reward_model is not None:
529
+ self.stats["objective/rlhf_reward"] = []
530
+ self.stats["objective/scores_margin"] = []
531
+ self.stats["objective/scores"] = []
532
+
533
+ if args.use_vllm:
534
+ self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
535
+ n=2, max_tokens=args.max_new_tokens,
536
+ temperature=args.temperature,
537
+ top_k=50,
538
+ top_p=1.0,
539
+ detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
540
+ else:
541
+ self.generation_config = GenerationConfig(
542
+ max_new_tokens=args.max_new_tokens,
543
+ temperature=args.temperature,
544
+ top_k=50,
545
+ top_p=1.0,
546
+ do_sample=True,
547
+ use_cache=False if args.gradient_checkpointing else True,
548
+ )
549
+
550
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
551
+ # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
552
+ # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
553
+ # of the input, floating-point operations will not be computed." To suppress this warning, we set the
554
+ # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
555
+ # that the warning has already been issued.
556
+ model.warnings_issued["estimate_tokens"] = True
557
+
558
+ super().__init__(
559
+ model=model,
560
+ args=args,
561
+ data_collator=data_collator,
562
+ train_dataset=train_dataset,
563
+ eval_dataset=eval_dataset,
564
+ processing_class=processing_class,
565
+ compute_metrics=compute_metrics,
566
+ callbacks=callbacks,
567
+ optimizers=optimizers,
568
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
569
+ )
570
+
571
+ # Add tags for models that have been loaded with the correct transformers version
572
+ if hasattr(self.model, "add_model_tags"):
573
+ self.model.add_model_tags(self._tag_names)
574
+
575
+ self._beta = args.beta
576
+
577
+ # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
578
+ if self.is_deepspeed_enabled:
579
+ if self.reward_model is not None:
580
+ self.reward_model = prepare_deepspeed(
581
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
582
+ )
583
+ if self.ref_model is not None:
584
+ self.ref_model = prepare_deepspeed(
585
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
586
+ )
587
+ else:
588
+ if self.ref_model is not None:
589
+ self.ref_model = self.ref_model.to(self.accelerator.device)
590
+ if self.reward_model is not None:
591
+ self.reward_model = self.reward_model.to(self.accelerator.device)
592
+
593
+ @property
594
+ def beta(self):
595
+ if isinstance(self._beta, list):
596
+ epoch = self.state.epoch
597
+ return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
598
+ else:
599
+ return self._beta
600
+
601
+ @staticmethod
602
+ def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
603
+ """Tokenize a single row from a DPO specific dataset."""
604
+ if not is_encoder_decoder:
605
+ batch = tokenizer(feature["prompt"], add_special_tokens=False)
606
+ # Add BOS token to head of prompt. Avoid adding if it's already there
607
+ if tokenizer.bos_token_id is not None:
608
+ prompt_len_input_ids = len(batch["input_ids"])
609
+ if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
610
+ batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
611
+ batch["attention_mask"] = [1] + batch["attention_mask"]
612
+ else:
613
+ batch = tokenizer(feature["prompt"], add_special_tokens=True)
614
+ batch = {f"prompt_{key}": value for key, value in batch.items()}
615
+ return batch
616
+
617
+ # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
618
+ @wraps(Trainer.get_train_dataloader)
619
+ def get_train_dataloader(self) -> DataLoader:
620
+ if self.train_dataset is None:
621
+ raise ValueError("Trainer: training requires a train_dataset.")
622
+
623
+ train_dataset = self.train_dataset
624
+ data_collator = self.data_collator
625
+ dataloader_params = {
626
+ "batch_size": self._train_batch_size,
627
+ "collate_fn": data_collator,
628
+ "num_workers": self.args.dataloader_num_workers,
629
+ "pin_memory": self.args.dataloader_pin_memory,
630
+ "persistent_workers": self.args.dataloader_persistent_workers,
631
+ }
632
+
633
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
634
+ dataloader_params["sampler"] = self._get_train_sampler()
635
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
636
+ dataloader_params["worker_init_fn"] = seed_worker
637
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
638
+
639
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
640
+
641
+ # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
642
+ @wraps(Trainer.get_eval_dataloader)
643
+ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
644
+ if eval_dataset is None and self.eval_dataset is None:
645
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
646
+
647
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
648
+ # don't change during training
649
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
650
+ if (
651
+ hasattr(self, "_eval_dataloaders")
652
+ and dataloader_key in self._eval_dataloaders
653
+ and self.args.dataloader_persistent_workers
654
+ ):
655
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
656
+
657
+ eval_dataset = (
658
+ self.eval_dataset[eval_dataset]
659
+ if isinstance(eval_dataset, str)
660
+ else eval_dataset
661
+ if eval_dataset is not None
662
+ else self.eval_dataset
663
+ )
664
+ data_collator = self.data_collator
665
+
666
+ dataloader_params = {
667
+ "batch_size": self.args.eval_batch_size,
668
+ "collate_fn": data_collator,
669
+ "num_workers": self.args.dataloader_num_workers,
670
+ "pin_memory": self.args.dataloader_pin_memory,
671
+ "persistent_workers": self.args.dataloader_persistent_workers,
672
+ }
673
+
674
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
675
+ dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
676
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
677
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
678
+
679
+ # accelerator.free_memory() will destroy the references, so
680
+ # we need to store the non-prepared version
681
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
682
+ if self.args.dataloader_persistent_workers:
683
+ if hasattr(self, "_eval_dataloaders"):
684
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
685
+ else:
686
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
687
+
688
+ return self.accelerator.prepare(eval_dataloader)
689
+
690
+ def _generate_vllm(self, model, prompts):
691
+ eos_token_id = self.processing_class.eos_token_id
692
+ pad_token_id = self.processing_class.pad_token_id
693
+
694
+ # Load the latest weights
695
+
696
+ pass
697
+
698
+ pass
699
+
700
+ if is_conversational({"prompt": prompts[0]}):
701
+ outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
702
+ else:
703
+ outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
704
+
705
+ completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
706
+ prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
707
+
708
+ # Create mask and pad the prompt and completion
709
+ max_prompt_length = max(len(ids) for ids in prompt_ids)
710
+ prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
711
+ prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
712
+ max_tokens = self.generation_config.max_tokens
713
+ completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
714
+ completion_ids = [
715
+ ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
716
+ for ids in completion_ids
717
+ ]
718
+ completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
719
+
720
+ # Convert to tensors
721
+ prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
722
+ prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
723
+ completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
724
+ completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
725
+
726
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
727
+
728
+ def _generate(self, model, prompts):
729
+ eos_token_id = self.processing_class.eos_token_id
730
+ pad_token_id = self.processing_class.pad_token_id
731
+
732
+ # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
733
+ # policies with different tokenizers / chat templates.
734
+ inputs = [{"prompt": prompt} for prompt in prompts]
735
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
736
+ inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
737
+ inputs = self.data_collator(inputs)
738
+
739
+ # Sample 2 completions per prompt of size `max_new_tokens` from the model
740
+ inputs = self._prepare_inputs(inputs)
741
+ prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
742
+ prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
743
+ with unwrap_model_for_generation(
744
+ model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
745
+ ) as unwrapped_model:
746
+ output = unwrapped_model.generate(
747
+ input_ids=prompt_ids,
748
+ attention_mask=prompt_mask,
749
+ generation_config=self.generation_config,
750
+ )
751
+
752
+ completion_ids = output[:, prompt_ids.size(1) :]
753
+ completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
754
+
755
+ return prompt_ids, prompt_mask, completion_ids, completion_mask
756
+
757
+ def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
758
+ # Get the number of tokens to truncate from prompt
759
+ num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
760
+
761
+ # Truncate left to avoid oom
762
+ prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
763
+ prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
764
+
765
+ # Concat the prompt and completion
766
+ prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
767
+ prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
768
+
769
+ # Get the logprobs of the completions from the model
770
+ output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
771
+
772
+ # There is 1 offset, because the model predict the next token
773
+ logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
774
+
775
+ # Take the completion tokens logprob
776
+ logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
777
+ return logprobs
778
+
779
+ def training_step(
780
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
781
+ ) -> torch.Tensor:
782
+ model.train()
783
+
784
+ prompts = inputs["prompt"]
785
+ batch_size = len(prompts)
786
+
787
+ if self.args.use_vllm:
788
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
789
+ else:
790
+ prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
791
+
792
+ contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
793
+
794
+ logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
795
+ with torch.no_grad():
796
+ if self.ref_model is not None:
797
+ ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
798
+ else: # peft case: we just need to disable the adapter
799
+ with self.model.disable_adapter():
800
+ ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
801
+
802
+ # Decode the completions, and format them if the input is conversational
803
+ device = logprobs.device
804
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
805
+ if is_conversational({"prompt": prompts[0]}):
806
+ completions = [[{"role": "assistant", "content": completion}] for completion in completions]
807
+
808
+ # Get the reward from the reward model or judge
809
+ if self.judge is not None:
810
+ # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
811
+ # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
812
+ # independent of the model's chat template, we use the raw conversation data, and apply our own chat
813
+ # template to it.
814
+ if is_conversational({"prompt": prompts[0]}):
815
+ environment = jinja2.Environment()
816
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
817
+ prompts = [template.render(messages=prompt) for prompt in prompts]
818
+ completions = [template.render(messages=completion) for completion in completions]
819
+
820
+ ranks_of_first_completion = self.judge.judge(
821
+ prompts, list(zip(completions[:batch_size], completions[batch_size:]))
822
+ )
823
+
824
+ # convert ranks to a True/False mask:
825
+ # when rank == 0, it means the first completion is the best
826
+ # when rank == 1, it means the second completion is the best
827
+ mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
828
+ else:
829
+ # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
830
+ # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
831
+ prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
832
+ if is_conversational({"prompt": prompts[0]}):
833
+ examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
834
+ examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
835
+ prompts = [example["prompt"] for example in examples]
836
+ completions = [example["completion"] for example in examples]
837
+
838
+ # Tokenize the prompts
839
+ prompts_ids = self.reward_processing_class(
840
+ prompts, padding=True, return_tensors="pt", padding_side="left"
841
+ )["input_ids"].to(device)
842
+ context_length = prompts_ids.shape[1]
843
+
844
+ # Tokenize the completions
845
+ completions_ids = self.reward_processing_class(
846
+ completions, padding=True, return_tensors="pt", padding_side="right"
847
+ )["input_ids"].to(device)
848
+
849
+ # Concatenate the prompts and completions and get the reward
850
+ prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
851
+ with torch.inference_mode():
852
+ _, scores, _ = get_reward(
853
+ self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
854
+ )
855
+
856
+ # Filter completion. Ensure that the sample contains stop_token_id
857
+ # Completions not passing that filter will receive a lower score.
858
+ if self.args.missing_eos_penalty is not None:
859
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
860
+
861
+ # Split the scores in 2 (the prompts of the first half are the same as the second half)
862
+ first_half, second_half = scores.split(batch_size)
863
+
864
+ # Get the indices of the chosen and rejected examples
865
+ mask = first_half >= second_half
866
+
867
+ batch_range = torch.arange(batch_size, device=device)
868
+ chosen_indices = batch_range + (~mask * batch_size)
869
+ rejected_indices = batch_range + (mask * batch_size)
870
+
871
+ # Build tensor so that the first half is the chosen examples and the second half the rejected examples
872
+ cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
873
+ cr_logprobs = logprobs[cr_indices]
874
+ cr_ref_logprobs = ref_logprobs[cr_indices]
875
+
876
+ # mask out the padding tokens
877
+ padding_mask = ~completion_mask.bool()
878
+ cr_padding_mask = padding_mask[cr_indices]
879
+
880
+ cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
881
+ cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
882
+
883
+ # Split the chosen and rejected examples
884
+ chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
885
+ chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
886
+ pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
887
+ ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
888
+
889
+ logits = pi_logratios - ref_logratios
890
+
891
+ if self.args.loss_type == "sigmoid":
892
+ losses = -F.logsigmoid(self.beta * logits)
893
+ elif self.args.loss_type == "ipo":
894
+ losses = (logits - 1 / (2 * self.beta)) ** 2
895
+ else:
896
+ raise NotImplementedError(f"invalid loss type {self.loss_type}")
897
+
898
+ loss = losses.mean()
899
+
900
+ # Log everything
901
+ if self.reward_model is not None:
902
+ scores_margin = scores[chosen_indices] - scores[rejected_indices]
903
+ self.stats["objective/scores_margin"].append(
904
+ self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
905
+ )
906
+ self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
907
+ self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
908
+ self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
909
+ self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
910
+
911
+ kl = logprobs - ref_logprobs
912
+ mean_kl = kl.sum(1).mean()
913
+ self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
914
+ non_score_reward = (-self.beta * kl).sum(1)
915
+ mean_non_score_reward = non_score_reward.mean()
916
+ self.stats["objective/non_score_reward"].append(
917
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
918
+ )
919
+ if self.reward_model is not None:
920
+ rlhf_reward = scores + non_score_reward
921
+ self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
922
+ mean_entropy = -logprobs.sum(1).mean()
923
+ self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
924
+ chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
925
+ gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
926
+ self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
927
+ rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
928
+ gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
929
+ self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
930
+ margin = gathered_chosen_rewards - gathered_rejected_rewards
931
+ self.stats["rewards/margins"].append(margin.mean().item())
932
+ accuracy = margin > 0
933
+ self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
934
+ self.stats["beta"].append(self.beta)
935
+
936
+ if (
937
+ self.args.torch_empty_cache_steps is not None
938
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
939
+ ):
940
+ empty_cache()
941
+
942
+ kwargs = {}
943
+
944
+ # For LOMO optimizers you need to explicitly use the learnign rate
945
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
946
+ kwargs["learning_rate"] = self._get_learning_rate()
947
+
948
+ if self.args.n_gpu > 1:
949
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
950
+
951
+ if self.use_apex:
952
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
953
+ scaled_loss.backward()
954
+ else:
955
+ self.accelerator.backward(loss, **kwargs)
956
+
957
+ return loss.detach() / self.args.gradient_accumulation_steps
958
+
959
+ # Same as Trainer._maybe_log_save_evaluate but log our metrics
960
+ # start_time defaults to None to allow compatibility with transformers<=4.46
961
+ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
962
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
963
+ logs: dict[str, float] = {}
964
+
965
+ # all_gather + mean() to get average loss over all processes
966
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
967
+
968
+ # reset tr_loss to zero
969
+ tr_loss -= tr_loss
970
+
971
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
972
+ if grad_norm is not None:
973
+ logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
974
+ logs["learning_rate"] = self._get_learning_rate()
975
+
976
+ # Add our metrics
977
+ for key, val in self.stats.items():
978
+ logs[key] = sum(val) / len(val)
979
+ self.stats = {key: [] for key in self.stats} # reset stats
980
+
981
+ self._total_loss_scalar += tr_loss_scalar
982
+ self._globalstep_last_logged = self.state.global_step
983
+ self.store_flos()
984
+
985
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
986
+ self.log(logs, start_time)
987
+ else: # transformers<=4.46
988
+ self.log(logs)
989
+
990
+ metrics = None
991
+ if self.control.should_evaluate:
992
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
993
+ is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
994
+
995
+ if self.args.save_strategy == "best":
996
+ self.control.should_save = is_new_best_metric
997
+
998
+ if self.control.should_save:
999
+ self._save_checkpoint(model, trial)
1000
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1001
+
1002
+ # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
1003
+ # This can be removed once the minimum transformers version is updated to 4.47.
1004
+ # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
1005
+ def _determine_best_metric(self, metrics, trial):
1006
+ """
1007
+ Determine if the model should be saved based on the evaluation metrics.
1008
+ If args.metric_for_best_model is not set, the loss is used.
1009
+ Returns:
1010
+ bool: True if a new best metric was found, else False
1011
+ """
1012
+ is_new_best_metric = False
1013
+
1014
+ if self.args.metric_for_best_model is not None:
1015
+ metric_to_check = self.args.metric_for_best_model
1016
+
1017
+ if not metric_to_check.startswith("eval_"):
1018
+ metric_to_check = f"eval_{metric_to_check}"
1019
+
1020
+ try:
1021
+ metric_value = metrics[metric_to_check]
1022
+ except KeyError as exc:
1023
+ raise KeyError(
1024
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1025
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1026
+ ) from exc
1027
+
1028
+ operator = np.greater if self.args.greater_is_better else np.less
1029
+
1030
+ if self.state.best_metric is None:
1031
+ self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1032
+
1033
+ if operator(metric_value, self.state.best_metric):
1034
+ run_dir = self._get_output_dir(trial=trial)
1035
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1036
+ output_dir = os.path.join(run_dir, checkpoint_folder)
1037
+ self.state.best_metric = metric_value
1038
+ self.state.best_model_checkpoint = output_dir
1039
+
1040
+ is_new_best_metric = True
1041
+
1042
+ return is_new_best_metric
1043
+
1044
+ def create_model_card(
1045
+ self,
1046
+ model_name: Optional[str] = None,
1047
+ dataset_name: Optional[str] = None,
1048
+ tags: Union[str, list[str], None] = None,
1049
+ ):
1050
+ """
1051
+ Creates a draft of a model card using the information available to the `Trainer`.
1052
+
1053
+ Args:
1054
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1055
+ Name of the model.
1056
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1057
+ Name of the dataset used for training.
1058
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1059
+ Tags to be associated with the model card.
1060
+ """
1061
+ if not self.is_world_process_zero():
1062
+ return
1063
+
1064
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1065
+ base_model = self.model.config._name_or_path
1066
+ else:
1067
+ base_model = None
1068
+
1069
+ tags = tags or []
1070
+ if isinstance(tags, str):
1071
+ tags = [tags]
1072
+
1073
+ if hasattr(self.model.config, "unsloth_version"):
1074
+ tags.append("unsloth")
1075
+
1076
+ citation = textwrap.dedent("""\
1077
+ @article{guo2024direct,
1078
+ title = {{Direct Language Model Alignment from Online AI Feedback}},
1079
+ author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1080
+ year = 2024,
1081
+ eprint = {arXiv:2402.04792}
1082
+ }""")
1083
+
1084
+ model_card = generate_model_card(
1085
+ base_model=base_model,
1086
+ model_name=model_name,
1087
+ hub_model_id=self.hub_model_id,
1088
+ dataset_name=dataset_name,
1089
+ tags=tags,
1090
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1091
+ comet_url=get_comet_experiment_url(),
1092
+ trainer_name="Online DPO",
1093
+ trainer_citation=citation,
1094
+ paper_title="Direct Language Model Alignment from Online AI Feedback",
1095
+ paper_id="2402.04792",
1096
+ )
1097
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1098
+ class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1099
+ """
1100
+
1101
+ Initialize OnlineDPOTrainer.
1102
+
1103
+ Args:
1104
+ model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1105
+ The model to train, preferably an `AutoModelForCausalLM`.
1106
+ ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1107
+ The reference model to use for training. If None is specified, the reference model will be created from
1108
+ the model.
1109
+ reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1110
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1111
+ judge (`BasePairwiseJudge`):
1112
+ The judge to use for pairwise comparison of model completions.
1113
+ args (`OnlineDPOConfig`):
1114
+ The online DPO config arguments to use for training.
1115
+ data_collator (`transformers.DataCollator`):
1116
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1117
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1118
+ train_dataset (`datasets.Dataset`):
1119
+ The dataset to use for training.
1120
+ eval_dataset (`datasets.Dataset`):
1121
+ The dataset to use for evaluation.
1122
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1123
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1124
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1125
+ reuse the fine-tuned model.
1126
+ peft_config (`dict`):
1127
+ The peft config to use for training.
1128
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1129
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1130
+ a dictionary string to metric values.
1131
+ callbacks (`list[transformers.TrainerCallback]`):
1132
+ The callbacks to use for training.
1133
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1134
+ The optimizer and scheduler to use for training.
1135
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1136
+ The function to use to preprocess the logits before computing the metrics.
1137
+
1138
+ """
1139
+ def __init__(
1140
+ self,
1141
+ model,
1142
+ ref_model = None,
1143
+ reward_model = None,
1144
+ judge = None,
1145
+ args = None,
1146
+ data_collator = None,
1147
+ train_dataset = None,
1148
+ eval_dataset = None,
1149
+ processing_class = None,
1150
+ reward_processing_class = None,
1151
+ peft_config = None,
1152
+ compute_metrics = None,
1153
+ callbacks = None,
1154
+ preprocess_logits_for_metrics = None,
1155
+ **kwargs
1156
+ ):
1157
+ if args is None: args = UnslothOnlineDPOConfig()
1158
+ use_bf16 = getattr(args, 'bf16', False)
1159
+ use_fp16 = getattr(args, 'fp16', False)
1160
+ force_float32 = False
1161
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1162
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1163
+ force_float32 = True
1164
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1165
+ dtype = getattr(model.config, 'torch_dtype', None)
1166
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1167
+ from unsloth_zoo.utils import _get_dtype
1168
+ dtype = _get_dtype(dtype)
1169
+ float16 = dtype == torch.float16
1170
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1171
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1172
+ if force_float32:
1173
+ args.fp16 = False
1174
+ args.bf16 = False
1175
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1176
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1177
+ args.fp16 = float16
1178
+ args.bf16 = not float16
1179
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1180
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1181
+ args.eval_strategy = 'steps'
1182
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1183
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1184
+ if ga_steps is not None and ga_steps > 1:
1185
+ from transformers import __version__ as transformers_version
1186
+ if Version(transformers_version) <= Version('4.45.2'):
1187
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1188
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1189
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1190
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1191
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1192
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1193
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1194
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1195
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1196
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1197
+ if force_float32:
1198
+ args.bf16_full_eval = False
1199
+ args.fp16_full_eval = False
1200
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1201
+ args.bf16_full_eval = True
1202
+ args.fp16_full_eval = False
1203
+ elif not bf16_full_eval and not fp16_full_eval:
1204
+ args.bf16_full_eval = args.bf16
1205
+ args.fp16_full_eval = args.fp16
1206
+ _output_logits = False
1207
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1208
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1209
+ if _output_logits:
1210
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1211
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1212
+ pass
1213
+ else:
1214
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1215
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1216
+ if args_max_seq_length is None and model_max_seq_length is not None:
1217
+ max_seq_length = model.max_seq_length
1218
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1219
+ if model is not None and hasattr(model, 'for_training'):
1220
+ model.for_training()
1221
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1222
+ if 'processing_class' in locals():
1223
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1224
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1225
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1226
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1227
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1228
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1229
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1230
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1231
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1232
+ else:
1233
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1234
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1235
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1236
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1237
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1238
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1239
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1240
+ else:
1241
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1242
+ other_metrics = []
1243
+
1244
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1245
+ PatchRLStatistics('online_dpo_trainer', other_metrics)
1246
+
1247
+ super().__init__(
1248
+ model = model,
1249
+ ref_model = ref_model,
1250
+ reward_model = reward_model,
1251
+ judge = judge,
1252
+ args = args,
1253
+ data_collator = data_collator,
1254
+ train_dataset = train_dataset,
1255
+ eval_dataset = eval_dataset,
1256
+ processing_class = processing_class,
1257
+ reward_processing_class = reward_processing_class,
1258
+ peft_config = peft_config,
1259
+ compute_metrics = compute_metrics,
1260
+ callbacks = callbacks,
1261
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1262
+ if hasattr(self, 'neftune_hook_handle'):
1263
+ self.neftune_hook_handle.remove()
1264
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1265
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1266
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1267
+ pass
1268
+
1269
+ pass
unsloth_compiled_cache/UnslothPPOTrainer.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPPOConfig(PPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`PPOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
58
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
59
+ ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
60
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
61
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
62
+ Number of epochs to train.
63
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
64
+ Whether to whiten the rewards.
65
+ kl_coef (`float`, *optional*, defaults to `0.05`):
66
+ KL coefficient.
67
+ cliprange (`float`, *optional*, defaults to `0.2`):
68
+ Clip range.
69
+ vf_coef (`float`, *optional*, defaults to `0.1`):
70
+ Value function coefficient.
71
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
72
+ Clip range for the value function.
73
+ gamma (`float`, *optional*, defaults to `1.0`):
74
+ Discount factor.
75
+ lam (`float`, *optional*, defaults to `0.95`):
76
+ Lambda value for GAE.
77
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
78
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
79
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
80
+ capacity of a single GPU, albeit at the cost of slower generation.
81
+
82
+ """
83
+ vllm_sampling_params: Optional[Any] = field(
84
+ default = None,
85
+ metadata = {'help': 'vLLM SamplingParams'},
86
+ )
87
+ unsloth_num_chunks : Optional[int] = field(
88
+ default = -1,
89
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
90
+ )
91
+ def __init__(
92
+ self,
93
+ output_dir = None,
94
+ overwrite_output_dir = None,
95
+ do_train = False,
96
+ do_eval = False,
97
+ do_predict = False,
98
+ eval_strategy = 'no',
99
+ prediction_loss_only = False,
100
+ per_device_train_batch_size = 4,
101
+ per_device_eval_batch_size = 4,
102
+ per_gpu_train_batch_size = None,
103
+ per_gpu_eval_batch_size = None,
104
+ gradient_accumulation_steps = 2,
105
+ eval_accumulation_steps = 2,
106
+ eval_delay = 0,
107
+ torch_empty_cache_steps = 250,
108
+ learning_rate = 5e-05,
109
+ weight_decay = 0.01,
110
+ adam_beta1 = 0.9,
111
+ adam_beta2 = 0.999,
112
+ adam_epsilon = 1e-08,
113
+ max_grad_norm = 1.0,
114
+ num_train_epochs = 3.0,
115
+ max_steps = -1,
116
+ lr_scheduler_type = 'linear',
117
+ warmup_ratio = 0.1,
118
+ warmup_steps = 0,
119
+ log_level = 'passive',
120
+ log_level_replica = 'warning',
121
+ log_on_each_node = True,
122
+ logging_dir = None,
123
+ logging_strategy = 'steps',
124
+ logging_first_step = False,
125
+ logging_steps = 1,
126
+ logging_nan_inf_filter = False,
127
+ save_strategy = 'steps',
128
+ save_steps = 500,
129
+ save_total_limit = None,
130
+ save_safetensors = True,
131
+ save_on_each_node = False,
132
+ save_only_model = False,
133
+ restore_callback_states_from_checkpoint = False,
134
+ no_cuda = False,
135
+ use_cpu = False,
136
+ use_mps_device = False,
137
+ seed = 3407,
138
+ data_seed = 3407,
139
+ jit_mode_eval = False,
140
+ use_ipex = False,
141
+ bf16 = False,
142
+ fp16 = False,
143
+ fp16_opt_level = 'O1',
144
+ half_precision_backend = 'auto',
145
+ bf16_full_eval = False,
146
+ fp16_full_eval = False,
147
+ tf32 = None,
148
+ local_rank = -1,
149
+ ddp_backend = None,
150
+ tpu_num_cores = None,
151
+ tpu_metrics_debug = False,
152
+ debug = '',
153
+ dataloader_drop_last = False,
154
+ eval_steps = None,
155
+ dataloader_num_workers = 0,
156
+ dataloader_prefetch_factor = None,
157
+ past_index = -1,
158
+ run_name = None,
159
+ disable_tqdm = None,
160
+ remove_unused_columns = True,
161
+ label_names = None,
162
+ load_best_model_at_end = False,
163
+ metric_for_best_model = None,
164
+ greater_is_better = None,
165
+ ignore_data_skip = False,
166
+ fsdp = '',
167
+ fsdp_min_num_params = 0,
168
+ fsdp_config = None,
169
+ tp_size = 0,
170
+ fsdp_transformer_layer_cls_to_wrap = None,
171
+ accelerator_config = None,
172
+ deepspeed = None,
173
+ label_smoothing_factor = 0.0,
174
+ optim = 'adamw_8bit',
175
+ optim_args = None,
176
+ adafactor = False,
177
+ group_by_length = False,
178
+ length_column_name = 'length',
179
+ report_to = None,
180
+ ddp_find_unused_parameters = None,
181
+ ddp_bucket_cap_mb = None,
182
+ ddp_broadcast_buffers = None,
183
+ dataloader_pin_memory = True,
184
+ dataloader_persistent_workers = False,
185
+ skip_memory_metrics = True,
186
+ use_legacy_prediction_loop = False,
187
+ push_to_hub = False,
188
+ resume_from_checkpoint = None,
189
+ hub_model_id = None,
190
+ hub_strategy = 'every_save',
191
+ hub_token = None,
192
+ hub_private_repo = None,
193
+ hub_always_push = False,
194
+ gradient_checkpointing = False,
195
+ gradient_checkpointing_kwargs = None,
196
+ include_inputs_for_metrics = False,
197
+ eval_do_concat_batches = True,
198
+ fp16_backend = 'auto',
199
+ evaluation_strategy = None,
200
+ push_to_hub_model_id = None,
201
+ push_to_hub_organization = None,
202
+ push_to_hub_token = None,
203
+ mp_parameters = '',
204
+ auto_find_batch_size = False,
205
+ full_determinism = False,
206
+ torchdynamo = None,
207
+ ray_scope = 'last',
208
+ ddp_timeout = 1800,
209
+ torch_compile = False,
210
+ torch_compile_backend = None,
211
+ torch_compile_mode = None,
212
+ dispatch_batches = None,
213
+ split_batches = None,
214
+ include_tokens_per_second = False,
215
+ include_num_input_tokens_seen = False,
216
+ neftune_noise_alpha = None,
217
+ optim_target_modules = None,
218
+ batch_eval_metrics = False,
219
+ eval_on_start = False,
220
+ use_liger_kernel = False,
221
+ eval_use_gather_object = False,
222
+ average_tokens_across_devices = False,
223
+ dataset_num_proc = None,
224
+ num_mini_batches = 1,
225
+ total_episodes = None,
226
+ local_rollout_forward_batch_size = 64,
227
+ num_sample_generations = 10,
228
+ response_length = 53,
229
+ stop_token = None,
230
+ stop_token_id = None,
231
+ temperature = 0.7,
232
+ missing_eos_penalty = None,
233
+ sft_model_path = 'EleutherAI/pythia-160m',
234
+ world_size = None,
235
+ num_total_batches = None,
236
+ micro_batch_size = None,
237
+ local_batch_size = None,
238
+ batch_size = None,
239
+ local_mini_batch_size = None,
240
+ mini_batch_size = None,
241
+ exp_name = 'ppo_config',
242
+ reward_model_path = 'EleutherAI/pythia-160m',
243
+ model_adapter_name = None,
244
+ ref_adapter_name = None,
245
+ num_ppo_epochs = 4,
246
+ whiten_rewards = False,
247
+ kl_coef = 0.05,
248
+ cliprange = 0.2,
249
+ vf_coef = 0.1,
250
+ cliprange_value = 0.2,
251
+ gamma = 1.0,
252
+ lam = 0.95,
253
+ ds3_gather_for_generation = True,
254
+ vllm_sampling_params = None,
255
+ unsloth_num_chunks = -1,
256
+ **kwargs,
257
+ ):
258
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
259
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
260
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
261
+ output_dir = 'unsloth_training_checkpoints'
262
+ save_strategy = 'no'
263
+ if dataset_num_proc is None:
264
+ from multiprocessing import cpu_count
265
+ dataset_num_proc = cpu_count()
266
+
267
+ super().__init__(
268
+ output_dir = output_dir,
269
+ overwrite_output_dir = overwrite_output_dir,
270
+ do_train = do_train,
271
+ do_eval = do_eval,
272
+ do_predict = do_predict,
273
+ eval_strategy = eval_strategy,
274
+ prediction_loss_only = prediction_loss_only,
275
+ per_device_train_batch_size = per_device_train_batch_size,
276
+ per_device_eval_batch_size = per_device_eval_batch_size,
277
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
278
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
279
+ gradient_accumulation_steps = gradient_accumulation_steps,
280
+ eval_accumulation_steps = eval_accumulation_steps,
281
+ eval_delay = eval_delay,
282
+ torch_empty_cache_steps = torch_empty_cache_steps,
283
+ learning_rate = learning_rate,
284
+ weight_decay = weight_decay,
285
+ adam_beta1 = adam_beta1,
286
+ adam_beta2 = adam_beta2,
287
+ adam_epsilon = adam_epsilon,
288
+ max_grad_norm = max_grad_norm,
289
+ num_train_epochs = num_train_epochs,
290
+ max_steps = max_steps,
291
+ lr_scheduler_type = lr_scheduler_type,
292
+ warmup_ratio = warmup_ratio,
293
+ warmup_steps = warmup_steps,
294
+ log_level = log_level,
295
+ log_level_replica = log_level_replica,
296
+ log_on_each_node = log_on_each_node,
297
+ logging_dir = logging_dir,
298
+ logging_strategy = logging_strategy,
299
+ logging_first_step = logging_first_step,
300
+ logging_steps = logging_steps,
301
+ logging_nan_inf_filter = logging_nan_inf_filter,
302
+ save_strategy = save_strategy,
303
+ save_steps = save_steps,
304
+ save_total_limit = save_total_limit,
305
+ save_safetensors = save_safetensors,
306
+ save_on_each_node = save_on_each_node,
307
+ save_only_model = save_only_model,
308
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
309
+ no_cuda = no_cuda,
310
+ use_cpu = use_cpu,
311
+ use_mps_device = use_mps_device,
312
+ seed = seed,
313
+ data_seed = data_seed,
314
+ jit_mode_eval = jit_mode_eval,
315
+ use_ipex = use_ipex,
316
+ bf16 = bf16,
317
+ fp16 = fp16,
318
+ fp16_opt_level = fp16_opt_level,
319
+ half_precision_backend = half_precision_backend,
320
+ bf16_full_eval = bf16_full_eval,
321
+ fp16_full_eval = fp16_full_eval,
322
+ tf32 = tf32,
323
+ local_rank = local_rank,
324
+ ddp_backend = ddp_backend,
325
+ tpu_num_cores = tpu_num_cores,
326
+ tpu_metrics_debug = tpu_metrics_debug,
327
+ debug = debug,
328
+ dataloader_drop_last = dataloader_drop_last,
329
+ eval_steps = eval_steps,
330
+ dataloader_num_workers = dataloader_num_workers,
331
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
332
+ past_index = past_index,
333
+ run_name = run_name,
334
+ disable_tqdm = disable_tqdm,
335
+ remove_unused_columns = remove_unused_columns,
336
+ label_names = label_names,
337
+ load_best_model_at_end = load_best_model_at_end,
338
+ metric_for_best_model = metric_for_best_model,
339
+ greater_is_better = greater_is_better,
340
+ ignore_data_skip = ignore_data_skip,
341
+ fsdp = fsdp,
342
+ fsdp_min_num_params = fsdp_min_num_params,
343
+ fsdp_config = fsdp_config,
344
+ tp_size = tp_size,
345
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
346
+ accelerator_config = accelerator_config,
347
+ deepspeed = deepspeed,
348
+ label_smoothing_factor = label_smoothing_factor,
349
+ optim = optim,
350
+ optim_args = optim_args,
351
+ adafactor = adafactor,
352
+ group_by_length = group_by_length,
353
+ length_column_name = length_column_name,
354
+ report_to = report_to,
355
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
356
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
357
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
358
+ dataloader_pin_memory = dataloader_pin_memory,
359
+ dataloader_persistent_workers = dataloader_persistent_workers,
360
+ skip_memory_metrics = skip_memory_metrics,
361
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
362
+ push_to_hub = push_to_hub,
363
+ resume_from_checkpoint = resume_from_checkpoint,
364
+ hub_model_id = hub_model_id,
365
+ hub_strategy = hub_strategy,
366
+ hub_token = hub_token,
367
+ hub_private_repo = hub_private_repo,
368
+ hub_always_push = hub_always_push,
369
+ gradient_checkpointing = gradient_checkpointing,
370
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
371
+ include_inputs_for_metrics = include_inputs_for_metrics,
372
+ eval_do_concat_batches = eval_do_concat_batches,
373
+ fp16_backend = fp16_backend,
374
+ evaluation_strategy = evaluation_strategy,
375
+ push_to_hub_model_id = push_to_hub_model_id,
376
+ push_to_hub_organization = push_to_hub_organization,
377
+ push_to_hub_token = push_to_hub_token,
378
+ mp_parameters = mp_parameters,
379
+ auto_find_batch_size = auto_find_batch_size,
380
+ full_determinism = full_determinism,
381
+ torchdynamo = torchdynamo,
382
+ ray_scope = ray_scope,
383
+ ddp_timeout = ddp_timeout,
384
+ torch_compile = torch_compile,
385
+ torch_compile_backend = torch_compile_backend,
386
+ torch_compile_mode = torch_compile_mode,
387
+ dispatch_batches = dispatch_batches,
388
+ split_batches = split_batches,
389
+ include_tokens_per_second = include_tokens_per_second,
390
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
391
+ neftune_noise_alpha = neftune_noise_alpha,
392
+ optim_target_modules = optim_target_modules,
393
+ batch_eval_metrics = batch_eval_metrics,
394
+ eval_on_start = eval_on_start,
395
+ use_liger_kernel = use_liger_kernel,
396
+ eval_use_gather_object = eval_use_gather_object,
397
+ average_tokens_across_devices = average_tokens_across_devices,
398
+ dataset_num_proc = dataset_num_proc,
399
+ num_mini_batches = num_mini_batches,
400
+ total_episodes = total_episodes,
401
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
402
+ num_sample_generations = num_sample_generations,
403
+ response_length = response_length,
404
+ stop_token = stop_token,
405
+ stop_token_id = stop_token_id,
406
+ temperature = temperature,
407
+ missing_eos_penalty = missing_eos_penalty,
408
+ sft_model_path = sft_model_path,
409
+ world_size = world_size,
410
+ num_total_batches = num_total_batches,
411
+ micro_batch_size = micro_batch_size,
412
+ local_batch_size = local_batch_size,
413
+ batch_size = batch_size,
414
+ local_mini_batch_size = local_mini_batch_size,
415
+ mini_batch_size = mini_batch_size,
416
+ exp_name = exp_name,
417
+ reward_model_path = reward_model_path,
418
+ model_adapter_name = model_adapter_name,
419
+ ref_adapter_name = ref_adapter_name,
420
+ num_ppo_epochs = num_ppo_epochs,
421
+ whiten_rewards = whiten_rewards,
422
+ kl_coef = kl_coef,
423
+ cliprange = cliprange,
424
+ vf_coef = vf_coef,
425
+ cliprange_value = cliprange_value,
426
+ gamma = gamma,
427
+ lam = lam,
428
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
429
+ self.vllm_sampling_params = vllm_sampling_params
430
+ self.unsloth_num_chunks = unsloth_num_chunks
431
+ pass
432
+
433
+ class _UnslothPPOTrainer(Trainer):
434
+ _tag_names = ["trl", "ppo"]
435
+
436
+ def __init__(
437
+ self,
438
+ args: PPOConfig,
439
+ processing_class: Optional[
440
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
441
+ ],
442
+ model: nn.Module,
443
+ ref_model: Optional[nn.Module],
444
+ reward_model: nn.Module,
445
+ train_dataset: Dataset,
446
+ value_model: Optional[nn.Module] = None,
447
+ data_collator: Optional[DataCollatorWithPadding] = None,
448
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
449
+ # less commonly used
450
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
451
+ callbacks: Optional[list[TrainerCallback]] = None,
452
+ peft_config: Optional["PeftConfig"] = None,
453
+ ) -> None:
454
+ if ref_model is model:
455
+ raise ValueError(
456
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
457
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
458
+ )
459
+
460
+ self.args = args
461
+ self.processing_class = processing_class
462
+ self.policy_model = model
463
+
464
+ # Define the collator if not provided
465
+ if data_collator is None:
466
+ data_collator = DataCollatorWithPadding(self.processing_class)
467
+
468
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
469
+ if args.stop_token and args.stop_token_id:
470
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
471
+ elif args.stop_token:
472
+ if args.stop_token == "eos":
473
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
474
+ else:
475
+ raise ValueError(
476
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
477
+ )
478
+ else:
479
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
480
+
481
+ # peft support
482
+ if not is_peft_available() and peft_config is not None:
483
+ raise ImportError(
484
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
485
+ )
486
+ elif is_peft_available() and peft_config is not None:
487
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
488
+ if isinstance(self.policy_model, PeftModel):
489
+ self.policy_model = self.policy_model.merge_and_unload()
490
+
491
+ # get peft model with the given config
492
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
493
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
494
+ peft_module_casting_to_bf16(self.policy_model)
495
+
496
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
497
+ self.model_adapter_name = args.model_adapter_name
498
+ self.ref_adapter_name = args.ref_adapter_name
499
+
500
+ if ref_model:
501
+ self.ref_model = ref_model
502
+ elif self.is_peft_model:
503
+ self.ref_model = None
504
+ else:
505
+ self.ref_model = create_reference_model(self.policy_model)
506
+
507
+ self.reward_model = reward_model
508
+ self.train_dataset = train_dataset
509
+ self.train_dataset_len = len(train_dataset)
510
+ self.value_model = value_model
511
+ self.data_collator = data_collator
512
+ self.eval_dataset = eval_dataset
513
+ self.optimizer, self.lr_scheduler = optimizers
514
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
515
+
516
+ #########
517
+ # calculate various batch sizes
518
+ #########
519
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
520
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
521
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
522
+ self.accelerator = accelerator
523
+ args.world_size = accelerator.num_processes
524
+ args.local_batch_size = (
525
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
526
+ )
527
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
528
+ args.batch_size = int(args.local_batch_size * args.world_size)
529
+ args.mini_batch_size = exact_div(
530
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
531
+ )
532
+ args.local_mini_batch_size = exact_div(
533
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
534
+ )
535
+ if args.whiten_rewards:
536
+ assert (
537
+ args.local_mini_batch_size >= 8
538
+ ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
539
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
540
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
541
+ args.num_total_batches = math.ceil(
542
+ args.total_episodes / args.batch_size
543
+ ) # we may train for more than `total_episodes`
544
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
545
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
546
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
547
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
548
+ if args.num_sample_generations > 0:
549
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
550
+ self.local_dataloader_batch_size = args.local_batch_size
551
+
552
+ #########
553
+ # setup model, optimizer, and others
554
+ #########
555
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
556
+ if module is not None:
557
+ disable_dropout_in_model(module)
558
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
559
+ self.model.config = self.policy_model.config # needed for pushing to hub
560
+ self.create_optimizer_and_scheduler(
561
+ num_training_steps=args.num_total_batches
562
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
563
+
564
+ #########
565
+ ### trainer specifics
566
+ #########
567
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
568
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
569
+ self.callback_handler = CallbackHandler(
570
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
571
+ )
572
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
573
+ self.control = TrainerControl()
574
+ self.state = OnlineTrainerState(
575
+ is_local_process_zero=self.is_local_process_zero(),
576
+ is_world_process_zero=self.is_world_process_zero(),
577
+ stateful_callbacks=[
578
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
579
+ ],
580
+ )
581
+ self.current_flos = 0
582
+ self.hp_search_backend = None
583
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
584
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
585
+ # Create distant repo and output directory if needed
586
+ self.hub_model_id = None
587
+ if self.args.push_to_hub:
588
+ self.init_hf_repo()
589
+ if self.args.should_save:
590
+ os.makedirs(self.args.output_dir, exist_ok=True)
591
+
592
+ # Add tags for models that have been loaded with the correct transformers version
593
+ if hasattr(self.model, "add_model_tags"):
594
+ self.model.add_model_tags(self._tag_names)
595
+
596
+ #########
597
+ ### setup dataloader
598
+ #########
599
+ self.dataloader = DataLoader(
600
+ self.train_dataset,
601
+ batch_size=self.local_dataloader_batch_size,
602
+ shuffle=True,
603
+ collate_fn=self.data_collator,
604
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
605
+ )
606
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
607
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
608
+ torch.manual_seed(args.seed)
609
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
610
+ torch.manual_seed(self.local_seed) # reset the local seed again
611
+
612
+ self.eval_dataloader = DataLoader(
613
+ self.eval_dataset,
614
+ batch_size=args.per_device_eval_batch_size,
615
+ collate_fn=self.data_collator,
616
+ drop_last=True,
617
+ ) # no need to shuffle eval dataset
618
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
619
+
620
+ if self.is_deepspeed_enabled:
621
+ self.reward_model = prepare_deepspeed(
622
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
623
+ )
624
+
625
+ if self.ref_model is None:
626
+ if not self.is_peft_model:
627
+ raise ValueError("No reference model and model is not a Peft model.")
628
+ else:
629
+ self.ref_model = prepare_deepspeed(
630
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
631
+ )
632
+ else:
633
+ if self.ref_model is None:
634
+ if not self.is_peft_model:
635
+ raise ValueError("No reference model and model is not a Peft model.")
636
+ else:
637
+ self.ref_model = self.ref_model.to(self.accelerator.device)
638
+ self.reward_model = self.reward_model.to(self.accelerator.device)
639
+
640
+ def get_train_dataloader(self) -> DataLoader:
641
+ return self.dataloader
642
+
643
+ def get_eval_dataloader(self) -> DataLoader:
644
+ return self.eval_dataloader
645
+
646
+ @contextmanager
647
+ def null_ref_context(self):
648
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
649
+ with (
650
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
651
+ if self.is_peft_model and not self.ref_adapter_name
652
+ else nullcontext()
653
+ ):
654
+ if self.ref_adapter_name:
655
+ self.model.policy.set_adapter(self.ref_adapter_name)
656
+ yield
657
+ if self.ref_adapter_name:
658
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
659
+
660
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
661
+ backup_model = self.model
662
+ self.model = self.model.policy # save only the policy
663
+
664
+ if self.is_deepspeed_enabled:
665
+ backup_deepspeed = self.deepspeed
666
+ self.deepspeed = self.model
667
+
668
+ super().save_model(output_dir, _internal_call)
669
+
670
+ self.model = backup_model
671
+
672
+ if self.is_deepspeed_enabled:
673
+ self.deepspeed = backup_deepspeed
674
+
675
+ def train(self):
676
+ args = self.args
677
+ accelerator = self.accelerator
678
+ optimizer = self.optimizer
679
+ model = self.model
680
+ ref_policy = self.ref_model
681
+ reward_model = self.reward_model
682
+ processing_class = self.processing_class
683
+ dataloader = self.dataloader
684
+ device = accelerator.device
685
+
686
+ def repeat_generator():
687
+ while True:
688
+ yield from dataloader
689
+
690
+ iter_dataloader = iter(repeat_generator())
691
+ generation_config = GenerationConfig(
692
+ max_new_tokens=args.response_length,
693
+ temperature=(args.temperature + 1e-7),
694
+ top_k=0.0,
695
+ top_p=1.0,
696
+ do_sample=True,
697
+ )
698
+
699
+ accelerator.print("===training policy===")
700
+ start_time = time.time()
701
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
702
+ approxkl_stats = torch.zeros(stats_shape, device=device)
703
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
704
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
705
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
706
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
707
+ entropy_stats = torch.zeros(stats_shape, device=device)
708
+ ratio_stats = torch.zeros(stats_shape, device=device)
709
+ model.train()
710
+
711
+ # trainer state initialization
712
+ self.state.global_step = 0
713
+ self.state.episode = 0
714
+ self.state.max_steps = args.num_total_batches * args.num_mini_batches
715
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
716
+ # Compute absolute values for logging, eval, and save if given as ratio
717
+ if args.logging_steps is not None:
718
+ if args.logging_steps < 1:
719
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
720
+ else:
721
+ self.state.logging_steps = args.logging_steps
722
+ if args.eval_steps is not None:
723
+ if args.eval_steps < 1:
724
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
725
+ else:
726
+ self.state.eval_steps = args.eval_steps
727
+ if args.save_steps is not None:
728
+ if args.save_steps < 1:
729
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
730
+ else:
731
+ self.state.save_steps = args.save_steps
732
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
733
+
734
+ # backward compatibility
735
+ if self.is_deepspeed_enabled:
736
+ self.deepspeed = self.model
737
+ self.model_wrapped = self.model
738
+
739
+ for update in range(1, args.num_total_batches + 1):
740
+ self.state.episode += 1 * args.batch_size
741
+ data = next(iter_dataloader)
742
+ with torch.no_grad():
743
+ queries = data["input_ids"].to(device)
744
+ context_length = queries.shape[1]
745
+ responses = []
746
+ postprocessed_responses = []
747
+ logprobs = []
748
+ ref_logprobs = []
749
+ scores = []
750
+ sequence_lengths = []
751
+ values = []
752
+ with unwrap_model_for_generation(
753
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
754
+ ) as unwrapped_model:
755
+ query_responses, logitss = batch_generation(
756
+ unwrapped_model.policy,
757
+ queries,
758
+ args.local_rollout_forward_batch_size,
759
+ processing_class.pad_token_id,
760
+ generation_config,
761
+ )
762
+
763
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
764
+ query = queries[i : i + args.local_rollout_forward_batch_size]
765
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
766
+ response = query_response[:, context_length:]
767
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
768
+ logprob = selective_log_softmax(logits, response)
769
+ del logits
770
+ torch.cuda.empty_cache()
771
+
772
+ if ref_policy is None:
773
+ with self.null_ref_context():
774
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
775
+ else:
776
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
777
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
778
+ ref_logits /= args.temperature + 1e-7
779
+ ref_logprob = selective_log_softmax(ref_logits, response)
780
+ del ref_output, ref_logits
781
+ torch.cuda.empty_cache()
782
+
783
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
784
+ postprocessed_response = response
785
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
786
+ postprocessed_response = truncate_response(
787
+ self.stop_token_id, processing_class.pad_token_id, response
788
+ )
789
+
790
+ # Response Processing 2. run reward model on the truncated responses
791
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
792
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
793
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
794
+ full_value, _, _ = get_reward(
795
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
796
+ )
797
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
798
+ _, score, _ = get_reward(
799
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
800
+ )
801
+
802
+ responses.append(response)
803
+ postprocessed_responses.append(postprocessed_response)
804
+ logprobs.append(logprob)
805
+ ref_logprobs.append(ref_logprob)
806
+ sequence_lengths.append(sequence_length)
807
+ scores.append(score)
808
+ values.append(value)
809
+ responses = torch.cat(responses, 0)
810
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
811
+ logprobs = torch.cat(logprobs, 0)
812
+ ref_logprobs = torch.cat(ref_logprobs, 0)
813
+ sequence_lengths = torch.cat(sequence_lengths, 0)
814
+ scores = torch.cat(scores, 0)
815
+ values = torch.cat(values, 0)
816
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
817
+ torch.cuda.empty_cache()
818
+ gc.collect()
819
+
820
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
821
+ # Completions not passing that filter will receive a lower score.
822
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
823
+ if self.args.missing_eos_penalty is not None:
824
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
825
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
826
+
827
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
828
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
829
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
830
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
831
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
832
+ sequence_lengths_p1 = sequence_lengths + 1
833
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
834
+ values = torch.masked_fill(values, padding_mask_p1, 0)
835
+
836
+ # 4. compute rewards
837
+ kl = logprobs - ref_logprobs
838
+ non_score_reward = -args.kl_coef * kl
839
+ rewards = non_score_reward.clone()
840
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
841
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
842
+ rewards[[actual_start, actual_end]] += scores
843
+
844
+ # 5. whiten rewards
845
+ if args.whiten_rewards:
846
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
847
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
848
+
849
+ # 6. compute advantages and returns
850
+ lastgaelam = 0
851
+ advantages_reversed = []
852
+ gen_length = responses.shape[1]
853
+ for t in reversed(range(gen_length)):
854
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
855
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
856
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
857
+ advantages_reversed.append(lastgaelam)
858
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
859
+ returns = advantages + values
860
+ advantages = masked_whiten(advantages, ~padding_mask)
861
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
862
+ torch.cuda.empty_cache()
863
+
864
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
865
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
866
+ b_inds = np.random.permutation(args.local_batch_size)
867
+ minibatch_idx = 0
868
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
869
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
870
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
871
+ gradient_accumulation_idx = 0
872
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
873
+ with accelerator.accumulate(model):
874
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
875
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
876
+ mb_advantage = advantages[micro_batch_inds]
877
+ mb_responses = responses[micro_batch_inds]
878
+ mb_query_responses = query_responses[micro_batch_inds]
879
+ mb_logprobs = logprobs[micro_batch_inds]
880
+ mb_return = returns[micro_batch_inds]
881
+ mb_values = values[micro_batch_inds]
882
+
883
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
884
+ logits = output.logits[:, context_length - 1 : -1]
885
+ logits /= args.temperature + 1e-7
886
+ new_logprobs = selective_log_softmax(logits, mb_responses)
887
+ new_logprobs = torch.masked_fill(
888
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
889
+ )
890
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
891
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
892
+ vpredclipped = torch.clamp(
893
+ vpred,
894
+ mb_values - args.cliprange_value,
895
+ mb_values + args.cliprange_value,
896
+ )
897
+ vf_losses1 = torch.square(vpred - mb_return)
898
+ vf_losses2 = torch.square(vpredclipped - mb_return)
899
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
900
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
901
+ vf_clipfrac = masked_mean(
902
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
903
+ )
904
+ logprobs_diff = new_logprobs - mb_logprobs
905
+ ratio = torch.exp(logprobs_diff)
906
+ pg_losses = -mb_advantage * ratio
907
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
908
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
909
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
910
+ loss = pg_loss + args.vf_coef * vf_loss
911
+ accelerator.backward(loss)
912
+ optimizer.step()
913
+ optimizer.zero_grad()
914
+ with torch.no_grad():
915
+ pg_clipfrac = masked_mean(
916
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
917
+ )
918
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
919
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
920
+ approxkl = 0.5 * (logprobs_diff**2).mean()
921
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
922
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
923
+ pg_clipfrac
924
+ )
925
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
926
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
927
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
928
+ vf_clipfrac
929
+ )
930
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
931
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
932
+ gradient_accumulation_idx += 1
933
+ minibatch_idx += 1
934
+ # del everything and empty cache
935
+ # fmt: off
936
+ del (
937
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
938
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
939
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
940
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
941
+ )
942
+ # fmt: on
943
+ torch.cuda.empty_cache()
944
+ with torch.no_grad():
945
+ mean_kl = kl.sum(1).mean()
946
+ mean_entropy = (-logprobs).sum(1).mean()
947
+ mean_non_score_reward = non_score_reward.sum(1).mean()
948
+ rlhf_reward = mean_non_score_reward + scores.mean()
949
+ eps = int(self.state.episode / (time.time() - start_time))
950
+ metrics = {}
951
+ metrics["eps"] = eps
952
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
953
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
954
+ metrics["objective/non_score_reward"] = (
955
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
956
+ )
957
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
958
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
959
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
960
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
961
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
962
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
963
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
964
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
965
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
966
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
967
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
968
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
969
+ metrics["episode"] = self.state.episode
970
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
971
+ self.state.global_step += 1
972
+ self.log(metrics)
973
+
974
+ self.lr_scheduler.step()
975
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
976
+ if self.control.should_save:
977
+ self._save_checkpoint(model, trial=None)
978
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
979
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
980
+ torch.cuda.empty_cache()
981
+ gc.collect()
982
+
983
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
984
+ self.generate_completions(sampling=True)
985
+ torch.cuda.empty_cache()
986
+ del (
987
+ query_responses,
988
+ responses,
989
+ postprocessed_responses,
990
+ logprobs,
991
+ ref_logprobs,
992
+ values,
993
+ sequence_lengths,
994
+ contain_eos_token,
995
+ sequence_lengths_p1,
996
+ response_idxs,
997
+ padding_mask,
998
+ padding_mask_p1,
999
+ rewards,
1000
+ actual_start,
1001
+ actual_end,
1002
+ advantages,
1003
+ returns,
1004
+ )
1005
+ torch.cuda.empty_cache()
1006
+
1007
+ # HF trainer specifics
1008
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1009
+ if self.control.should_save:
1010
+ self._save_checkpoint(model, trial=None, metrics=None)
1011
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1012
+
1013
+ def generate_completions(self, sampling: bool = False):
1014
+ args = self.args
1015
+ processing_class = self.processing_class
1016
+ generation_config = GenerationConfig(
1017
+ max_new_tokens=self.args.response_length,
1018
+ temperature=(0.01 + 1e-7),
1019
+ top_k=0.0,
1020
+ top_p=1.0,
1021
+ do_sample=True,
1022
+ )
1023
+
1024
+ table = defaultdict(list)
1025
+ with unwrap_model_for_generation(
1026
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1027
+ ) as unwrapped_model:
1028
+ for batch in self.eval_dataloader:
1029
+ query = batch["input_ids"]
1030
+ with torch.no_grad():
1031
+ context_length = query.shape[1]
1032
+ query_response, _ = batch_generation(
1033
+ unwrapped_model.policy,
1034
+ query,
1035
+ query.shape[0],
1036
+ processing_class.pad_token_id,
1037
+ generation_config,
1038
+ )
1039
+ response = query_response[:, context_length:]
1040
+ postprocessed_response = response
1041
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1042
+ postprocessed_response = truncate_response(
1043
+ self.stop_token_id, processing_class.pad_token_id, response
1044
+ )
1045
+ table["query"].extend(
1046
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1047
+ )
1048
+ table["model response"].extend(
1049
+ gather_object(processing_class.batch_decode(postprocessed_response))
1050
+ )
1051
+
1052
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1053
+ _, score, _ = get_reward(
1054
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1055
+ )
1056
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1057
+
1058
+ if sampling:
1059
+ break
1060
+ df = pd.DataFrame(table)
1061
+
1062
+ if self.accelerator.is_main_process:
1063
+ print_rich_table(df.iloc[0 : 0 + 5])
1064
+ if "wandb" in args.report_to:
1065
+ import wandb
1066
+
1067
+ if wandb.run is not None:
1068
+ wandb.log({"completions": wandb.Table(dataframe=df)})
1069
+
1070
+ if "comet_ml" in args.report_to:
1071
+ log_table_to_comet_experiment(
1072
+ name="completions.csv",
1073
+ table=df,
1074
+ )
1075
+
1076
+ def create_model_card(
1077
+ self,
1078
+ model_name: Optional[str] = None,
1079
+ dataset_name: Optional[str] = None,
1080
+ tags: Union[str, list[str], None] = None,
1081
+ ):
1082
+ """
1083
+ Creates a draft of a model card using the information available to the `Trainer`.
1084
+
1085
+ Args:
1086
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1087
+ Name of the model.
1088
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1089
+ Name of the dataset used for training.
1090
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1091
+ Tags to be associated with the model card.
1092
+ """
1093
+ if not self.is_world_process_zero():
1094
+ return
1095
+
1096
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1097
+ base_model = self.model.config._name_or_path
1098
+ else:
1099
+ base_model = None
1100
+
1101
+ tags = tags or []
1102
+ if isinstance(tags, str):
1103
+ tags = [tags]
1104
+
1105
+ if hasattr(self.model.config, "unsloth_version"):
1106
+ tags.append("unsloth")
1107
+
1108
+ citation = textwrap.dedent("""\
1109
+ @article{mziegler2019fine-tuning,
1110
+ title = {{Fine-Tuning Language Models from Human Preferences}},
1111
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1112
+ year = 2019,
1113
+ eprint = {arXiv:1909.08593}
1114
+ }""")
1115
+
1116
+ model_card = generate_model_card(
1117
+ base_model=base_model,
1118
+ model_name=model_name,
1119
+ hub_model_id=self.hub_model_id,
1120
+ dataset_name=dataset_name,
1121
+ tags=tags,
1122
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1123
+ comet_url=get_comet_experiment_url(),
1124
+ trainer_name="PPO",
1125
+ trainer_citation=citation,
1126
+ paper_title="Fine-Tuning Language Models from Human Preferences",
1127
+ paper_id="1909.08593",
1128
+ )
1129
+
1130
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1131
+ class UnslothPPOTrainer(_UnslothPPOTrainer):
1132
+ """
1133
+
1134
+ """
1135
+ def __init__(
1136
+ self,
1137
+ args,
1138
+ processing_class,
1139
+ model,
1140
+ ref_model,
1141
+ reward_model,
1142
+ train_dataset,
1143
+ value_model = None,
1144
+ data_collator = None,
1145
+ eval_dataset = None,
1146
+ callbacks = None,
1147
+ peft_config = None,
1148
+ **kwargs
1149
+ ):
1150
+ if args is None: args = UnslothPPOConfig()
1151
+ use_bf16 = getattr(args, 'bf16', False)
1152
+ use_fp16 = getattr(args, 'fp16', False)
1153
+ force_float32 = False
1154
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1155
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1156
+ force_float32 = True
1157
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1158
+ dtype = getattr(model.config, 'torch_dtype', None)
1159
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1160
+ from unsloth_zoo.utils import _get_dtype
1161
+ dtype = _get_dtype(dtype)
1162
+ float16 = dtype == torch.float16
1163
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1164
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1165
+ if force_float32:
1166
+ args.fp16 = False
1167
+ args.bf16 = False
1168
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1169
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1170
+ args.fp16 = float16
1171
+ args.bf16 = not float16
1172
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1173
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1174
+ args.eval_strategy = 'steps'
1175
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1176
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1177
+ if ga_steps is not None and ga_steps > 1:
1178
+ from transformers import __version__ as transformers_version
1179
+ if Version(transformers_version) <= Version('4.45.2'):
1180
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1181
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1182
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1183
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1184
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1185
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1186
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1187
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1188
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1189
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1190
+ if force_float32:
1191
+ args.bf16_full_eval = False
1192
+ args.fp16_full_eval = False
1193
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1194
+ args.bf16_full_eval = True
1195
+ args.fp16_full_eval = False
1196
+ elif not bf16_full_eval and not fp16_full_eval:
1197
+ args.bf16_full_eval = args.bf16
1198
+ args.fp16_full_eval = args.fp16
1199
+ _output_logits = False
1200
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1201
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1202
+ if _output_logits:
1203
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1204
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1205
+ pass
1206
+ else:
1207
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1208
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1209
+ if args_max_seq_length is None and model_max_seq_length is not None:
1210
+ max_seq_length = model.max_seq_length
1211
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1212
+ if model is not None and hasattr(model, 'for_training'):
1213
+ model.for_training()
1214
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1215
+ if 'processing_class' in locals():
1216
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1217
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1218
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1219
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1220
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1221
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1222
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1223
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1224
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1225
+ else:
1226
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1227
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1228
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1229
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1230
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1231
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1232
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1233
+ else:
1234
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1235
+ other_metrics = []
1236
+
1237
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1238
+ PatchRLStatistics('ppo_trainer', other_metrics)
1239
+
1240
+ super().__init__(
1241
+ args = args,
1242
+ processing_class = processing_class,
1243
+ model = model,
1244
+ ref_model = ref_model,
1245
+ reward_model = reward_model,
1246
+ train_dataset = train_dataset,
1247
+ value_model = value_model,
1248
+ data_collator = data_collator,
1249
+ eval_dataset = eval_dataset,
1250
+ callbacks = callbacks,
1251
+ peft_config = peft_config,**kwargs)
1252
+ if hasattr(self, 'neftune_hook_handle'):
1253
+ self.neftune_hook_handle.remove()
1254
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1255
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1256
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1257
+ pass
1258
+
1259
+ pass
unsloth_compiled_cache/UnslothPRMTrainer.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothPRMConfig(PRMConfig):
44
+ """
45
+
46
+ Configuration class for the [`PRMTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ learning_rate (`float`, *optional*, defaults to `1e-5`):
54
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
55
+ [`~transformers.TrainingArguments`].
56
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
57
+ Maximum length of the sequences (prompt + completion) used for truncation.
58
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
59
+ Maximum length of the prompt used for truncation.
60
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
61
+ Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
62
+ disable_dropout (`bool`, *optional*, defaults to `True`):
63
+ Whether to disable dropout in the model.
64
+ step_separator (`str`, *optional*, defaults to `"\n"`):
65
+ Separator used to separate each step of the reasoning process.
66
+ train_on_last_step_only (`bool`, *optional*, defaults to `False`):
67
+ Whether to train only on the last step.
68
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
69
+ Number of processes to use for processing the dataset.
70
+
71
+ """
72
+ vllm_sampling_params: Optional[Any] = field(
73
+ default = None,
74
+ metadata = {'help': 'vLLM SamplingParams'},
75
+ )
76
+ unsloth_num_chunks : Optional[int] = field(
77
+ default = -1,
78
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
79
+ )
80
+ def __init__(
81
+ self,
82
+ output_dir = None,
83
+ overwrite_output_dir = None,
84
+ do_train = False,
85
+ do_eval = False,
86
+ do_predict = False,
87
+ eval_strategy = 'no',
88
+ prediction_loss_only = False,
89
+ per_device_train_batch_size = 4,
90
+ per_device_eval_batch_size = 4,
91
+ per_gpu_train_batch_size = None,
92
+ per_gpu_eval_batch_size = None,
93
+ gradient_accumulation_steps = 2,
94
+ eval_accumulation_steps = 2,
95
+ eval_delay = 0,
96
+ torch_empty_cache_steps = 250,
97
+ learning_rate = 5e-05,
98
+ weight_decay = 0.01,
99
+ adam_beta1 = 0.9,
100
+ adam_beta2 = 0.999,
101
+ adam_epsilon = 1e-08,
102
+ max_grad_norm = 1.0,
103
+ num_train_epochs = 3.0,
104
+ max_steps = -1,
105
+ lr_scheduler_type = 'linear',
106
+ warmup_ratio = 0.1,
107
+ warmup_steps = 0,
108
+ log_level = 'passive',
109
+ log_level_replica = 'warning',
110
+ log_on_each_node = True,
111
+ logging_dir = None,
112
+ logging_strategy = 'steps',
113
+ logging_first_step = False,
114
+ logging_steps = 1,
115
+ logging_nan_inf_filter = False,
116
+ save_strategy = 'steps',
117
+ save_steps = 500,
118
+ save_total_limit = None,
119
+ save_safetensors = True,
120
+ save_on_each_node = False,
121
+ save_only_model = False,
122
+ restore_callback_states_from_checkpoint = False,
123
+ no_cuda = False,
124
+ use_cpu = False,
125
+ use_mps_device = False,
126
+ seed = 3407,
127
+ data_seed = 3407,
128
+ jit_mode_eval = False,
129
+ use_ipex = False,
130
+ bf16 = False,
131
+ fp16 = False,
132
+ fp16_opt_level = 'O1',
133
+ half_precision_backend = 'auto',
134
+ bf16_full_eval = False,
135
+ fp16_full_eval = False,
136
+ tf32 = None,
137
+ local_rank = -1,
138
+ ddp_backend = None,
139
+ tpu_num_cores = None,
140
+ tpu_metrics_debug = False,
141
+ debug = '',
142
+ dataloader_drop_last = False,
143
+ eval_steps = None,
144
+ dataloader_num_workers = 0,
145
+ dataloader_prefetch_factor = None,
146
+ past_index = -1,
147
+ run_name = None,
148
+ disable_tqdm = None,
149
+ remove_unused_columns = True,
150
+ label_names = None,
151
+ load_best_model_at_end = False,
152
+ metric_for_best_model = None,
153
+ greater_is_better = None,
154
+ ignore_data_skip = False,
155
+ fsdp = '',
156
+ fsdp_min_num_params = 0,
157
+ fsdp_config = None,
158
+ tp_size = 0,
159
+ fsdp_transformer_layer_cls_to_wrap = None,
160
+ accelerator_config = None,
161
+ deepspeed = None,
162
+ label_smoothing_factor = 0.0,
163
+ optim = 'adamw_8bit',
164
+ optim_args = None,
165
+ adafactor = False,
166
+ group_by_length = False,
167
+ length_column_name = 'length',
168
+ report_to = None,
169
+ ddp_find_unused_parameters = None,
170
+ ddp_bucket_cap_mb = None,
171
+ ddp_broadcast_buffers = None,
172
+ dataloader_pin_memory = True,
173
+ dataloader_persistent_workers = False,
174
+ skip_memory_metrics = True,
175
+ use_legacy_prediction_loop = False,
176
+ push_to_hub = False,
177
+ resume_from_checkpoint = None,
178
+ hub_model_id = None,
179
+ hub_strategy = 'every_save',
180
+ hub_token = None,
181
+ hub_private_repo = None,
182
+ hub_always_push = False,
183
+ gradient_checkpointing = False,
184
+ gradient_checkpointing_kwargs = None,
185
+ include_inputs_for_metrics = False,
186
+ eval_do_concat_batches = True,
187
+ fp16_backend = 'auto',
188
+ evaluation_strategy = None,
189
+ push_to_hub_model_id = None,
190
+ push_to_hub_organization = None,
191
+ push_to_hub_token = None,
192
+ mp_parameters = '',
193
+ auto_find_batch_size = False,
194
+ full_determinism = False,
195
+ torchdynamo = None,
196
+ ray_scope = 'last',
197
+ ddp_timeout = 1800,
198
+ torch_compile = False,
199
+ torch_compile_backend = None,
200
+ torch_compile_mode = None,
201
+ dispatch_batches = None,
202
+ split_batches = None,
203
+ include_tokens_per_second = False,
204
+ include_num_input_tokens_seen = False,
205
+ neftune_noise_alpha = None,
206
+ optim_target_modules = None,
207
+ batch_eval_metrics = False,
208
+ eval_on_start = False,
209
+ use_liger_kernel = False,
210
+ eval_use_gather_object = False,
211
+ average_tokens_across_devices = False,
212
+ max_length = 1024,
213
+ max_prompt_length = 512,
214
+ max_completion_length = None,
215
+ disable_dropout = True,
216
+ step_separator = '\
217
+ ',
218
+ train_on_last_step_only = False,
219
+ dataset_num_proc = None,
220
+ vllm_sampling_params = None,
221
+ unsloth_num_chunks = -1,
222
+ **kwargs,
223
+ ):
224
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
225
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
226
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
227
+ output_dir = 'unsloth_training_checkpoints'
228
+ save_strategy = 'no'
229
+ if dataset_num_proc is None:
230
+ from multiprocessing import cpu_count
231
+ dataset_num_proc = cpu_count()
232
+
233
+ super().__init__(
234
+ output_dir = output_dir,
235
+ overwrite_output_dir = overwrite_output_dir,
236
+ do_train = do_train,
237
+ do_eval = do_eval,
238
+ do_predict = do_predict,
239
+ eval_strategy = eval_strategy,
240
+ prediction_loss_only = prediction_loss_only,
241
+ per_device_train_batch_size = per_device_train_batch_size,
242
+ per_device_eval_batch_size = per_device_eval_batch_size,
243
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
244
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
245
+ gradient_accumulation_steps = gradient_accumulation_steps,
246
+ eval_accumulation_steps = eval_accumulation_steps,
247
+ eval_delay = eval_delay,
248
+ torch_empty_cache_steps = torch_empty_cache_steps,
249
+ learning_rate = learning_rate,
250
+ weight_decay = weight_decay,
251
+ adam_beta1 = adam_beta1,
252
+ adam_beta2 = adam_beta2,
253
+ adam_epsilon = adam_epsilon,
254
+ max_grad_norm = max_grad_norm,
255
+ num_train_epochs = num_train_epochs,
256
+ max_steps = max_steps,
257
+ lr_scheduler_type = lr_scheduler_type,
258
+ warmup_ratio = warmup_ratio,
259
+ warmup_steps = warmup_steps,
260
+ log_level = log_level,
261
+ log_level_replica = log_level_replica,
262
+ log_on_each_node = log_on_each_node,
263
+ logging_dir = logging_dir,
264
+ logging_strategy = logging_strategy,
265
+ logging_first_step = logging_first_step,
266
+ logging_steps = logging_steps,
267
+ logging_nan_inf_filter = logging_nan_inf_filter,
268
+ save_strategy = save_strategy,
269
+ save_steps = save_steps,
270
+ save_total_limit = save_total_limit,
271
+ save_safetensors = save_safetensors,
272
+ save_on_each_node = save_on_each_node,
273
+ save_only_model = save_only_model,
274
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
275
+ no_cuda = no_cuda,
276
+ use_cpu = use_cpu,
277
+ use_mps_device = use_mps_device,
278
+ seed = seed,
279
+ data_seed = data_seed,
280
+ jit_mode_eval = jit_mode_eval,
281
+ use_ipex = use_ipex,
282
+ bf16 = bf16,
283
+ fp16 = fp16,
284
+ fp16_opt_level = fp16_opt_level,
285
+ half_precision_backend = half_precision_backend,
286
+ bf16_full_eval = bf16_full_eval,
287
+ fp16_full_eval = fp16_full_eval,
288
+ tf32 = tf32,
289
+ local_rank = local_rank,
290
+ ddp_backend = ddp_backend,
291
+ tpu_num_cores = tpu_num_cores,
292
+ tpu_metrics_debug = tpu_metrics_debug,
293
+ debug = debug,
294
+ dataloader_drop_last = dataloader_drop_last,
295
+ eval_steps = eval_steps,
296
+ dataloader_num_workers = dataloader_num_workers,
297
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
298
+ past_index = past_index,
299
+ run_name = run_name,
300
+ disable_tqdm = disable_tqdm,
301
+ remove_unused_columns = remove_unused_columns,
302
+ label_names = label_names,
303
+ load_best_model_at_end = load_best_model_at_end,
304
+ metric_for_best_model = metric_for_best_model,
305
+ greater_is_better = greater_is_better,
306
+ ignore_data_skip = ignore_data_skip,
307
+ fsdp = fsdp,
308
+ fsdp_min_num_params = fsdp_min_num_params,
309
+ fsdp_config = fsdp_config,
310
+ tp_size = tp_size,
311
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
312
+ accelerator_config = accelerator_config,
313
+ deepspeed = deepspeed,
314
+ label_smoothing_factor = label_smoothing_factor,
315
+ optim = optim,
316
+ optim_args = optim_args,
317
+ adafactor = adafactor,
318
+ group_by_length = group_by_length,
319
+ length_column_name = length_column_name,
320
+ report_to = report_to,
321
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
322
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
323
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
324
+ dataloader_pin_memory = dataloader_pin_memory,
325
+ dataloader_persistent_workers = dataloader_persistent_workers,
326
+ skip_memory_metrics = skip_memory_metrics,
327
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
328
+ push_to_hub = push_to_hub,
329
+ resume_from_checkpoint = resume_from_checkpoint,
330
+ hub_model_id = hub_model_id,
331
+ hub_strategy = hub_strategy,
332
+ hub_token = hub_token,
333
+ hub_private_repo = hub_private_repo,
334
+ hub_always_push = hub_always_push,
335
+ gradient_checkpointing = gradient_checkpointing,
336
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
337
+ include_inputs_for_metrics = include_inputs_for_metrics,
338
+ eval_do_concat_batches = eval_do_concat_batches,
339
+ fp16_backend = fp16_backend,
340
+ evaluation_strategy = evaluation_strategy,
341
+ push_to_hub_model_id = push_to_hub_model_id,
342
+ push_to_hub_organization = push_to_hub_organization,
343
+ push_to_hub_token = push_to_hub_token,
344
+ mp_parameters = mp_parameters,
345
+ auto_find_batch_size = auto_find_batch_size,
346
+ full_determinism = full_determinism,
347
+ torchdynamo = torchdynamo,
348
+ ray_scope = ray_scope,
349
+ ddp_timeout = ddp_timeout,
350
+ torch_compile = torch_compile,
351
+ torch_compile_backend = torch_compile_backend,
352
+ torch_compile_mode = torch_compile_mode,
353
+ dispatch_batches = dispatch_batches,
354
+ split_batches = split_batches,
355
+ include_tokens_per_second = include_tokens_per_second,
356
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
357
+ neftune_noise_alpha = neftune_noise_alpha,
358
+ optim_target_modules = optim_target_modules,
359
+ batch_eval_metrics = batch_eval_metrics,
360
+ eval_on_start = eval_on_start,
361
+ use_liger_kernel = use_liger_kernel,
362
+ eval_use_gather_object = eval_use_gather_object,
363
+ average_tokens_across_devices = average_tokens_across_devices,
364
+ max_length = max_length,
365
+ max_prompt_length = max_prompt_length,
366
+ max_completion_length = max_completion_length,
367
+ disable_dropout = disable_dropout,
368
+ step_separator = step_separator,
369
+ train_on_last_step_only = train_on_last_step_only,
370
+ dataset_num_proc = dataset_num_proc,**kwargs)
371
+ self.vllm_sampling_params = vllm_sampling_params
372
+ self.unsloth_num_chunks = unsloth_num_chunks
373
+ pass
374
+
375
+ class _UnslothPRMTrainer(Trainer):
376
+ """"""
377
+
378
+ _tag_names = ["trl", "prm"]
379
+
380
+ def __init__(
381
+ self,
382
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
383
+ args: Optional[PRMConfig] = None,
384
+ data_collator: Optional[DataCollator] = None,
385
+ train_dataset: Optional[Dataset] = None,
386
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
387
+ processing_class: Optional[
388
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
389
+ ] = None,
390
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
391
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
392
+ callbacks: Optional[list[TrainerCallback]] = None,
393
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
394
+ None,
395
+ None,
396
+ ),
397
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
398
+ peft_config: Optional[dict] = None,
399
+ ):
400
+ if not is_peft_available() and peft_config is not None:
401
+ raise ValueError(
402
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
403
+ )
404
+ elif is_peft_available() and peft_config is not None:
405
+ if not isinstance(model, PeftModel):
406
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
407
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
408
+ inspect.signature(prepare_model_for_kbit_training).parameters
409
+ )
410
+
411
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
412
+
413
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
414
+ warnings.warn(
415
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
416
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
417
+ )
418
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
419
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
420
+
421
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
422
+
423
+ model = model
424
+
425
+ # Disable dropout in the model
426
+ if args.disable_dropout:
427
+ disable_dropout_in_model(model)
428
+
429
+ if compute_metrics is None:
430
+ compute_metrics = compute_accuracy
431
+
432
+ if data_collator is None:
433
+ if processing_class is None:
434
+ raise ValueError(
435
+ "A processing_class must be specified when using the default DataCollatorForTokenClassification"
436
+ )
437
+ data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
438
+
439
+ if "input_ids" not in train_dataset.column_names:
440
+ with PartialState().local_main_process_first():
441
+ fn_kwargs = {
442
+ "tokenizer": processing_class,
443
+ "step_separator": args.step_separator,
444
+ "max_length": args.max_length,
445
+ "max_prompt_length": args.max_prompt_length,
446
+ "max_completion_length": args.max_completion_length,
447
+ "train_on_last_step_only": args.train_on_last_step_only,
448
+ }
449
+ train_fn_kwargs = {**fn_kwargs, "is_eval": False}
450
+ train_dataset = train_dataset.map(
451
+ self.tokenize_row,
452
+ fn_kwargs=train_fn_kwargs,
453
+ num_proc=args.dataset_num_proc,
454
+ remove_columns=train_dataset.features,
455
+ desc="Tokenizing train dataset",
456
+ features=features.Features( # needed to avoid map to cast labels to bool
457
+ {
458
+ "labels": features.Sequence(features.Value("int64")),
459
+ "input_ids": features.Sequence(features.Value("int64")),
460
+ }
461
+ ),
462
+ )
463
+
464
+ eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
465
+ if eval_dataset is not None:
466
+ eval_dataset = eval_dataset.map(
467
+ self.tokenize_row,
468
+ fn_kwargs=eval_fn_kwargs,
469
+ num_proc=args.dataset_num_proc,
470
+ remove_columns=eval_dataset.features,
471
+ desc="Tokenizing eval dataset",
472
+ features=features.Features( # needed to avoid map to cast labels to bool
473
+ {
474
+ "labels": features.Sequence(features.Value("int64")),
475
+ "input_ids": features.Sequence(features.Value("int64")),
476
+ }
477
+ ),
478
+ )
479
+
480
+ super().__init__(
481
+ model=model,
482
+ args=args,
483
+ data_collator=data_collator,
484
+ train_dataset=train_dataset,
485
+ eval_dataset=eval_dataset,
486
+ processing_class=processing_class,
487
+ model_init=model_init,
488
+ compute_metrics=compute_metrics,
489
+ callbacks=callbacks,
490
+ optimizers=optimizers,
491
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
492
+ )
493
+
494
+ # Add tags for models that have been loaded with the correct transformers version
495
+ if hasattr(self.model, "add_model_tags"):
496
+ self.model.add_model_tags(self._tag_names)
497
+
498
+ @staticmethod
499
+ def tokenize_row(
500
+ features,
501
+ tokenizer,
502
+ step_separator,
503
+ max_length,
504
+ max_prompt_length,
505
+ max_completion_length,
506
+ train_on_last_step_only,
507
+ is_eval,
508
+ ):
509
+ r"""
510
+ Tokenize a row of the dataset.
511
+
512
+ Args:
513
+ features (`dict[str, str]`):
514
+ Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
515
+ tokenizer (`PreTrainedTokenizerBase`):
516
+ Tokenizer used to process the data.
517
+ step_separator (`str`):
518
+ Separator between steps in the completion.
519
+ max_length (`int` or `None`):
520
+ Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
521
+ max_prompt_length (`int` or `None`):
522
+ Maximum length of the prompt. If `None`, the prompt is not truncated.
523
+ max_completion_length (`int` or `None`):
524
+ Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
525
+ train_on_last_step_only (`bool`):
526
+ Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
527
+ token of the completion.
528
+ is_eval (`bool`):
529
+ Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
530
+
531
+ Returns:
532
+ `dict[str, list[int]]`:
533
+ Tokenized sequences with the keys `"input_ids"`, and `"labels".
534
+
535
+ Example:
536
+ ```python
537
+ >>> from transformers import AutoTokenizer
538
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
539
+ >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
540
+ ... "completions": ["11 is greater than 8.",
541
+ ... "Hence, 9.11 > 9.8."],
542
+ ... "labels": [True, False]}
543
+ >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
544
+ {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
545
+ 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
546
+ ```
547
+ """
548
+ # Tokenize the prompt and completions
549
+ prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
550
+ completions_ids = [
551
+ tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
552
+ ]
553
+ if train_on_last_step_only and not is_eval:
554
+ labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
555
+ else:
556
+ labels = [int(label) for label in features["labels"]]
557
+
558
+ # Get the ID of the separator token and add it to the completions
559
+ separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
560
+ completions_ids = [completion + separator_ids for completion in completions_ids]
561
+
562
+ # Create the label
563
+ labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
564
+
565
+ # Join the completions and labels steps
566
+ completion_ids = list(chain(*completions_ids))
567
+ labels = list(chain(*labels))
568
+
569
+ if tokenizer.bos_token_id is not None:
570
+ prompt_ids = [tokenizer.bos_token_id] + prompt_ids
571
+
572
+ # Truncate prompt and completion sequences
573
+ if max_prompt_length is not None:
574
+ prompt_ids = prompt_ids[-max_prompt_length:]
575
+ if max_completion_length is not None:
576
+ completion_ids = completion_ids[:max_completion_length]
577
+ labels = labels[:max_completion_length]
578
+
579
+ input_ids = prompt_ids + completion_ids
580
+ labels = [-100] * len(prompt_ids) + labels
581
+
582
+ if max_length is not None:
583
+ input_ids = input_ids[:max_length]
584
+ labels = labels[:max_length]
585
+
586
+ return {"input_ids": input_ids, "labels": labels}
587
+
588
+ def create_model_card(
589
+ self,
590
+ model_name: Optional[str] = None,
591
+ dataset_name: Optional[str] = None,
592
+ tags: Union[str, list[str], None] = None,
593
+ ):
594
+ """
595
+ Creates a draft of a model card using the information available to the `Trainer`.
596
+
597
+ Args:
598
+ model_name (`str` or `None`, *optional*, defaults to `None`):
599
+ Name of the model.
600
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
601
+ Name of the dataset used for training.
602
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
603
+ Tags to be associated with the model card.
604
+ """
605
+ if not self.is_world_process_zero():
606
+ return
607
+
608
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
609
+ base_model = self.model.config._name_or_path
610
+ else:
611
+ base_model = None
612
+
613
+ tags = tags or []
614
+ if isinstance(tags, str):
615
+ tags = [tags]
616
+
617
+ if hasattr(self.model.config, "unsloth_version"):
618
+ tags.append("unsloth")
619
+
620
+ citation = textwrap.dedent("""\
621
+ @article{uesato2022solving,
622
+ title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
623
+ author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
624
+ year = 2022,
625
+ journal = {arXiv preprint arXiv:2211.14275}
626
+ }""")
627
+
628
+ model_card = generate_model_card(
629
+ base_model=base_model,
630
+ model_name=model_name,
631
+ hub_model_id=self.hub_model_id,
632
+ dataset_name=dataset_name,
633
+ tags=tags,
634
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
635
+ trainer_name="PRM",
636
+ trainer_citation=citation,
637
+ paper_title="Solving math word problems with process-and outcome-based feedback",
638
+ )
639
+
640
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
641
+ class UnslothPRMTrainer(_UnslothPRMTrainer):
642
+ """
643
+
644
+ Initialize PRMTrainer.
645
+
646
+ Args:
647
+ model (`transformers.PreTrainedModel`):
648
+ The model to train, preferably an `AutoModelForTokenClassification`.
649
+ args (`PRMConfig`):
650
+ The arguments to use for training.
651
+ data_collator (`transformers.DataCollator`):
652
+ The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
653
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
654
+ train_dataset (`datasets.Dataset`):
655
+ The dataset to use for training.
656
+ eval_dataset (`datasets.Dataset`):
657
+ The dataset to use for evaluation.
658
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
659
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
660
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
661
+ reuse the fine-tuned model.
662
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
663
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
664
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
665
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
666
+ callbacks (`list[transformers.TrainerCallback]`):
667
+ The callbacks to use for training.
668
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
669
+ The optimizer and scheduler to use for training.
670
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
671
+ The function to use to preprocess the logits before computing the metrics.
672
+ peft_config (`dict`, defaults to `None`):
673
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
674
+
675
+ """
676
+ def __init__(
677
+ self,
678
+ model = None,
679
+ args = None,
680
+ data_collator = None,
681
+ train_dataset = None,
682
+ eval_dataset = None,
683
+ processing_class = None,
684
+ model_init = None,
685
+ compute_metrics = None,
686
+ callbacks = None,
687
+ preprocess_logits_for_metrics = None,
688
+ peft_config = None,
689
+ **kwargs
690
+ ):
691
+ if args is None: args = UnslothPRMConfig()
692
+ use_bf16 = getattr(args, 'bf16', False)
693
+ use_fp16 = getattr(args, 'fp16', False)
694
+ force_float32 = False
695
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
696
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
697
+ force_float32 = True
698
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
699
+ dtype = getattr(model.config, 'torch_dtype', None)
700
+ if dtype is None: dtype = model.get_input_embeddings().dtype
701
+ from unsloth_zoo.utils import _get_dtype
702
+ dtype = _get_dtype(dtype)
703
+ float16 = dtype == torch.float16
704
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
705
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
706
+ if force_float32:
707
+ args.fp16 = False
708
+ args.bf16 = False
709
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
710
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
711
+ args.fp16 = float16
712
+ args.bf16 = not float16
713
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
714
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
715
+ args.eval_strategy = 'steps'
716
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
717
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
718
+ if ga_steps is not None and ga_steps > 1:
719
+ from transformers import __version__ as transformers_version
720
+ if Version(transformers_version) <= Version('4.45.2'):
721
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
722
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
723
+ if getattr(args, 'eval_strategy', 'no') != 'no':
724
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
725
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
726
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
727
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
728
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
729
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
730
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
731
+ if force_float32:
732
+ args.bf16_full_eval = False
733
+ args.fp16_full_eval = False
734
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
735
+ args.bf16_full_eval = True
736
+ args.fp16_full_eval = False
737
+ elif not bf16_full_eval and not fp16_full_eval:
738
+ args.bf16_full_eval = args.bf16
739
+ args.fp16_full_eval = args.fp16
740
+ _output_logits = False
741
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
742
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
743
+ if _output_logits:
744
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
745
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
746
+ pass
747
+ else:
748
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
749
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
750
+ if args_max_seq_length is None and model_max_seq_length is not None:
751
+ max_seq_length = model.max_seq_length
752
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
753
+ if model is not None and hasattr(model, 'for_training'):
754
+ model.for_training()
755
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
756
+ if 'processing_class' in locals():
757
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
758
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
759
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
760
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
761
+ if not isinstance(data_collator, UnslothVisionDataCollator):
762
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
763
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
764
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
765
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
766
+ else:
767
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
768
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
769
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
770
+ if not isinstance(data_collator, UnslothVisionDataCollator):
771
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
772
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
773
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
774
+ else:
775
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
776
+ other_metrics = []
777
+
778
+ from unsloth_zoo.logging_utils import PatchRLStatistics
779
+ PatchRLStatistics('prm_trainer', other_metrics)
780
+
781
+ super().__init__(
782
+ model = model,
783
+ args = args,
784
+ data_collator = data_collator,
785
+ train_dataset = train_dataset,
786
+ eval_dataset = eval_dataset,
787
+ processing_class = processing_class,
788
+ model_init = model_init,
789
+ compute_metrics = compute_metrics,
790
+ callbacks = callbacks,
791
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
792
+ peft_config = peft_config,**kwargs)
793
+ if hasattr(self, 'neftune_hook_handle'):
794
+ self.neftune_hook_handle.remove()
795
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
796
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
797
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
798
+ pass
799
+
800
+ pass
unsloth_compiled_cache/UnslothRLOOTrainer.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRLOOConfig(RLOOConfig):
44
+ """
45
+
46
+ Configuration class for the [`RLOOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
54
+ Name of this experiment.
55
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
56
+ Path to the reward model.
57
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
58
+ Number of epochs to train.
59
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
60
+ Whether to whiten the rewards.
61
+ kl_coef (`float`, *optional*, defaults to `0.05`):
62
+ KL coefficient.
63
+ cliprange (`float`, *optional*, defaults to `0.2`):
64
+ Clip range.
65
+ rloo_k (`int`, *optional*, defaults to `2`):
66
+ REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
67
+ normalize_reward (`bool`, *optional*, defaults to `False`):
68
+ Whether to normalize rewards.
69
+ reward_clip_range (`float`, *optional*, defaults to `10.0`):
70
+ Clip range for rewards.
71
+ normalize_advantage (`bool`, *optional*, defaults to `False`):
72
+ Whether to normalize advantages.
73
+ token_level_kl (`bool`, *optional*, defaults to `True`):
74
+ Whether to use token-level KL penalty or sequence-level KL penalty.
75
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
76
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
77
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
78
+ capacity of a single GPU, albeit at the cost of slower generation.
79
+
80
+ """
81
+ vllm_sampling_params: Optional[Any] = field(
82
+ default = None,
83
+ metadata = {'help': 'vLLM SamplingParams'},
84
+ )
85
+ unsloth_num_chunks : Optional[int] = field(
86
+ default = -1,
87
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
+ )
89
+ def __init__(
90
+ self,
91
+ output_dir = None,
92
+ overwrite_output_dir = None,
93
+ do_train = False,
94
+ do_eval = False,
95
+ do_predict = False,
96
+ eval_strategy = 'no',
97
+ prediction_loss_only = False,
98
+ per_device_train_batch_size = 4,
99
+ per_device_eval_batch_size = 4,
100
+ per_gpu_train_batch_size = None,
101
+ per_gpu_eval_batch_size = None,
102
+ gradient_accumulation_steps = 2,
103
+ eval_accumulation_steps = 2,
104
+ eval_delay = 0,
105
+ torch_empty_cache_steps = 250,
106
+ learning_rate = 5e-05,
107
+ weight_decay = 0.01,
108
+ adam_beta1 = 0.9,
109
+ adam_beta2 = 0.999,
110
+ adam_epsilon = 1e-08,
111
+ max_grad_norm = 1.0,
112
+ num_train_epochs = 3.0,
113
+ max_steps = -1,
114
+ lr_scheduler_type = 'linear',
115
+ warmup_ratio = 0.1,
116
+ warmup_steps = 0,
117
+ log_level = 'passive',
118
+ log_level_replica = 'warning',
119
+ log_on_each_node = True,
120
+ logging_dir = None,
121
+ logging_strategy = 'steps',
122
+ logging_first_step = False,
123
+ logging_steps = 1,
124
+ logging_nan_inf_filter = False,
125
+ save_strategy = 'steps',
126
+ save_steps = 500,
127
+ save_total_limit = None,
128
+ save_safetensors = True,
129
+ save_on_each_node = False,
130
+ save_only_model = False,
131
+ restore_callback_states_from_checkpoint = False,
132
+ no_cuda = False,
133
+ use_cpu = False,
134
+ use_mps_device = False,
135
+ seed = 3407,
136
+ data_seed = 3407,
137
+ jit_mode_eval = False,
138
+ use_ipex = False,
139
+ bf16 = False,
140
+ fp16 = False,
141
+ fp16_opt_level = 'O1',
142
+ half_precision_backend = 'auto',
143
+ bf16_full_eval = False,
144
+ fp16_full_eval = False,
145
+ tf32 = None,
146
+ local_rank = -1,
147
+ ddp_backend = None,
148
+ tpu_num_cores = None,
149
+ tpu_metrics_debug = False,
150
+ debug = '',
151
+ dataloader_drop_last = False,
152
+ eval_steps = None,
153
+ dataloader_num_workers = 0,
154
+ dataloader_prefetch_factor = None,
155
+ past_index = -1,
156
+ run_name = None,
157
+ disable_tqdm = None,
158
+ remove_unused_columns = True,
159
+ label_names = None,
160
+ load_best_model_at_end = False,
161
+ metric_for_best_model = None,
162
+ greater_is_better = None,
163
+ ignore_data_skip = False,
164
+ fsdp = '',
165
+ fsdp_min_num_params = 0,
166
+ fsdp_config = None,
167
+ tp_size = 0,
168
+ fsdp_transformer_layer_cls_to_wrap = None,
169
+ accelerator_config = None,
170
+ deepspeed = None,
171
+ label_smoothing_factor = 0.0,
172
+ optim = 'adamw_8bit',
173
+ optim_args = None,
174
+ adafactor = False,
175
+ group_by_length = False,
176
+ length_column_name = 'length',
177
+ report_to = None,
178
+ ddp_find_unused_parameters = None,
179
+ ddp_bucket_cap_mb = None,
180
+ ddp_broadcast_buffers = None,
181
+ dataloader_pin_memory = True,
182
+ dataloader_persistent_workers = False,
183
+ skip_memory_metrics = True,
184
+ use_legacy_prediction_loop = False,
185
+ push_to_hub = False,
186
+ resume_from_checkpoint = None,
187
+ hub_model_id = None,
188
+ hub_strategy = 'every_save',
189
+ hub_token = None,
190
+ hub_private_repo = None,
191
+ hub_always_push = False,
192
+ gradient_checkpointing = False,
193
+ gradient_checkpointing_kwargs = None,
194
+ include_inputs_for_metrics = False,
195
+ eval_do_concat_batches = True,
196
+ fp16_backend = 'auto',
197
+ evaluation_strategy = None,
198
+ push_to_hub_model_id = None,
199
+ push_to_hub_organization = None,
200
+ push_to_hub_token = None,
201
+ mp_parameters = '',
202
+ auto_find_batch_size = False,
203
+ full_determinism = False,
204
+ torchdynamo = None,
205
+ ray_scope = 'last',
206
+ ddp_timeout = 1800,
207
+ torch_compile = False,
208
+ torch_compile_backend = None,
209
+ torch_compile_mode = None,
210
+ dispatch_batches = None,
211
+ split_batches = None,
212
+ include_tokens_per_second = False,
213
+ include_num_input_tokens_seen = False,
214
+ neftune_noise_alpha = None,
215
+ optim_target_modules = None,
216
+ batch_eval_metrics = False,
217
+ eval_on_start = False,
218
+ use_liger_kernel = False,
219
+ eval_use_gather_object = False,
220
+ average_tokens_across_devices = False,
221
+ dataset_num_proc = None,
222
+ num_mini_batches = 1,
223
+ total_episodes = None,
224
+ local_rollout_forward_batch_size = 64,
225
+ num_sample_generations = 10,
226
+ response_length = 53,
227
+ stop_token = None,
228
+ stop_token_id = None,
229
+ temperature = 0.7,
230
+ missing_eos_penalty = None,
231
+ sft_model_path = 'EleutherAI/pythia-160m',
232
+ world_size = None,
233
+ num_total_batches = None,
234
+ micro_batch_size = None,
235
+ local_batch_size = None,
236
+ batch_size = None,
237
+ local_mini_batch_size = None,
238
+ mini_batch_size = None,
239
+ exp_name = 'rloo_config',
240
+ reward_model_path = 'EleutherAI/pythia-160m',
241
+ num_ppo_epochs = 4,
242
+ whiten_rewards = False,
243
+ kl_coef = 0.05,
244
+ cliprange = 0.2,
245
+ rloo_k = 2,
246
+ normalize_reward = False,
247
+ reward_clip_range = 10.0,
248
+ normalize_advantage = False,
249
+ token_level_kl = False,
250
+ ds3_gather_for_generation = True,
251
+ vllm_sampling_params = None,
252
+ unsloth_num_chunks = -1,
253
+ **kwargs,
254
+ ):
255
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
256
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
257
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
258
+ output_dir = 'unsloth_training_checkpoints'
259
+ save_strategy = 'no'
260
+ if dataset_num_proc is None:
261
+ from multiprocessing import cpu_count
262
+ dataset_num_proc = cpu_count()
263
+
264
+ super().__init__(
265
+ output_dir = output_dir,
266
+ overwrite_output_dir = overwrite_output_dir,
267
+ do_train = do_train,
268
+ do_eval = do_eval,
269
+ do_predict = do_predict,
270
+ eval_strategy = eval_strategy,
271
+ prediction_loss_only = prediction_loss_only,
272
+ per_device_train_batch_size = per_device_train_batch_size,
273
+ per_device_eval_batch_size = per_device_eval_batch_size,
274
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
275
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
276
+ gradient_accumulation_steps = gradient_accumulation_steps,
277
+ eval_accumulation_steps = eval_accumulation_steps,
278
+ eval_delay = eval_delay,
279
+ torch_empty_cache_steps = torch_empty_cache_steps,
280
+ learning_rate = learning_rate,
281
+ weight_decay = weight_decay,
282
+ adam_beta1 = adam_beta1,
283
+ adam_beta2 = adam_beta2,
284
+ adam_epsilon = adam_epsilon,
285
+ max_grad_norm = max_grad_norm,
286
+ num_train_epochs = num_train_epochs,
287
+ max_steps = max_steps,
288
+ lr_scheduler_type = lr_scheduler_type,
289
+ warmup_ratio = warmup_ratio,
290
+ warmup_steps = warmup_steps,
291
+ log_level = log_level,
292
+ log_level_replica = log_level_replica,
293
+ log_on_each_node = log_on_each_node,
294
+ logging_dir = logging_dir,
295
+ logging_strategy = logging_strategy,
296
+ logging_first_step = logging_first_step,
297
+ logging_steps = logging_steps,
298
+ logging_nan_inf_filter = logging_nan_inf_filter,
299
+ save_strategy = save_strategy,
300
+ save_steps = save_steps,
301
+ save_total_limit = save_total_limit,
302
+ save_safetensors = save_safetensors,
303
+ save_on_each_node = save_on_each_node,
304
+ save_only_model = save_only_model,
305
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
306
+ no_cuda = no_cuda,
307
+ use_cpu = use_cpu,
308
+ use_mps_device = use_mps_device,
309
+ seed = seed,
310
+ data_seed = data_seed,
311
+ jit_mode_eval = jit_mode_eval,
312
+ use_ipex = use_ipex,
313
+ bf16 = bf16,
314
+ fp16 = fp16,
315
+ fp16_opt_level = fp16_opt_level,
316
+ half_precision_backend = half_precision_backend,
317
+ bf16_full_eval = bf16_full_eval,
318
+ fp16_full_eval = fp16_full_eval,
319
+ tf32 = tf32,
320
+ local_rank = local_rank,
321
+ ddp_backend = ddp_backend,
322
+ tpu_num_cores = tpu_num_cores,
323
+ tpu_metrics_debug = tpu_metrics_debug,
324
+ debug = debug,
325
+ dataloader_drop_last = dataloader_drop_last,
326
+ eval_steps = eval_steps,
327
+ dataloader_num_workers = dataloader_num_workers,
328
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
329
+ past_index = past_index,
330
+ run_name = run_name,
331
+ disable_tqdm = disable_tqdm,
332
+ remove_unused_columns = remove_unused_columns,
333
+ label_names = label_names,
334
+ load_best_model_at_end = load_best_model_at_end,
335
+ metric_for_best_model = metric_for_best_model,
336
+ greater_is_better = greater_is_better,
337
+ ignore_data_skip = ignore_data_skip,
338
+ fsdp = fsdp,
339
+ fsdp_min_num_params = fsdp_min_num_params,
340
+ fsdp_config = fsdp_config,
341
+ tp_size = tp_size,
342
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
343
+ accelerator_config = accelerator_config,
344
+ deepspeed = deepspeed,
345
+ label_smoothing_factor = label_smoothing_factor,
346
+ optim = optim,
347
+ optim_args = optim_args,
348
+ adafactor = adafactor,
349
+ group_by_length = group_by_length,
350
+ length_column_name = length_column_name,
351
+ report_to = report_to,
352
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
353
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
354
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
355
+ dataloader_pin_memory = dataloader_pin_memory,
356
+ dataloader_persistent_workers = dataloader_persistent_workers,
357
+ skip_memory_metrics = skip_memory_metrics,
358
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
359
+ push_to_hub = push_to_hub,
360
+ resume_from_checkpoint = resume_from_checkpoint,
361
+ hub_model_id = hub_model_id,
362
+ hub_strategy = hub_strategy,
363
+ hub_token = hub_token,
364
+ hub_private_repo = hub_private_repo,
365
+ hub_always_push = hub_always_push,
366
+ gradient_checkpointing = gradient_checkpointing,
367
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
368
+ include_inputs_for_metrics = include_inputs_for_metrics,
369
+ eval_do_concat_batches = eval_do_concat_batches,
370
+ fp16_backend = fp16_backend,
371
+ evaluation_strategy = evaluation_strategy,
372
+ push_to_hub_model_id = push_to_hub_model_id,
373
+ push_to_hub_organization = push_to_hub_organization,
374
+ push_to_hub_token = push_to_hub_token,
375
+ mp_parameters = mp_parameters,
376
+ auto_find_batch_size = auto_find_batch_size,
377
+ full_determinism = full_determinism,
378
+ torchdynamo = torchdynamo,
379
+ ray_scope = ray_scope,
380
+ ddp_timeout = ddp_timeout,
381
+ torch_compile = torch_compile,
382
+ torch_compile_backend = torch_compile_backend,
383
+ torch_compile_mode = torch_compile_mode,
384
+ dispatch_batches = dispatch_batches,
385
+ split_batches = split_batches,
386
+ include_tokens_per_second = include_tokens_per_second,
387
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
388
+ neftune_noise_alpha = neftune_noise_alpha,
389
+ optim_target_modules = optim_target_modules,
390
+ batch_eval_metrics = batch_eval_metrics,
391
+ eval_on_start = eval_on_start,
392
+ use_liger_kernel = use_liger_kernel,
393
+ eval_use_gather_object = eval_use_gather_object,
394
+ average_tokens_across_devices = average_tokens_across_devices,
395
+ dataset_num_proc = dataset_num_proc,
396
+ num_mini_batches = num_mini_batches,
397
+ total_episodes = total_episodes,
398
+ local_rollout_forward_batch_size = local_rollout_forward_batch_size,
399
+ num_sample_generations = num_sample_generations,
400
+ response_length = response_length,
401
+ stop_token = stop_token,
402
+ stop_token_id = stop_token_id,
403
+ temperature = temperature,
404
+ missing_eos_penalty = missing_eos_penalty,
405
+ sft_model_path = sft_model_path,
406
+ world_size = world_size,
407
+ num_total_batches = num_total_batches,
408
+ micro_batch_size = micro_batch_size,
409
+ local_batch_size = local_batch_size,
410
+ batch_size = batch_size,
411
+ local_mini_batch_size = local_mini_batch_size,
412
+ mini_batch_size = mini_batch_size,
413
+ exp_name = exp_name,
414
+ reward_model_path = reward_model_path,
415
+ num_ppo_epochs = num_ppo_epochs,
416
+ whiten_rewards = whiten_rewards,
417
+ kl_coef = kl_coef,
418
+ cliprange = cliprange,
419
+ rloo_k = rloo_k,
420
+ normalize_reward = normalize_reward,
421
+ reward_clip_range = reward_clip_range,
422
+ normalize_advantage = normalize_advantage,
423
+ token_level_kl = token_level_kl,
424
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
425
+ self.vllm_sampling_params = vllm_sampling_params
426
+ self.unsloth_num_chunks = unsloth_num_chunks
427
+ pass
428
+
429
+ class _UnslothRLOOTrainer(Trainer):
430
+ _tag_names = ["trl", "rloo"]
431
+
432
+ def __init__(
433
+ self,
434
+ config: RLOOConfig,
435
+ processing_class: Optional[
436
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
437
+ ],
438
+ policy: nn.Module,
439
+ ref_policy: nn.Module,
440
+ reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
441
+ train_dataset: Dataset,
442
+ data_collator: Optional[DataCollatorWithPadding] = None,
443
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
444
+ # less commonly used
445
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
446
+ callbacks: Optional[list[TrainerCallback]] = None,
447
+ ) -> None:
448
+ if ref_policy is policy:
449
+ raise ValueError(
450
+ "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
451
+ "same as `policy`, you must mass a copy of it, or `None` if you use peft."
452
+ )
453
+
454
+ self.args = config
455
+ args = config
456
+ self.processing_class = processing_class
457
+ self.policy = policy
458
+
459
+ # Define the collator if not provided
460
+ if data_collator is None:
461
+ data_collator = DataCollatorWithPadding(self.processing_class)
462
+
463
+ self.policy.generation_config.eos_token_id = (
464
+ None # disable `pad_token_id` and `eos_token_id` because we just want to
465
+ )
466
+ self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
467
+
468
+ self.ref_policy = ref_policy
469
+ self.reward_model = reward_model
470
+ self.train_dataset = train_dataset
471
+ self.train_dataset_len = len(train_dataset)
472
+ self.data_collator = data_collator
473
+ self.eval_dataset = eval_dataset
474
+ self.optimizer, self.lr_scheduler = optimizers
475
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
476
+
477
+ #########
478
+ # calculate various batch sizes
479
+ #########
480
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
481
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
482
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
483
+ self.accelerator = accelerator
484
+ args.world_size = accelerator.num_processes
485
+ args.local_batch_size = (
486
+ args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
487
+ )
488
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
489
+ args.batch_size = int(args.local_batch_size * args.world_size)
490
+ args.mini_batch_size = exact_div(
491
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
492
+ )
493
+ args.local_mini_batch_size = exact_div(
494
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
495
+ )
496
+ args.num_total_batches = math.ceil(
497
+ args.total_episodes / args.batch_size
498
+ ) # we may train for more than `total_episodes`
499
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
500
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
501
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
502
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
503
+ if args.num_sample_generations > 0:
504
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
505
+ self.local_dataloader_batch_size = exact_div(
506
+ args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
507
+ ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
508
+
509
+ #########
510
+ # setup model, optimizer, and others
511
+ #########
512
+ for module in [policy, ref_policy, reward_model]:
513
+ if isinstance(module, nn.Module):
514
+ disable_dropout_in_model(module)
515
+ if args.stop_token and args.stop_token == "eos":
516
+ args.stop_token_id = self.processing_class.eos_token_id
517
+ self.model = policy
518
+ self.create_optimizer_and_scheduler(
519
+ num_training_steps=args.num_total_batches
520
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
521
+
522
+ #########
523
+ ### trainer specifics
524
+ #########
525
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
526
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
527
+ self.callback_handler = CallbackHandler(
528
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
529
+ )
530
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
531
+ self.control = TrainerControl()
532
+ self.state = OnlineTrainerState(
533
+ is_local_process_zero=self.is_local_process_zero(),
534
+ is_world_process_zero=self.is_world_process_zero(),
535
+ stateful_callbacks=[
536
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
537
+ ],
538
+ )
539
+
540
+ self.current_flos = 0
541
+ self.hp_search_backend = None
542
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
543
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
544
+ # Create distant repo and output directory if needed
545
+ self.hub_model_id = None
546
+ if self.args.push_to_hub:
547
+ self.init_hf_repo()
548
+ if self.args.should_save:
549
+ os.makedirs(self.args.output_dir, exist_ok=True)
550
+ self.backup_model = None
551
+
552
+ # Add tags for models that have been loaded with the correct transformers version
553
+ if hasattr(self.model, "add_model_tags"):
554
+ self.model.add_model_tags(self._tag_names)
555
+
556
+ #########
557
+ ### setup dataloader
558
+ #########
559
+ self.dataloader = DataLoader(
560
+ self.train_dataset,
561
+ batch_size=self.local_dataloader_batch_size,
562
+ shuffle=True,
563
+ collate_fn=self.data_collator,
564
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
565
+ )
566
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
567
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
568
+ torch.manual_seed(args.seed)
569
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
570
+ torch.manual_seed(self.local_seed) # reset the local seed again
571
+
572
+ self.eval_dataloader = DataLoader(
573
+ self.eval_dataset,
574
+ batch_size=args.per_device_eval_batch_size,
575
+ collate_fn=self.data_collator,
576
+ drop_last=True,
577
+ ) # no need to shuffle eval dataset
578
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
579
+
580
+ if self.is_deepspeed_enabled:
581
+ if isinstance(self.reward_model, nn.Module):
582
+ self.reward_model = prepare_deepspeed(
583
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
584
+ )
585
+ self.ref_policy = prepare_deepspeed(
586
+ self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
587
+ )
588
+ self.deepspeed = self.model
589
+ else:
590
+ self.ref_policy = self.ref_policy.to(self.accelerator.device)
591
+ if isinstance(self.reward_model, nn.Module):
592
+ self.reward_model = self.reward_model.to(self.accelerator.device)
593
+
594
+ def get_train_dataloader(self) -> DataLoader:
595
+ return self.dataloader
596
+
597
+ def get_eval_dataloader(self) -> DataLoader:
598
+ return self.eval_dataloader
599
+
600
+ def train(self):
601
+ args = self.args
602
+ accelerator = self.accelerator
603
+ optimizer = self.optimizer
604
+ model = self.model
605
+ self.model_wrapped = self.model
606
+ ref_policy = self.ref_policy
607
+ reward_model = self.reward_model
608
+ processing_class = self.processing_class
609
+ dataloader = self.dataloader
610
+ device = accelerator.device
611
+
612
+ def repeat_generator():
613
+ while True:
614
+ yield from dataloader
615
+
616
+ iter_dataloader = iter(repeat_generator())
617
+ generation_config = GenerationConfig(
618
+ max_new_tokens=args.response_length,
619
+ temperature=(args.temperature + 1e-7),
620
+ top_k=0.0,
621
+ top_p=1.0,
622
+ do_sample=True,
623
+ )
624
+
625
+ accelerator.print("===training policy===")
626
+ start_time = time.time()
627
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
628
+ approxkl_stats = torch.zeros(stats_shape, device=device)
629
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
630
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
631
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
632
+ entropy_stats = torch.zeros(stats_shape, device=device)
633
+ ratio_stats = torch.zeros(stats_shape, device=device)
634
+ model.train()
635
+
636
+ # trainer state initialization
637
+ self.state.global_step = 0
638
+ self.state.episode = 0
639
+ self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
640
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
641
+ # Compute absolute values for logging, eval, and save if given as ratio
642
+ if args.logging_steps is not None:
643
+ if args.logging_steps < 1:
644
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
645
+ else:
646
+ self.state.logging_steps = args.logging_steps
647
+ if args.eval_steps is not None:
648
+ if args.eval_steps < 1:
649
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
650
+ else:
651
+ self.state.eval_steps = args.eval_steps
652
+ if args.save_steps is not None:
653
+ if args.save_steps < 1:
654
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
655
+ else:
656
+ self.state.save_steps = args.save_steps
657
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
658
+
659
+ for update in range(1, args.num_total_batches + 1):
660
+ self.state.episode += 1 * args.batch_size
661
+ data = next(iter_dataloader)
662
+ with torch.no_grad():
663
+ queries = data["input_ids"].to(device)
664
+ queries = queries.repeat(args.rloo_k, 1)
665
+ context_length = queries.shape[1]
666
+ responses = []
667
+ postprocessed_responses = []
668
+ logprobs = []
669
+ ref_logprobs = []
670
+ scores = []
671
+ sequence_lengths = []
672
+
673
+ # Generate responses and compute logprobs
674
+ with unwrap_model_for_generation(
675
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
676
+ ) as unwrapped_model:
677
+ query_responses, logitss = batch_generation(
678
+ unwrapped_model,
679
+ queries,
680
+ args.local_rollout_forward_batch_size,
681
+ processing_class.pad_token_id,
682
+ generation_config,
683
+ )
684
+
685
+ # Process responses in batches
686
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
687
+ query = queries[i : i + args.local_rollout_forward_batch_size]
688
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
689
+ response = query_response[:, context_length:]
690
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
691
+ logprob = selective_log_softmax(logits, response)
692
+ del logits
693
+ torch.cuda.empty_cache()
694
+
695
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
696
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
697
+ ref_logits /= args.temperature + 1e-7
698
+ ref_logprob = selective_log_softmax(ref_logits, response)
699
+ del ref_output, ref_logits
700
+ torch.cuda.empty_cache()
701
+
702
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
703
+ postprocessed_response = response
704
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
705
+ postprocessed_response = truncate_response(
706
+ args.stop_token_id, processing_class.pad_token_id, response
707
+ )
708
+
709
+ # Response Processing 2. run reward model on the truncated responses
710
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
711
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
712
+
713
+ if isinstance(reward_model, nn.Module):
714
+ _, score, _ = get_reward(
715
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
716
+ )
717
+ else:
718
+ score = torch.tensor(
719
+ reward_model(
720
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
721
+ ),
722
+ dtype=torch.float,
723
+ ).to(device)
724
+
725
+ # Store batch results
726
+ responses.append(response)
727
+ postprocessed_responses.append(postprocessed_response)
728
+ logprobs.append(logprob)
729
+ ref_logprobs.append(ref_logprob)
730
+ sequence_lengths.append(sequence_length)
731
+ scores.append(score)
732
+
733
+ # Concatenate all batched results
734
+ responses = torch.cat(responses, 0)
735
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
736
+ logprobs = torch.cat(logprobs, 0)
737
+ ref_logprobs = torch.cat(ref_logprobs, 0)
738
+ sequence_lengths = torch.cat(sequence_lengths, 0)
739
+ scores = torch.cat(scores, 0)
740
+ del (logprob, ref_logprob, score)
741
+ torch.cuda.empty_cache()
742
+ gc.collect()
743
+
744
+ # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
745
+ # responses not passing that filter will receive a low (fixed) score
746
+ # only query humans on responses that pass that filter
747
+ contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
748
+ if args.missing_eos_penalty is not None:
749
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
750
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
751
+
752
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
753
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
754
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
755
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
756
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
757
+
758
+ # 4. compute rewards
759
+ # Compute KL divergence
760
+ kl = logprobs - ref_logprobs
761
+
762
+ # Normalize rewards
763
+ if args.normalize_reward:
764
+ scores = (scores - scores.mean()) / (scores.std() + 1e-8)
765
+ scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
766
+
767
+ # Compute total reward with KL penalty
768
+ if args.token_level_kl:
769
+ # Token-level KL penalty: apply KL penalty per token
770
+ kl_reward = -args.kl_coef * kl
771
+
772
+ # Get the index of the last non-padded token for each sequence
773
+ eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
774
+ last_reward = torch.zeros_like(kl)
775
+ # Ensure scores has correct shape and type
776
+ scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
777
+ last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
778
+
779
+ # Combine KL reward and last reward
780
+ non_score_reward = kl_reward.sum(1) # Keep this for logging
781
+ reward = last_reward + kl_reward
782
+ rlhf_reward = reward.sum(1) # Sum across sequence length
783
+ else:
784
+ # Sequence-level KL penalty: sum KL across tokens first
785
+ sequence_kl = kl.sum(1)
786
+ non_score_reward = -args.kl_coef * sequence_kl
787
+ rlhf_reward = non_score_reward + scores
788
+
789
+ # vectorized RLOO advantages implementation
790
+ rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
791
+ baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
792
+ advantages = rlhf_reward - baseline
793
+ advantages = advantages.flatten()
794
+
795
+ # Normalize advantages
796
+ if args.normalize_advantage:
797
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
798
+
799
+ torch.cuda.empty_cache()
800
+
801
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
802
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
803
+ b_inds = np.random.permutation(args.local_batch_size)
804
+ minibatch_idx = 0
805
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
806
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
807
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
808
+ gradient_accumulation_idx = 0
809
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
810
+ with accelerator.accumulate(model):
811
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
812
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
813
+
814
+ # Get batch data
815
+ mb_advantage = advantages[micro_batch_inds]
816
+ mb_responses = responses[micro_batch_inds]
817
+ mb_query_responses = query_responses[micro_batch_inds]
818
+ mb_logprobs = logprobs[micro_batch_inds]
819
+
820
+ # Forward pass
821
+ output = forward(model, mb_query_responses, processing_class.pad_token_id)
822
+ logits = output.logits[:, context_length - 1 : -1]
823
+ logits /= args.temperature + 1e-7
824
+
825
+ # Compute new logprobs
826
+ new_logprobs = selective_log_softmax(logits, mb_responses)
827
+ new_logprobs = torch.masked_fill(
828
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
829
+ )
830
+
831
+ # Compute probability ratios
832
+ new_ratio = (new_logprobs - mb_logprobs).exp()
833
+ new_logprobs = new_logprobs.sum(1)
834
+ mb_logprobs = mb_logprobs.sum(1)
835
+ logprobs_diff = new_logprobs - mb_logprobs
836
+ ratio = torch.exp(logprobs_diff)
837
+
838
+ # PPO clipped loss
839
+ pg_losses = -mb_advantage * ratio
840
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
841
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
842
+ pg_loss = pg_loss_max.mean()
843
+
844
+ # Final loss
845
+ loss = pg_loss
846
+
847
+ # Optimization step
848
+ accelerator.backward(loss)
849
+ optimizer.step()
850
+ optimizer.zero_grad()
851
+
852
+ with torch.no_grad():
853
+ pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
854
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
855
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
856
+ approxkl = 0.5 * (logprobs_diff**2).mean()
857
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
858
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
859
+ pg_clipfrac
860
+ )
861
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
862
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
863
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
864
+ gradient_accumulation_idx += 1
865
+ minibatch_idx += 1
866
+
867
+ # del everything and empty cache
868
+ # fmt: off
869
+ del (
870
+ output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
871
+ pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
872
+ mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
873
+ )
874
+ # fmt: on
875
+ torch.cuda.empty_cache()
876
+
877
+ # Compute metrics
878
+ with torch.no_grad():
879
+ mean_kl = kl.sum(1).mean()
880
+ mean_entropy = (-logprobs).sum(1).mean()
881
+ mean_non_score_reward = non_score_reward.mean()
882
+ eps = int(self.state.episode / (time.time() - start_time))
883
+ metrics = {}
884
+ metrics["eps"] = eps
885
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
886
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
887
+ metrics["objective/non_score_reward"] = (
888
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
889
+ )
890
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
891
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
892
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
893
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
894
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
895
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
896
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
897
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
898
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
899
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
900
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
901
+ metrics["episode"] = self.state.episode
902
+ self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
903
+ self.log(metrics)
904
+ del kl, mean_kl, mean_entropy, scores
905
+
906
+ self.lr_scheduler.step()
907
+ self.state.global_step += 1
908
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
909
+ if self.control.should_save:
910
+ self._save_checkpoint(model, trial=None)
911
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
912
+ torch.cuda.empty_cache()
913
+ gc.collect()
914
+
915
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
916
+ self.generate_completions(sampling=True)
917
+
918
+ # HF trainer specifics
919
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
920
+ if self.control.should_save:
921
+ self._save_checkpoint(model, trial=None, metrics=None)
922
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
923
+
924
+ def generate_completions(self, sampling: bool = False):
925
+ args = self.args
926
+ processing_class = self.processing_class
927
+ generation_config = GenerationConfig(
928
+ max_new_tokens=self.args.response_length,
929
+ temperature=(0.01 + 1e-7),
930
+ top_k=0.0,
931
+ top_p=1.0,
932
+ do_sample=True,
933
+ )
934
+
935
+ table = defaultdict(list)
936
+ with unwrap_model_for_generation(
937
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
938
+ ) as unwrapped_model:
939
+ for batch in self.eval_dataloader:
940
+ query = batch["input_ids"]
941
+ with torch.no_grad():
942
+ context_length = query.shape[1]
943
+ query_response, _ = batch_generation(
944
+ unwrapped_model,
945
+ query,
946
+ query.shape[0],
947
+ processing_class.pad_token_id,
948
+ generation_config,
949
+ )
950
+ response = query_response[:, context_length:]
951
+ postprocessed_response = response
952
+ if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
953
+ postprocessed_response = truncate_response(
954
+ args.stop_token_id, processing_class.pad_token_id, response
955
+ )
956
+ table["query"].extend(
957
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
958
+ )
959
+ table["model response"].extend(
960
+ gather_object(processing_class.batch_decode(postprocessed_response))
961
+ )
962
+
963
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
964
+
965
+ if isinstance(self.reward_model, nn.Module):
966
+ _, score, _ = get_reward(
967
+ self.reward_model,
968
+ postprocessed_query_response,
969
+ processing_class.pad_token_id,
970
+ context_length,
971
+ )
972
+ else:
973
+ score = torch.tensor(
974
+ self.reward_model(
975
+ processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
976
+ ),
977
+ dtype=torch.float,
978
+ ).to(postprocessed_query_response.device)
979
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
980
+
981
+ if sampling:
982
+ break
983
+ df = pd.DataFrame(table)
984
+
985
+ if self.accelerator.is_main_process:
986
+ print_rich_table(df.iloc[0 : 0 + 5])
987
+ if "wandb" in args.report_to:
988
+ import wandb
989
+
990
+ if wandb.run is not None:
991
+ wandb.log({"completions": wandb.Table(dataframe=df)})
992
+
993
+ if "comet_ml" in args.report_to:
994
+ log_table_to_comet_experiment(
995
+ name="completions.csv",
996
+ table=df,
997
+ )
998
+
999
+ def create_model_card(
1000
+ self,
1001
+ model_name: Optional[str] = None,
1002
+ dataset_name: Optional[str] = None,
1003
+ tags: Union[str, list[str], None] = None,
1004
+ ):
1005
+ """
1006
+ Creates a draft of a model card using the information available to the `Trainer`.
1007
+
1008
+ Args:
1009
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1010
+ Name of the model.
1011
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1012
+ Name of the dataset used for training.
1013
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1014
+ Tags to be associated with the model card.
1015
+ """
1016
+ if not self.is_world_process_zero():
1017
+ return
1018
+
1019
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1020
+ base_model = self.model.config._name_or_path
1021
+ else:
1022
+ base_model = None
1023
+
1024
+ tags = tags or []
1025
+ if isinstance(tags, str):
1026
+ tags = [tags]
1027
+
1028
+ if hasattr(self.model.config, "unsloth_version"):
1029
+ tags.append("unsloth")
1030
+
1031
+ citation = textwrap.dedent("""\
1032
+ @inproceedings{ahmadian2024back,
1033
+ title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1034
+ author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1035
+ year = 2024,
1036
+ booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1037
+ publisher = {Association for Computational Linguistics},
1038
+ pages = {12248--12267},
1039
+ editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1040
+ }""")
1041
+
1042
+ model_card = generate_model_card(
1043
+ base_model=base_model,
1044
+ model_name=model_name,
1045
+ hub_model_id=self.hub_model_id,
1046
+ dataset_name=dataset_name,
1047
+ tags=tags,
1048
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1049
+ comet_url=get_comet_experiment_url(),
1050
+ trainer_name="RLOO",
1051
+ trainer_citation=citation,
1052
+ paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1053
+ paper_id="2402.14740",
1054
+ )
1055
+
1056
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1057
+ class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1058
+ """
1059
+
1060
+ """
1061
+ def __init__(
1062
+ self,
1063
+ config,
1064
+ processing_class,
1065
+ policy,
1066
+ ref_policy,
1067
+ reward_model,
1068
+ train_dataset,
1069
+ data_collator = None,
1070
+ eval_dataset = None,
1071
+ callbacks = None,
1072
+ **kwargs
1073
+ ):
1074
+ if args is None: args = UnslothRLOOConfig()
1075
+ _output_logits = False
1076
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1077
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1078
+ if _output_logits:
1079
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1080
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1081
+ pass
1082
+ else:
1083
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1084
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1085
+ if args_max_seq_length is None and model_max_seq_length is not None:
1086
+ max_seq_length = model.max_seq_length
1087
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1088
+ if model is not None and hasattr(model, 'for_training'):
1089
+ model.for_training()
1090
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1091
+ if 'processing_class' in locals():
1092
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1093
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1094
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1095
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1096
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1097
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1098
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1099
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1100
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1101
+ else:
1102
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1103
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1104
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1105
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1106
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1107
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1108
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1109
+ else:
1110
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1111
+ other_metrics = []
1112
+
1113
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1114
+ PatchRLStatistics('rloo_trainer', other_metrics)
1115
+
1116
+ super().__init__(
1117
+ config = config,
1118
+ processing_class = processing_class,
1119
+ policy = policy,
1120
+ ref_policy = ref_policy,
1121
+ reward_model = reward_model,
1122
+ train_dataset = train_dataset,
1123
+ data_collator = data_collator,
1124
+ eval_dataset = eval_dataset,
1125
+ callbacks = callbacks,**kwargs)
1126
+ if hasattr(self, 'neftune_hook_handle'):
1127
+ self.neftune_hook_handle.remove()
1128
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1129
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1130
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1131
+ pass
1132
+
1133
+ pass
unsloth_compiled_cache/UnslothRewardTrainer.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, warnings)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothRewardConfig(RewardConfig):
44
+ """
45
+
46
+ Configuration class for the [`RewardTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
55
+ limit. This argument is required if you want to use the default data collator.
56
+ disable_dropout (`bool`, *optional*, defaults to `True`):
57
+ Whether to disable dropout in the model.
58
+ dataset_num_proc (`int`, *optional*, defaults to `None`):
59
+ Number of processes to use for processing the dataset.
60
+ center_rewards_coefficient (`float`, *optional*, defaults to `None`):
61
+ Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
62
+ https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
63
+ remove_unused_columns (`bool`, *optional*, defaults to `False`):
64
+ Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
65
+ the dataset is pretokenized.
66
+
67
+ """
68
+ vllm_sampling_params: Optional[Any] = field(
69
+ default = None,
70
+ metadata = {'help': 'vLLM SamplingParams'},
71
+ )
72
+ unsloth_num_chunks : Optional[int] = field(
73
+ default = -1,
74
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
75
+ )
76
+ def __init__(
77
+ self,
78
+ output_dir = None,
79
+ overwrite_output_dir = None,
80
+ do_train = False,
81
+ do_eval = False,
82
+ do_predict = False,
83
+ eval_strategy = 'no',
84
+ prediction_loss_only = False,
85
+ per_device_train_batch_size = 4,
86
+ per_device_eval_batch_size = 4,
87
+ per_gpu_train_batch_size = None,
88
+ per_gpu_eval_batch_size = None,
89
+ gradient_accumulation_steps = 2,
90
+ eval_accumulation_steps = 2,
91
+ eval_delay = 0,
92
+ torch_empty_cache_steps = 250,
93
+ learning_rate = 5e-05,
94
+ weight_decay = 0.01,
95
+ adam_beta1 = 0.9,
96
+ adam_beta2 = 0.999,
97
+ adam_epsilon = 1e-08,
98
+ max_grad_norm = 1.0,
99
+ num_train_epochs = 3.0,
100
+ max_steps = -1,
101
+ lr_scheduler_type = 'linear',
102
+ warmup_ratio = 0.1,
103
+ warmup_steps = 0,
104
+ log_level = 'passive',
105
+ log_level_replica = 'warning',
106
+ log_on_each_node = True,
107
+ logging_dir = None,
108
+ logging_strategy = 'steps',
109
+ logging_first_step = False,
110
+ logging_steps = 1,
111
+ logging_nan_inf_filter = False,
112
+ save_strategy = 'steps',
113
+ save_steps = 500,
114
+ save_total_limit = None,
115
+ save_safetensors = True,
116
+ save_on_each_node = False,
117
+ save_only_model = False,
118
+ restore_callback_states_from_checkpoint = False,
119
+ no_cuda = False,
120
+ use_cpu = False,
121
+ use_mps_device = False,
122
+ seed = 3407,
123
+ data_seed = 3407,
124
+ jit_mode_eval = False,
125
+ use_ipex = False,
126
+ bf16 = False,
127
+ fp16 = False,
128
+ fp16_opt_level = 'O1',
129
+ half_precision_backend = 'auto',
130
+ bf16_full_eval = False,
131
+ fp16_full_eval = False,
132
+ tf32 = None,
133
+ local_rank = -1,
134
+ ddp_backend = None,
135
+ tpu_num_cores = None,
136
+ tpu_metrics_debug = False,
137
+ debug = '',
138
+ dataloader_drop_last = False,
139
+ eval_steps = None,
140
+ dataloader_num_workers = 0,
141
+ dataloader_prefetch_factor = None,
142
+ past_index = -1,
143
+ run_name = None,
144
+ disable_tqdm = None,
145
+ remove_unused_columns = False,
146
+ label_names = None,
147
+ load_best_model_at_end = False,
148
+ metric_for_best_model = None,
149
+ greater_is_better = None,
150
+ ignore_data_skip = False,
151
+ fsdp = '',
152
+ fsdp_min_num_params = 0,
153
+ fsdp_config = None,
154
+ tp_size = 0,
155
+ fsdp_transformer_layer_cls_to_wrap = None,
156
+ accelerator_config = None,
157
+ deepspeed = None,
158
+ label_smoothing_factor = 0.0,
159
+ optim = 'adamw_8bit',
160
+ optim_args = None,
161
+ adafactor = False,
162
+ group_by_length = False,
163
+ length_column_name = 'length',
164
+ report_to = None,
165
+ ddp_find_unused_parameters = None,
166
+ ddp_bucket_cap_mb = None,
167
+ ddp_broadcast_buffers = None,
168
+ dataloader_pin_memory = True,
169
+ dataloader_persistent_workers = False,
170
+ skip_memory_metrics = True,
171
+ use_legacy_prediction_loop = False,
172
+ push_to_hub = False,
173
+ resume_from_checkpoint = None,
174
+ hub_model_id = None,
175
+ hub_strategy = 'every_save',
176
+ hub_token = None,
177
+ hub_private_repo = None,
178
+ hub_always_push = False,
179
+ gradient_checkpointing = False,
180
+ gradient_checkpointing_kwargs = None,
181
+ include_inputs_for_metrics = False,
182
+ eval_do_concat_batches = True,
183
+ fp16_backend = 'auto',
184
+ evaluation_strategy = None,
185
+ push_to_hub_model_id = None,
186
+ push_to_hub_organization = None,
187
+ push_to_hub_token = None,
188
+ mp_parameters = '',
189
+ auto_find_batch_size = False,
190
+ full_determinism = False,
191
+ torchdynamo = None,
192
+ ray_scope = 'last',
193
+ ddp_timeout = 1800,
194
+ torch_compile = False,
195
+ torch_compile_backend = None,
196
+ torch_compile_mode = None,
197
+ dispatch_batches = None,
198
+ split_batches = None,
199
+ include_tokens_per_second = False,
200
+ include_num_input_tokens_seen = False,
201
+ neftune_noise_alpha = None,
202
+ optim_target_modules = None,
203
+ batch_eval_metrics = False,
204
+ eval_on_start = False,
205
+ use_liger_kernel = False,
206
+ eval_use_gather_object = False,
207
+ average_tokens_across_devices = False,
208
+ max_length = 1024,
209
+ disable_dropout = True,
210
+ dataset_num_proc = None,
211
+ center_rewards_coefficient = None,
212
+ vllm_sampling_params = None,
213
+ unsloth_num_chunks = -1,
214
+ **kwargs,
215
+ ):
216
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
217
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
218
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
219
+ output_dir = 'unsloth_training_checkpoints'
220
+ save_strategy = 'no'
221
+ if dataset_num_proc is None:
222
+ from multiprocessing import cpu_count
223
+ dataset_num_proc = cpu_count()
224
+
225
+ super().__init__(
226
+ output_dir = output_dir,
227
+ overwrite_output_dir = overwrite_output_dir,
228
+ do_train = do_train,
229
+ do_eval = do_eval,
230
+ do_predict = do_predict,
231
+ eval_strategy = eval_strategy,
232
+ prediction_loss_only = prediction_loss_only,
233
+ per_device_train_batch_size = per_device_train_batch_size,
234
+ per_device_eval_batch_size = per_device_eval_batch_size,
235
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
236
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
237
+ gradient_accumulation_steps = gradient_accumulation_steps,
238
+ eval_accumulation_steps = eval_accumulation_steps,
239
+ eval_delay = eval_delay,
240
+ torch_empty_cache_steps = torch_empty_cache_steps,
241
+ learning_rate = learning_rate,
242
+ weight_decay = weight_decay,
243
+ adam_beta1 = adam_beta1,
244
+ adam_beta2 = adam_beta2,
245
+ adam_epsilon = adam_epsilon,
246
+ max_grad_norm = max_grad_norm,
247
+ num_train_epochs = num_train_epochs,
248
+ max_steps = max_steps,
249
+ lr_scheduler_type = lr_scheduler_type,
250
+ warmup_ratio = warmup_ratio,
251
+ warmup_steps = warmup_steps,
252
+ log_level = log_level,
253
+ log_level_replica = log_level_replica,
254
+ log_on_each_node = log_on_each_node,
255
+ logging_dir = logging_dir,
256
+ logging_strategy = logging_strategy,
257
+ logging_first_step = logging_first_step,
258
+ logging_steps = logging_steps,
259
+ logging_nan_inf_filter = logging_nan_inf_filter,
260
+ save_strategy = save_strategy,
261
+ save_steps = save_steps,
262
+ save_total_limit = save_total_limit,
263
+ save_safetensors = save_safetensors,
264
+ save_on_each_node = save_on_each_node,
265
+ save_only_model = save_only_model,
266
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
267
+ no_cuda = no_cuda,
268
+ use_cpu = use_cpu,
269
+ use_mps_device = use_mps_device,
270
+ seed = seed,
271
+ data_seed = data_seed,
272
+ jit_mode_eval = jit_mode_eval,
273
+ use_ipex = use_ipex,
274
+ bf16 = bf16,
275
+ fp16 = fp16,
276
+ fp16_opt_level = fp16_opt_level,
277
+ half_precision_backend = half_precision_backend,
278
+ bf16_full_eval = bf16_full_eval,
279
+ fp16_full_eval = fp16_full_eval,
280
+ tf32 = tf32,
281
+ local_rank = local_rank,
282
+ ddp_backend = ddp_backend,
283
+ tpu_num_cores = tpu_num_cores,
284
+ tpu_metrics_debug = tpu_metrics_debug,
285
+ debug = debug,
286
+ dataloader_drop_last = dataloader_drop_last,
287
+ eval_steps = eval_steps,
288
+ dataloader_num_workers = dataloader_num_workers,
289
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
290
+ past_index = past_index,
291
+ run_name = run_name,
292
+ disable_tqdm = disable_tqdm,
293
+ remove_unused_columns = remove_unused_columns,
294
+ label_names = label_names,
295
+ load_best_model_at_end = load_best_model_at_end,
296
+ metric_for_best_model = metric_for_best_model,
297
+ greater_is_better = greater_is_better,
298
+ ignore_data_skip = ignore_data_skip,
299
+ fsdp = fsdp,
300
+ fsdp_min_num_params = fsdp_min_num_params,
301
+ fsdp_config = fsdp_config,
302
+ tp_size = tp_size,
303
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
304
+ accelerator_config = accelerator_config,
305
+ deepspeed = deepspeed,
306
+ label_smoothing_factor = label_smoothing_factor,
307
+ optim = optim,
308
+ optim_args = optim_args,
309
+ adafactor = adafactor,
310
+ group_by_length = group_by_length,
311
+ length_column_name = length_column_name,
312
+ report_to = report_to,
313
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
314
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
315
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
316
+ dataloader_pin_memory = dataloader_pin_memory,
317
+ dataloader_persistent_workers = dataloader_persistent_workers,
318
+ skip_memory_metrics = skip_memory_metrics,
319
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
320
+ push_to_hub = push_to_hub,
321
+ resume_from_checkpoint = resume_from_checkpoint,
322
+ hub_model_id = hub_model_id,
323
+ hub_strategy = hub_strategy,
324
+ hub_token = hub_token,
325
+ hub_private_repo = hub_private_repo,
326
+ hub_always_push = hub_always_push,
327
+ gradient_checkpointing = gradient_checkpointing,
328
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
329
+ include_inputs_for_metrics = include_inputs_for_metrics,
330
+ eval_do_concat_batches = eval_do_concat_batches,
331
+ fp16_backend = fp16_backend,
332
+ evaluation_strategy = evaluation_strategy,
333
+ push_to_hub_model_id = push_to_hub_model_id,
334
+ push_to_hub_organization = push_to_hub_organization,
335
+ push_to_hub_token = push_to_hub_token,
336
+ mp_parameters = mp_parameters,
337
+ auto_find_batch_size = auto_find_batch_size,
338
+ full_determinism = full_determinism,
339
+ torchdynamo = torchdynamo,
340
+ ray_scope = ray_scope,
341
+ ddp_timeout = ddp_timeout,
342
+ torch_compile = torch_compile,
343
+ torch_compile_backend = torch_compile_backend,
344
+ torch_compile_mode = torch_compile_mode,
345
+ dispatch_batches = dispatch_batches,
346
+ split_batches = split_batches,
347
+ include_tokens_per_second = include_tokens_per_second,
348
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
349
+ neftune_noise_alpha = neftune_noise_alpha,
350
+ optim_target_modules = optim_target_modules,
351
+ batch_eval_metrics = batch_eval_metrics,
352
+ eval_on_start = eval_on_start,
353
+ use_liger_kernel = use_liger_kernel,
354
+ eval_use_gather_object = eval_use_gather_object,
355
+ average_tokens_across_devices = average_tokens_across_devices,
356
+ max_length = max_length,
357
+ disable_dropout = disable_dropout,
358
+ dataset_num_proc = dataset_num_proc,
359
+ center_rewards_coefficient = center_rewards_coefficient,**kwargs)
360
+ self.vllm_sampling_params = vllm_sampling_params
361
+ self.unsloth_num_chunks = unsloth_num_chunks
362
+ pass
363
+
364
+ class _UnslothRewardTrainer(Trainer):
365
+ _tag_names = ["trl", "reward-trainer"]
366
+
367
+ def __init__(
368
+ self,
369
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
370
+ args: Optional[RewardConfig] = None,
371
+ data_collator: Optional[DataCollator] = None,
372
+ train_dataset: Optional[Dataset] = None,
373
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
374
+ processing_class: Optional[
375
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
376
+ ] = None,
377
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
378
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
379
+ callbacks: Optional[list[TrainerCallback]] = None,
380
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
381
+ None,
382
+ None,
383
+ ),
384
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
385
+ peft_config: Optional[dict] = None,
386
+ ):
387
+ """
388
+ Initialize RewardTrainer.
389
+
390
+ Args:
391
+ model (`transformers.PreTrainedModel`):
392
+ The model to train, preferably an `AutoModelForSequenceClassification`.
393
+ args (`RewardConfig`):
394
+ The arguments to use for training.
395
+ data_collator (`transformers.DataCollator`):
396
+ The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
397
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
398
+ train_dataset (`datasets.Dataset`):
399
+ The dataset to use for training.
400
+ eval_dataset (`datasets.Dataset`):
401
+ The dataset to use for evaluation.
402
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
403
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
404
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
405
+ reuse the fine-tuned model.
406
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
407
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
408
+ compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
409
+ The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
410
+ callbacks (`list[transformers.TrainerCallback]`):
411
+ The callbacks to use for training.
412
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
413
+ The optimizer and scheduler to use for training.
414
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
415
+ The function to use to preprocess the logits before computing the metrics.
416
+ peft_config (`dict`, defaults to `None`):
417
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
418
+ """
419
+ if not is_peft_available() and peft_config is not None:
420
+ raise ValueError(
421
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
422
+ )
423
+ elif is_peft_available() and peft_config is not None:
424
+ if not isinstance(model, PeftModel):
425
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
426
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
427
+ inspect.signature(prepare_model_for_kbit_training).parameters
428
+ )
429
+
430
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
431
+
432
+ if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
433
+ warnings.warn(
434
+ "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
435
+ "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
436
+ UserWarning,
437
+ )
438
+ elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
439
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
440
+
441
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
442
+
443
+ model = model
444
+
445
+ # Disable dropout in the model
446
+ if args.disable_dropout:
447
+ disable_dropout_in_model(model)
448
+
449
+ if compute_metrics is None:
450
+ compute_metrics = compute_accuracy
451
+
452
+ if data_collator is None:
453
+ if processing_class is None:
454
+ raise ValueError(
455
+ "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
456
+ )
457
+
458
+ max_length = args.max_length
459
+
460
+ data_collator = RewardDataCollatorWithPadding(processing_class)
461
+
462
+ if args.remove_unused_columns:
463
+ try: # for bc before https://github.com/huggingface/transformers/pull/25435
464
+ args.remove_unused_columns = False
465
+ except FrozenInstanceError:
466
+ args = replace(args, remove_unused_columns=False)
467
+ # warn users
468
+ warnings.warn(
469
+ "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
470
+ " we have set it for you, but you should do it yourself in the future.",
471
+ UserWarning,
472
+ )
473
+
474
+ self.use_reward_data_collator = True
475
+ else:
476
+ self.use_reward_data_collator = False
477
+
478
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
479
+ # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
480
+ # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
481
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
482
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
483
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
484
+ # issued.
485
+ model.warnings_issued["estimate_tokens"] = True
486
+
487
+ if "input_ids_chosen" not in train_dataset.column_names:
488
+ with PartialState().local_main_process_first():
489
+ fn_kwargs = {"tokenizer": processing_class}
490
+ train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
491
+ train_dataset = train_dataset.map(
492
+ _tokenize,
493
+ batched=True,
494
+ fn_kwargs=fn_kwargs,
495
+ num_proc=args.dataset_num_proc,
496
+ )
497
+ # This filter is important because otherwise you get samples that exceed the model's context length and
498
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
499
+ # user might get surprised if N samples are missing from training.
500
+ train_dataset = train_dataset.filter(
501
+ lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
502
+ num_proc=args.dataset_num_proc,
503
+ )
504
+ if eval_dataset is not None:
505
+ eval_dataset = eval_dataset.map(
506
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
507
+ )
508
+ eval_dataset = eval_dataset.map(
509
+ _tokenize,
510
+ fn_kwargs=fn_kwargs,
511
+ batched=True,
512
+ num_proc=args.dataset_num_proc,
513
+ )
514
+ # This filter is important because otherwise you get samples that exceed the model's context length and
515
+ # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
516
+ # user might get surprised if N samples are missing from training.
517
+ eval_dataset = eval_dataset.filter(
518
+ lambda x: len(x["input_ids_chosen"]) <= max_length
519
+ and len(x["input_ids_rejected"]) <= max_length,
520
+ num_proc=args.dataset_num_proc,
521
+ )
522
+
523
+ super().__init__(
524
+ model=model,
525
+ args=args,
526
+ data_collator=data_collator,
527
+ train_dataset=train_dataset,
528
+ eval_dataset=eval_dataset,
529
+ processing_class=processing_class,
530
+ model_init=model_init,
531
+ compute_metrics=compute_metrics,
532
+ callbacks=callbacks,
533
+ optimizers=optimizers,
534
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
535
+ )
536
+
537
+ # Add tags for models that have been loaded with the correct transformers version
538
+ if hasattr(self.model, "add_model_tags"):
539
+ self.model.add_model_tags(self._tag_names)
540
+
541
+ def compute_loss(
542
+ self,
543
+ model: Union[PreTrainedModel, nn.Module],
544
+ inputs: dict[str, Union[torch.Tensor, Any]],
545
+ return_outputs=False,
546
+ num_items_in_batch=None,
547
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
548
+ rewards_chosen = model(
549
+ input_ids=inputs["input_ids_chosen"],
550
+ attention_mask=inputs["attention_mask_chosen"],
551
+ return_dict=True,
552
+ )["logits"]
553
+ rewards_rejected = model(
554
+ input_ids=inputs["input_ids_rejected"],
555
+ attention_mask=inputs["attention_mask_rejected"],
556
+ return_dict=True,
557
+ )["logits"]
558
+ # calculate loss, optionally modulate with margin
559
+ if "margin" in inputs:
560
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
561
+ else:
562
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
563
+
564
+ if self.args.center_rewards_coefficient is not None:
565
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
566
+
567
+ if return_outputs:
568
+ return loss, {
569
+ "rewards_chosen": rewards_chosen,
570
+ "rewards_rejected": rewards_rejected,
571
+ }
572
+ return loss
573
+
574
+ def prediction_step(
575
+ self,
576
+ model: Union[PreTrainedModel, nn.Module],
577
+ inputs: dict[str, Union[torch.Tensor, Any]],
578
+ prediction_loss_only: bool,
579
+ ignore_keys: Optional[list[str]] = None,
580
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
581
+ inputs = self._prepare_inputs(inputs)
582
+ if ignore_keys is None:
583
+ if hasattr(self.model, "config"):
584
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
585
+ else:
586
+ ignore_keys = []
587
+
588
+ with torch.no_grad():
589
+ loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
590
+
591
+ if prediction_loss_only:
592
+ return (loss, None, None)
593
+
594
+ loss = loss.detach()
595
+ logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
596
+ logits = nested_detach(logits)
597
+ # Stack accepted against rejected, mean over logits
598
+ # and softmax to get preferences between accepted and rejected to sum to 1
599
+ logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
600
+
601
+ labels = torch.zeros(logits.shape[0])
602
+ labels = self._prepare_inputs(labels)
603
+
604
+ return loss, logits, labels
605
+
606
+ def evaluate(self, *args, **kwargs):
607
+ num_print_samples = kwargs.pop("num_print_samples", 4)
608
+ self.visualize_samples(num_print_samples)
609
+ return super().evaluate(*args, **kwargs)
610
+
611
+ def visualize_samples(self, num_print_samples: int):
612
+ """
613
+ Visualize the reward model logits prediction
614
+
615
+ Args:
616
+ num_print_samples (`int`, defaults to `4`):
617
+ The number of samples to print. Set to `-1` to print all samples.
618
+ """
619
+ eval_dataloader = self.get_eval_dataloader()
620
+ table = defaultdict(list)
621
+ for _, inputs in enumerate(eval_dataloader):
622
+ _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
623
+ chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
624
+ rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
625
+ table["chosen_text"].extend(gather_object(chosen_text))
626
+ table["rejected_text"].extend(gather_object(rejected_text))
627
+ table["logits"].extend(
628
+ gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
629
+ )
630
+ if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
631
+ break
632
+ df = pd.DataFrame(table)
633
+ if self.accelerator.process_index == 0:
634
+ print_rich_table(df[:num_print_samples])
635
+ if "wandb" in self.args.report_to:
636
+ import wandb
637
+
638
+ if wandb.run is not None:
639
+ wandb.log({"completions": wandb.Table(dataframe=df)})
640
+
641
+ if "comet_ml" in self.args.report_to:
642
+ log_table_to_comet_experiment(
643
+ name="completions.csv",
644
+ table=df,
645
+ )
646
+
647
+ def create_model_card(
648
+ self,
649
+ model_name: Optional[str] = None,
650
+ dataset_name: Optional[str] = None,
651
+ tags: Union[str, list[str], None] = None,
652
+ ):
653
+ """
654
+ Creates a draft of a model card using the information available to the `Trainer`.
655
+
656
+ Args:
657
+ model_name (`str` or `None`, *optional*, defaults to `None`):
658
+ Name of the model.
659
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
660
+ Name of the dataset used for training.
661
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
662
+ Tags to be associated with the model card.
663
+ """
664
+ if not self.is_world_process_zero():
665
+ return
666
+
667
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
668
+ base_model = self.model.config._name_or_path
669
+ else:
670
+ base_model = None
671
+
672
+ tags = tags or []
673
+ if isinstance(tags, str):
674
+ tags = [tags]
675
+
676
+ if hasattr(self.model.config, "unsloth_version"):
677
+ tags.append("unsloth")
678
+
679
+ model_card = generate_model_card(
680
+ base_model=base_model,
681
+ model_name=model_name,
682
+ hub_model_id=self.hub_model_id,
683
+ dataset_name=dataset_name,
684
+ tags=tags,
685
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
686
+ comet_url=get_comet_experiment_url(),
687
+ trainer_name="Reward",
688
+ )
689
+
690
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
691
+ class UnslothRewardTrainer(_UnslothRewardTrainer):
692
+ """
693
+
694
+ """
695
+ def __init__(
696
+ self,
697
+ model = None,
698
+ args = None,
699
+ data_collator = None,
700
+ train_dataset = None,
701
+ eval_dataset = None,
702
+ processing_class = None,
703
+ model_init = None,
704
+ compute_metrics = None,
705
+ callbacks = None,
706
+ preprocess_logits_for_metrics = None,
707
+ peft_config = None,
708
+ **kwargs
709
+ ):
710
+ if args is None: args = UnslothRewardConfig()
711
+ use_bf16 = getattr(args, 'bf16', False)
712
+ use_fp16 = getattr(args, 'fp16', False)
713
+ force_float32 = False
714
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
715
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
716
+ force_float32 = True
717
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
718
+ dtype = getattr(model.config, 'torch_dtype', None)
719
+ if dtype is None: dtype = model.get_input_embeddings().dtype
720
+ from unsloth_zoo.utils import _get_dtype
721
+ dtype = _get_dtype(dtype)
722
+ float16 = dtype == torch.float16
723
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
724
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
725
+ if force_float32:
726
+ args.fp16 = False
727
+ args.bf16 = False
728
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
729
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
730
+ args.fp16 = float16
731
+ args.bf16 = not float16
732
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
733
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
734
+ args.eval_strategy = 'steps'
735
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
736
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
737
+ if ga_steps is not None and ga_steps > 1:
738
+ from transformers import __version__ as transformers_version
739
+ if Version(transformers_version) <= Version('4.45.2'):
740
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
741
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
742
+ if getattr(args, 'eval_strategy', 'no') != 'no':
743
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
744
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
745
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
746
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
747
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
748
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
749
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
750
+ if force_float32:
751
+ args.bf16_full_eval = False
752
+ args.fp16_full_eval = False
753
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
754
+ args.bf16_full_eval = True
755
+ args.fp16_full_eval = False
756
+ elif not bf16_full_eval and not fp16_full_eval:
757
+ args.bf16_full_eval = args.bf16
758
+ args.fp16_full_eval = args.fp16
759
+ _output_logits = False
760
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
761
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
762
+ if _output_logits:
763
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
764
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
765
+ pass
766
+ else:
767
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
768
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
769
+ if args_max_seq_length is None and model_max_seq_length is not None:
770
+ max_seq_length = model.max_seq_length
771
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
772
+ if model is not None and hasattr(model, 'for_training'):
773
+ model.for_training()
774
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
775
+ if 'processing_class' in locals():
776
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
777
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
778
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
779
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
780
+ if not isinstance(data_collator, UnslothVisionDataCollator):
781
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
782
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
783
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
784
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
785
+ else:
786
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
787
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
788
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
789
+ if not isinstance(data_collator, UnslothVisionDataCollator):
790
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
791
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
792
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
793
+ else:
794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
795
+ other_metrics = []
796
+
797
+ from unsloth_zoo.logging_utils import PatchRLStatistics
798
+ PatchRLStatistics('reward_trainer', other_metrics)
799
+
800
+ super().__init__(
801
+ model = model,
802
+ args = args,
803
+ data_collator = data_collator,
804
+ train_dataset = train_dataset,
805
+ eval_dataset = eval_dataset,
806
+ processing_class = processing_class,
807
+ model_init = model_init,
808
+ compute_metrics = compute_metrics,
809
+ callbacks = callbacks,
810
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
811
+ peft_config = peft_config,**kwargs)
812
+ if hasattr(self, 'neftune_hook_handle'):
813
+ self.neftune_hook_handle.remove()
814
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
815
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
816
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
817
+ pass
818
+
819
+ pass
unsloth_compiled_cache/UnslothSFTTrainer.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothSFTConfig(SFTConfig):
44
+ """
45
+
46
+ Configuration class for the [`SFTTrainer`].
47
+
48
+ Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
49
+ [`~transformers.TrainingArguments`] documentation.
50
+
51
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
52
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
53
+ command line.
54
+
55
+ Parameters:
56
+ > Parameters that control the model
57
+
58
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
59
+ Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
60
+ argument of the [`SFTTrainer`] is provided as a string.
61
+ use_liger (`bool`, *optional*, defaults to `False`):
62
+ Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
63
+
64
+ > Parameters that control the data preprocessing
65
+
66
+ dataset_text_field (`str`, *optional*, defaults to `"text"`):
67
+ Name of the column that contains text data in the dataset.
68
+ dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
69
+ Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
70
+ `skip_prepare_dataset`.
71
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
72
+ Number of processes to use for processing the dataset.
73
+ max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
74
+ Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
75
+ right.
76
+ If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
77
+ packing (`bool`, *optional*, defaults to `False`):
78
+ Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
79
+ length.
80
+ eval_packing (`bool` or `None`, *optional*, defaults to `None`):
81
+ Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
82
+
83
+ > Parameters that control the training
84
+
85
+ learning_rate (`float`, *optional*, defaults to `2e-5`):
86
+ Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
87
+ [`~transformers.TrainingArguments`].
88
+
89
+ """
90
+ vllm_sampling_params: Optional[Any] = field(
91
+ default = None,
92
+ metadata = {'help': 'vLLM SamplingParams'},
93
+ )
94
+ unsloth_num_chunks : Optional[int] = field(
95
+ default = -1,
96
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
97
+ )
98
+ def __init__(
99
+ self,
100
+ output_dir = None,
101
+ overwrite_output_dir = None,
102
+ do_train = False,
103
+ do_eval = False,
104
+ do_predict = False,
105
+ eval_strategy = 'no',
106
+ prediction_loss_only = False,
107
+ per_device_train_batch_size = 4,
108
+ per_device_eval_batch_size = 4,
109
+ per_gpu_train_batch_size = None,
110
+ per_gpu_eval_batch_size = None,
111
+ gradient_accumulation_steps = 2,
112
+ eval_accumulation_steps = 2,
113
+ eval_delay = 0,
114
+ torch_empty_cache_steps = 250,
115
+ learning_rate = 5e-05,
116
+ weight_decay = 0.01,
117
+ adam_beta1 = 0.9,
118
+ adam_beta2 = 0.999,
119
+ adam_epsilon = 1e-08,
120
+ max_grad_norm = 1.0,
121
+ num_train_epochs = 3.0,
122
+ max_steps = -1,
123
+ lr_scheduler_type = 'linear',
124
+ warmup_ratio = 0.1,
125
+ warmup_steps = 0,
126
+ log_level = 'passive',
127
+ log_level_replica = 'warning',
128
+ log_on_each_node = True,
129
+ logging_dir = None,
130
+ logging_strategy = 'steps',
131
+ logging_first_step = False,
132
+ logging_steps = 1,
133
+ logging_nan_inf_filter = False,
134
+ save_strategy = 'steps',
135
+ save_steps = 500,
136
+ save_total_limit = None,
137
+ save_safetensors = True,
138
+ save_on_each_node = False,
139
+ save_only_model = False,
140
+ restore_callback_states_from_checkpoint = False,
141
+ no_cuda = False,
142
+ use_cpu = False,
143
+ use_mps_device = False,
144
+ seed = 3407,
145
+ data_seed = 3407,
146
+ jit_mode_eval = False,
147
+ use_ipex = False,
148
+ bf16 = False,
149
+ fp16 = False,
150
+ fp16_opt_level = 'O1',
151
+ half_precision_backend = 'auto',
152
+ bf16_full_eval = False,
153
+ fp16_full_eval = False,
154
+ tf32 = None,
155
+ local_rank = -1,
156
+ ddp_backend = None,
157
+ tpu_num_cores = None,
158
+ tpu_metrics_debug = False,
159
+ debug = '',
160
+ dataloader_drop_last = False,
161
+ eval_steps = None,
162
+ dataloader_num_workers = 0,
163
+ dataloader_prefetch_factor = None,
164
+ past_index = -1,
165
+ run_name = None,
166
+ disable_tqdm = None,
167
+ remove_unused_columns = True,
168
+ label_names = None,
169
+ load_best_model_at_end = False,
170
+ metric_for_best_model = None,
171
+ greater_is_better = None,
172
+ ignore_data_skip = False,
173
+ fsdp = '',
174
+ fsdp_min_num_params = 0,
175
+ fsdp_config = None,
176
+ tp_size = 0,
177
+ fsdp_transformer_layer_cls_to_wrap = None,
178
+ accelerator_config = None,
179
+ deepspeed = None,
180
+ label_smoothing_factor = 0.0,
181
+ optim = 'adamw_8bit',
182
+ optim_args = None,
183
+ adafactor = False,
184
+ group_by_length = False,
185
+ length_column_name = 'length',
186
+ report_to = None,
187
+ ddp_find_unused_parameters = None,
188
+ ddp_bucket_cap_mb = None,
189
+ ddp_broadcast_buffers = None,
190
+ dataloader_pin_memory = True,
191
+ dataloader_persistent_workers = False,
192
+ skip_memory_metrics = True,
193
+ use_legacy_prediction_loop = False,
194
+ push_to_hub = False,
195
+ resume_from_checkpoint = None,
196
+ hub_model_id = None,
197
+ hub_strategy = 'every_save',
198
+ hub_token = None,
199
+ hub_private_repo = None,
200
+ hub_always_push = False,
201
+ gradient_checkpointing = False,
202
+ gradient_checkpointing_kwargs = None,
203
+ include_inputs_for_metrics = False,
204
+ eval_do_concat_batches = True,
205
+ fp16_backend = 'auto',
206
+ evaluation_strategy = None,
207
+ push_to_hub_model_id = None,
208
+ push_to_hub_organization = None,
209
+ push_to_hub_token = None,
210
+ mp_parameters = '',
211
+ auto_find_batch_size = False,
212
+ full_determinism = False,
213
+ torchdynamo = None,
214
+ ray_scope = 'last',
215
+ ddp_timeout = 1800,
216
+ torch_compile = False,
217
+ torch_compile_backend = None,
218
+ torch_compile_mode = None,
219
+ dispatch_batches = None,
220
+ split_batches = None,
221
+ include_tokens_per_second = False,
222
+ include_num_input_tokens_seen = False,
223
+ neftune_noise_alpha = None,
224
+ optim_target_modules = None,
225
+ batch_eval_metrics = False,
226
+ eval_on_start = False,
227
+ use_liger_kernel = False,
228
+ eval_use_gather_object = False,
229
+ average_tokens_across_devices = False,
230
+ model_init_kwargs = None,
231
+ use_liger = False,
232
+ dataset_text_field = 'text',
233
+ dataset_kwargs = None,
234
+ dataset_num_proc = None,
235
+ max_seq_length = None,
236
+ packing = False,
237
+ eval_packing = None,
238
+ dataset_batch_size = None,
239
+ num_of_sequences = None,
240
+ chars_per_token = None,
241
+ vllm_sampling_params = None,
242
+ unsloth_num_chunks = -1,
243
+ **kwargs,
244
+ ):
245
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
246
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
247
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
248
+ output_dir = 'unsloth_training_checkpoints'
249
+ save_strategy = 'no'
250
+ if dataset_num_proc is None:
251
+ from multiprocessing import cpu_count
252
+ dataset_num_proc = cpu_count()
253
+
254
+ super().__init__(
255
+ output_dir = output_dir,
256
+ overwrite_output_dir = overwrite_output_dir,
257
+ do_train = do_train,
258
+ do_eval = do_eval,
259
+ do_predict = do_predict,
260
+ eval_strategy = eval_strategy,
261
+ prediction_loss_only = prediction_loss_only,
262
+ per_device_train_batch_size = per_device_train_batch_size,
263
+ per_device_eval_batch_size = per_device_eval_batch_size,
264
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
265
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
266
+ gradient_accumulation_steps = gradient_accumulation_steps,
267
+ eval_accumulation_steps = eval_accumulation_steps,
268
+ eval_delay = eval_delay,
269
+ torch_empty_cache_steps = torch_empty_cache_steps,
270
+ learning_rate = learning_rate,
271
+ weight_decay = weight_decay,
272
+ adam_beta1 = adam_beta1,
273
+ adam_beta2 = adam_beta2,
274
+ adam_epsilon = adam_epsilon,
275
+ max_grad_norm = max_grad_norm,
276
+ num_train_epochs = num_train_epochs,
277
+ max_steps = max_steps,
278
+ lr_scheduler_type = lr_scheduler_type,
279
+ warmup_ratio = warmup_ratio,
280
+ warmup_steps = warmup_steps,
281
+ log_level = log_level,
282
+ log_level_replica = log_level_replica,
283
+ log_on_each_node = log_on_each_node,
284
+ logging_dir = logging_dir,
285
+ logging_strategy = logging_strategy,
286
+ logging_first_step = logging_first_step,
287
+ logging_steps = logging_steps,
288
+ logging_nan_inf_filter = logging_nan_inf_filter,
289
+ save_strategy = save_strategy,
290
+ save_steps = save_steps,
291
+ save_total_limit = save_total_limit,
292
+ save_safetensors = save_safetensors,
293
+ save_on_each_node = save_on_each_node,
294
+ save_only_model = save_only_model,
295
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
296
+ no_cuda = no_cuda,
297
+ use_cpu = use_cpu,
298
+ use_mps_device = use_mps_device,
299
+ seed = seed,
300
+ data_seed = data_seed,
301
+ jit_mode_eval = jit_mode_eval,
302
+ use_ipex = use_ipex,
303
+ bf16 = bf16,
304
+ fp16 = fp16,
305
+ fp16_opt_level = fp16_opt_level,
306
+ half_precision_backend = half_precision_backend,
307
+ bf16_full_eval = bf16_full_eval,
308
+ fp16_full_eval = fp16_full_eval,
309
+ tf32 = tf32,
310
+ local_rank = local_rank,
311
+ ddp_backend = ddp_backend,
312
+ tpu_num_cores = tpu_num_cores,
313
+ tpu_metrics_debug = tpu_metrics_debug,
314
+ debug = debug,
315
+ dataloader_drop_last = dataloader_drop_last,
316
+ eval_steps = eval_steps,
317
+ dataloader_num_workers = dataloader_num_workers,
318
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
319
+ past_index = past_index,
320
+ run_name = run_name,
321
+ disable_tqdm = disable_tqdm,
322
+ remove_unused_columns = remove_unused_columns,
323
+ label_names = label_names,
324
+ load_best_model_at_end = load_best_model_at_end,
325
+ metric_for_best_model = metric_for_best_model,
326
+ greater_is_better = greater_is_better,
327
+ ignore_data_skip = ignore_data_skip,
328
+ fsdp = fsdp,
329
+ fsdp_min_num_params = fsdp_min_num_params,
330
+ fsdp_config = fsdp_config,
331
+ tp_size = tp_size,
332
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
333
+ accelerator_config = accelerator_config,
334
+ deepspeed = deepspeed,
335
+ label_smoothing_factor = label_smoothing_factor,
336
+ optim = optim,
337
+ optim_args = optim_args,
338
+ adafactor = adafactor,
339
+ group_by_length = group_by_length,
340
+ length_column_name = length_column_name,
341
+ report_to = report_to,
342
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
343
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
344
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
345
+ dataloader_pin_memory = dataloader_pin_memory,
346
+ dataloader_persistent_workers = dataloader_persistent_workers,
347
+ skip_memory_metrics = skip_memory_metrics,
348
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
349
+ push_to_hub = push_to_hub,
350
+ resume_from_checkpoint = resume_from_checkpoint,
351
+ hub_model_id = hub_model_id,
352
+ hub_strategy = hub_strategy,
353
+ hub_token = hub_token,
354
+ hub_private_repo = hub_private_repo,
355
+ hub_always_push = hub_always_push,
356
+ gradient_checkpointing = gradient_checkpointing,
357
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
358
+ include_inputs_for_metrics = include_inputs_for_metrics,
359
+ eval_do_concat_batches = eval_do_concat_batches,
360
+ fp16_backend = fp16_backend,
361
+ evaluation_strategy = evaluation_strategy,
362
+ push_to_hub_model_id = push_to_hub_model_id,
363
+ push_to_hub_organization = push_to_hub_organization,
364
+ push_to_hub_token = push_to_hub_token,
365
+ mp_parameters = mp_parameters,
366
+ auto_find_batch_size = auto_find_batch_size,
367
+ full_determinism = full_determinism,
368
+ torchdynamo = torchdynamo,
369
+ ray_scope = ray_scope,
370
+ ddp_timeout = ddp_timeout,
371
+ torch_compile = torch_compile,
372
+ torch_compile_backend = torch_compile_backend,
373
+ torch_compile_mode = torch_compile_mode,
374
+ dispatch_batches = dispatch_batches,
375
+ split_batches = split_batches,
376
+ include_tokens_per_second = include_tokens_per_second,
377
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
378
+ neftune_noise_alpha = neftune_noise_alpha,
379
+ optim_target_modules = optim_target_modules,
380
+ batch_eval_metrics = batch_eval_metrics,
381
+ eval_on_start = eval_on_start,
382
+ use_liger_kernel = use_liger_kernel,
383
+ eval_use_gather_object = eval_use_gather_object,
384
+ average_tokens_across_devices = average_tokens_across_devices,
385
+ model_init_kwargs = model_init_kwargs,
386
+ use_liger = use_liger,
387
+ dataset_text_field = dataset_text_field,
388
+ dataset_kwargs = dataset_kwargs,
389
+ dataset_num_proc = dataset_num_proc,
390
+ max_seq_length = max_seq_length,
391
+ packing = packing,
392
+ eval_packing = eval_packing,
393
+ dataset_batch_size = dataset_batch_size,
394
+ num_of_sequences = num_of_sequences,
395
+ chars_per_token = chars_per_token,**kwargs)
396
+ self.vllm_sampling_params = vllm_sampling_params
397
+ self.unsloth_num_chunks = unsloth_num_chunks
398
+ pass
399
+
400
+ class _UnslothSFTTrainer(Trainer):
401
+ """"""
402
+
403
+ _tag_names = ["trl", "sft"]
404
+
405
+ @deprecate_kwarg(
406
+ "tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
407
+ )
408
+ def __init__(
409
+ self,
410
+ model: Union[str, nn.Module, PreTrainedModel],
411
+ args: Optional[Union[SFTConfig, TrainingArguments]] = None,
412
+ data_collator: Optional[DataCollator] = None, # type: ignore
413
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
414
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
415
+ processing_class: Optional[
416
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
417
+ ] = None,
418
+ compute_loss_func: Optional[Callable] = None,
419
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
420
+ callbacks: Optional[list[TrainerCallback]] = None,
421
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
422
+ optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
423
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
424
+ peft_config: Optional["PeftConfig"] = None,
425
+ formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
426
+ ):
427
+ # Args
428
+ if args is None:
429
+ model_name = model if isinstance(model, str) else model.config._name_or_path
430
+ model_name = model_name.split("/")[-1]
431
+ args = SFTConfig(f"{model_name}-SFT")
432
+ elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
433
+ dict_args = args.to_dict()
434
+ dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
435
+ dict_args.pop("push_to_hub_token")
436
+ args = SFTConfig(**dict_args)
437
+
438
+ # Model
439
+ if args.model_init_kwargs is not None and not isinstance(model, str):
440
+ warnings.warn(
441
+ "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
442
+ "The `model_init_kwargs` will be ignored."
443
+ )
444
+ if isinstance(model, str):
445
+ model = self._create_model_from_path(model, args)
446
+
447
+ # PEFT configuration and model wrapping
448
+ if False:
449
+ model = self._prepare_peft_model(model, peft_config, args)
450
+
451
+ # Handle the tokenizer
452
+ if processing_class is None:
453
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
454
+ if processing_class.pad_token is None:
455
+ processing_class.pad_token = processing_class.eos_token # required for padding when collating data
456
+
457
+ # Dataset
458
+ preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
459
+ if preprocess_dataset:
460
+ train_dataset = self._prepare_dataset(
461
+ train_dataset, processing_class, args, args.packing, formatting_func, "train"
462
+ )
463
+ if eval_dataset is not None:
464
+ packing = args.packing if args.eval_packing is None else args.eval_packing
465
+ if isinstance(eval_dataset, dict):
466
+ eval_dataset = {
467
+ key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
468
+ for key, dataset in eval_dataset.items()
469
+ }
470
+ else:
471
+ eval_dataset = self._prepare_dataset(
472
+ eval_dataset, processing_class, args, packing, formatting_func, "eval"
473
+ )
474
+
475
+ # Data collator
476
+ if data_collator is None:
477
+ data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
478
+
479
+ # Initialize the metrics
480
+ self._metrics = defaultdict(list)
481
+
482
+ # Initialize the Trainer. Parent class will handle:
483
+ # - DeepSpeed configuration (through create_accelerator_and_postprocess)
484
+ # - FSDP setup
485
+ # - Distributed training setup
486
+ # - Optimizer and scheduler creation
487
+ # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
488
+ super_init_kwargs = {}
489
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
490
+ super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
491
+ else:
492
+ if optimizer_cls_and_kwargs is not None:
493
+ warnings.warn(
494
+ "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
495
+ "The default optimizer will be used. "
496
+ "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
497
+ )
498
+ super().__init__(
499
+ model=model,
500
+ args=args,
501
+ data_collator=data_collator,
502
+ train_dataset=train_dataset,
503
+ eval_dataset=eval_dataset,
504
+ processing_class=processing_class,
505
+ compute_loss_func=compute_loss_func,
506
+ compute_metrics=compute_metrics,
507
+ callbacks=callbacks,
508
+ optimizers=optimizers,
509
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
510
+ **super_init_kwargs,
511
+ )
512
+
513
+ # Add tags for models that have been loaded with the correct transformers version
514
+ if hasattr(self.model, "add_model_tags"):
515
+ self.model.add_model_tags(self._tag_names)
516
+
517
+ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
518
+ """Creates a model from a path or model identifier."""
519
+ model_init_kwargs = args.model_init_kwargs or {}
520
+ # Handle torch dtype
521
+ torch_dtype = model_init_kwargs.get("torch_dtype")
522
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
523
+ pass # torch_dtype is already a torch.dtype or "auto" or None
524
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
525
+ torch_dtype = getattr(torch, torch_dtype)
526
+ model_init_kwargs["torch_dtype"] = torch_dtype
527
+ else:
528
+ raise ValueError(
529
+ "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
530
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
531
+ )
532
+ # Disable caching if gradient checkpointing is enabled (not supported)
533
+ if args.gradient_checkpointing:
534
+ model_init_kwargs["use_cache"] = False
535
+
536
+ # Create model
537
+ if args.use_liger:
538
+ if not is_liger_kernel_available():
539
+ raise ImportError("Please install Liger-kernel for use_liger=True")
540
+ model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
541
+ else:
542
+ model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
543
+ return model
544
+
545
+ def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
546
+ """Prepares a model for PEFT training."""
547
+ if not is_peft_available():
548
+ raise ImportError("To use PeftModel, you need to install the `peft` library.")
549
+
550
+ if not isinstance(peft_config, PeftConfig):
551
+ raise ValueError(
552
+ f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
553
+ "to pass a PeftConfig object to the SFTTrainer."
554
+ )
555
+
556
+ if isinstance(model, PeftModel):
557
+ return model
558
+
559
+ # Handle quantized models (QLoRA)
560
+ is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
561
+
562
+ is_sharded_qlora = False
563
+ if getattr(model, "is_loaded_in_4bit", False):
564
+ # Check if model is sharded (FSDP/DS-Zero3)
565
+ for _, param in model.named_parameters():
566
+ if param.__class__.__name__ == "Params4bit":
567
+ is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
568
+ break
569
+
570
+ # Prepare model for kbit training if needed
571
+ if is_qlora and not is_sharded_qlora:
572
+ model = self._prepare_model_for_kbit_training(model, args)
573
+ # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
574
+ args = dataclasses.replace(args, gradient_checkpointing=False)
575
+ elif args.gradient_checkpointing:
576
+ model = self._enable_gradient_checkpointing(model, args)
577
+
578
+ # Create PEFT model
579
+ if (
580
+ version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
581
+ and getattr(model, "is_loaded_in_4bit", False)
582
+ and is_sharded_qlora
583
+ ):
584
+ model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
585
+ else:
586
+ model = get_peft_model(model, peft_config)
587
+
588
+ # Handle bf16 casting for 4-bit models
589
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
590
+ peft_module_casting_to_bf16(model)
591
+
592
+ return model
593
+
594
+ def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
595
+ """Prepares a quantized model for kbit training."""
596
+ prepare_model_kwargs = {
597
+ "use_gradient_checkpointing": args.gradient_checkpointing,
598
+ "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
599
+ }
600
+
601
+ return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
602
+
603
+ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
604
+ """Enables gradient checkpointing for the model."""
605
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
606
+ use_reentrant = (
607
+ "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
608
+ )
609
+
610
+ if use_reentrant:
611
+ if hasattr(model, "enable_input_require_grads"):
612
+ model.enable_input_require_grads()
613
+ else:
614
+
615
+ def make_inputs_require_grad(module, input, output):
616
+ output.requires_grad_(True)
617
+
618
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
619
+
620
+ return model
621
+
622
+ def _prepare_dataset(
623
+ self,
624
+ dataset: Union[Dataset, IterableDataset],
625
+ processing_class,
626
+ args,
627
+ packing: bool,
628
+ formatting_func: Optional[Callable[[dict], str]],
629
+ dataset_name: str,
630
+ ) -> Union[Dataset, IterableDataset]:
631
+ # All Unsloth Zoo code licensed under LGPLv3
632
+ if isinstance(dataset, ConstantLengthDataset): return dataset
633
+
634
+ map_kwargs = {}
635
+ use_desc = isinstance(dataset, Dataset)
636
+ is_vlm = hasattr(processing_class, "tokenizer")
637
+ tokenizer = processing_class
638
+ if is_vlm: tokenizer = processing_class.tokenizer
639
+
640
+ # Get max length
641
+ max_seq_length = getattr(args, "max_length", 0)
642
+ if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
643
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
644
+ if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
645
+ if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
646
+ dataset_text_field = getattr(args, "dataset_text_field", "text")
647
+ do_truncation = max_seq_length != 0
648
+ do_formatting_func = False
649
+ do_tokenize = True
650
+
651
+ # Get correct column names
652
+ column_names = set(next(iter(dataset)).keys())
653
+ used_column_names = ["input_ids"]
654
+ if "attention_mask" in column_names:
655
+ used_column_names.append("attention_mask")
656
+
657
+ # Check if already tokenized so skip
658
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
659
+ if "labels" in column_names:
660
+ # Most likely forgot data collator!
661
+ if is_vlm and not hasattr(tokenizer, "pad"):
662
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
663
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
664
+ self.data_collator = DataCollatorForSeq2Seq(tokenizer)
665
+ used_column_names.append("labels")
666
+ do_tokenize = False
667
+ elif "input_ids" in column_names:
668
+ # Skip dataset prep, and set data collator
669
+ if is_vlm and not hasattr(tokenizer, "pad"):
670
+ # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
671
+ raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
672
+ self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
673
+ do_tokenize = False
674
+ elif dataset_text_field not in column_names:
675
+ do_formatting_func = True
676
+ if formatting_func is None:
677
+ raise RuntimeError("Unsloth: You must specify a `formatting_func`")
678
+ pass
679
+
680
+ if do_tokenize:
681
+ # Check double BOS tokens
682
+ if do_formatting_func:
683
+ test_text = formatting_func(dataset[0])
684
+ if not isinstance(test_text, list):
685
+ raise ValueError(
686
+ "Unsloth: The `formatting_func` should return a list of processed strings."
687
+ )
688
+ test_text = test_text[0]
689
+ else:
690
+ test_text = dataset[0][dataset_text_field]
691
+
692
+ # Get chat template
693
+ chat_template = getattr(processing_class, 'chat_template', '')
694
+ if chat_template == '' and is_vlm:
695
+ chat_template = getattr(tokenizer, 'chat_template', '')
696
+ if chat_template is None:
697
+ chat_template = ''
698
+
699
+ # Get bos_token
700
+ add_special_tokens = True
701
+ bos_token_1 = getattr(processing_class, 'bos_token', None)
702
+ bos_token_2 = getattr(tokenizer, 'bos_token', None)
703
+ bos_token = bos_token_1 or bos_token_2
704
+
705
+ if bos_token is not None:
706
+ if test_text.startswith(bos_token) or bos_token in chat_template:
707
+ add_special_tokens = False
708
+ print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
709
+ pass
710
+
711
+ # Create tokenize function
712
+ def _tokenize(example):
713
+ return tokenizer(
714
+ example[dataset_text_field] if not do_formatting_func else formatting_func(example),
715
+ truncation = do_truncation,
716
+ max_length = max_seq_length,
717
+ return_token_type_ids = False,
718
+ add_special_tokens = add_special_tokens,
719
+ )
720
+ pass
721
+
722
+ map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
723
+ if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
724
+ dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
725
+
726
+ # If VLM, switch data collator since .pad is needed!
727
+ if is_vlm and not hasattr(processing_class, "pad"):
728
+ data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
729
+ self.data_collator = data_collator
730
+ pass
731
+ pass
732
+ if packing:
733
+ print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
734
+ return dataset
735
+
736
+ if max_seq_length == 0:
737
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
738
+
739
+ if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
740
+ dataset = dataset.select_columns(used_column_names).map(
741
+ pack_examples,
742
+ batched = True,
743
+ fn_kwargs = {"seq_length": max_seq_length,},
744
+ **map_kwargs,
745
+ )
746
+ pass
747
+ return dataset
748
+
749
+ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
750
+ outputs = super().compute_loss(
751
+ model,
752
+ inputs,
753
+ return_outputs = return_outputs,
754
+ num_items_in_batch = num_items_in_batch,
755
+ )
756
+ return outputs
757
+
758
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
759
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
760
+
761
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
762
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
763
+ if next(iter(logs.keys())).startswith("eval_"):
764
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
765
+
766
+ logs = {**logs, **metrics}
767
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
768
+ super().log(logs, start_time)
769
+ else: # transformers<=4.46
770
+ super().log(logs)
771
+ self._metrics.clear()
772
+
773
+ def create_model_card(
774
+ self,
775
+ model_name: Optional[str] = None,
776
+ dataset_name: Optional[str] = None,
777
+ tags: Union[str, list[str], None] = None,
778
+ ):
779
+ """
780
+ Creates a draft of a model card using the information available to the `Trainer`.
781
+
782
+ Args:
783
+ model_name (`str` or `None`, *optional*, defaults to `None`):
784
+ Name of the model.
785
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
786
+ Name of the dataset used for training.
787
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
788
+ Tags to be associated with the model card.
789
+ """
790
+ if not self.is_world_process_zero():
791
+ return
792
+
793
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
794
+ base_model = self.model.config._name_or_path
795
+ else:
796
+ base_model = None
797
+
798
+ tags = tags or []
799
+ if isinstance(tags, str):
800
+ tags = [tags]
801
+
802
+ if hasattr(self.model.config, "unsloth_version"):
803
+ tags.append("unsloth")
804
+
805
+ model_card = generate_model_card(
806
+ base_model=base_model,
807
+ model_name=model_name,
808
+ hub_model_id=self.hub_model_id,
809
+ dataset_name=dataset_name,
810
+ tags=tags,
811
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
812
+ comet_url=get_comet_experiment_url(),
813
+ trainer_name="SFT",
814
+ )
815
+
816
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
817
+ class UnslothSFTTrainer(_UnslothSFTTrainer):
818
+ """
819
+
820
+ Trainer for Supervised Fine-Tuning (SFT) method.
821
+
822
+ This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
823
+
824
+ Example:
825
+
826
+ ```python
827
+ from datasets import load_dataset
828
+ from trl import SFTTrainer
829
+
830
+ dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
831
+
832
+ trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
833
+ trainer.train()
834
+ ```
835
+
836
+ Args:
837
+ model (`Union[str, PreTrainedModel]`):
838
+ Model to be trained. Can be either:
839
+
840
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
841
+ a path to a *directory* containing model weights saved using
842
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
843
+ loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
844
+ in `args.model_init_kwargs`.
845
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
846
+ args ([`SFTConfig`], *optional*, defaults to `None`):
847
+ Configuration for this trainer. If `None`, a default configuration is used.
848
+ data_collator (`DataCollator`, *optional*):
849
+ Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
850
+ Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
851
+ of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
852
+ tokenizer.
853
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
854
+ Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
855
+ [prompt-completion](#prompt-completion) type. The format of the samples can be either:
856
+
857
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
858
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
859
+ and content).
860
+
861
+ The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
862
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
863
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
864
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
865
+ Processing class used to process the data. If `None`, the processing class is loaded from the model's name
866
+ with [`~transformers.AutoTokenizer.from_pretrained`].
867
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
868
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks
869
+ detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
870
+
871
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
872
+ method.
873
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
874
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
875
+ model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
876
+ optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
877
+ A tuple containing the optimizer class and keyword arguments to use.
878
+ Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
879
+
880
+ Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
881
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
882
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
883
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
884
+ by this function will be reflected in the predictions received by `compute_metrics`.
885
+
886
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
887
+ peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
888
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
889
+ formatting_func (`Optional[Callable]`):
890
+ Formatting function applied to the dataset before tokenization.
891
+
892
+ """
893
+ def __init__(
894
+ self,
895
+ model,
896
+ args = None,
897
+ data_collator = None,
898
+ train_dataset = None,
899
+ eval_dataset = None,
900
+ processing_class = None,
901
+ compute_loss_func = None,
902
+ compute_metrics = None,
903
+ callbacks = None,
904
+ optimizer_cls_and_kwargs = None,
905
+ preprocess_logits_for_metrics = None,
906
+ peft_config = None,
907
+ formatting_func = None,
908
+ **kwargs
909
+ ):
910
+ if args is None: args = UnslothSFTConfig()
911
+ use_bf16 = getattr(args, 'bf16', False)
912
+ use_fp16 = getattr(args, 'fp16', False)
913
+ force_float32 = False
914
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
915
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
916
+ force_float32 = True
917
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
918
+ dtype = getattr(model.config, 'torch_dtype', None)
919
+ if dtype is None: dtype = model.get_input_embeddings().dtype
920
+ from unsloth_zoo.utils import _get_dtype
921
+ dtype = _get_dtype(dtype)
922
+ float16 = dtype == torch.float16
923
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
924
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
925
+ if force_float32:
926
+ args.fp16 = False
927
+ args.bf16 = False
928
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
929
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
930
+ args.fp16 = float16
931
+ args.bf16 = not float16
932
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
933
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
934
+ args.eval_strategy = 'steps'
935
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
936
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
937
+ if ga_steps is not None and ga_steps > 1:
938
+ from transformers import __version__ as transformers_version
939
+ if Version(transformers_version) <= Version('4.45.2'):
940
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
941
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
942
+ if getattr(args, 'eval_strategy', 'no') != 'no':
943
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
944
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
945
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
946
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
947
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
948
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
949
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
950
+ if force_float32:
951
+ args.bf16_full_eval = False
952
+ args.fp16_full_eval = False
953
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
954
+ args.bf16_full_eval = True
955
+ args.fp16_full_eval = False
956
+ elif not bf16_full_eval and not fp16_full_eval:
957
+ args.bf16_full_eval = args.bf16
958
+ args.fp16_full_eval = args.fp16
959
+ _output_logits = False
960
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
961
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
962
+ if _output_logits:
963
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
964
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
965
+ pass
966
+ else:
967
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
968
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
969
+ if args_max_seq_length is None and model_max_seq_length is not None:
970
+ max_seq_length = model.max_seq_length
971
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
972
+ if model is not None and hasattr(model, 'for_training'):
973
+ model.for_training()
974
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
975
+ if 'processing_class' in locals():
976
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
977
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
978
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
979
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
980
+ if not isinstance(data_collator, UnslothVisionDataCollator):
981
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
982
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
983
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
984
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
985
+ else:
986
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
987
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
988
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
989
+ if not isinstance(data_collator, UnslothVisionDataCollator):
990
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
991
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
992
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
993
+ else:
994
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
995
+ other_metrics = []
996
+
997
+ from unsloth_zoo.logging_utils import PatchRLStatistics
998
+ PatchRLStatistics('sft_trainer', other_metrics)
999
+ IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
1000
+ from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
1001
+ from unsloth_zoo.training_utils import fix_zero_training_loss
1002
+ if 'tokenizer' not in locals(): tokenizer = processing_class
1003
+ fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1004
+ fix_zero_training_loss(model, tokenizer, train_dataset)
1005
+
1006
+ super().__init__(
1007
+ model = model,
1008
+ args = args,
1009
+ data_collator = data_collator,
1010
+ train_dataset = train_dataset,
1011
+ eval_dataset = eval_dataset,
1012
+ processing_class = processing_class,
1013
+ compute_loss_func = compute_loss_func,
1014
+ compute_metrics = compute_metrics,
1015
+ callbacks = callbacks,
1016
+ optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1017
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1018
+ peft_config = peft_config,
1019
+ formatting_func = formatting_func,**kwargs)
1020
+ if hasattr(self, 'neftune_hook_handle'):
1021
+ self.neftune_hook_handle.remove()
1022
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1023
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1024
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1025
+ pass
1026
+
1027
+ pass
unsloth_compiled_cache/UnslothXPOTrainer.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.15
3
+ 2025.3.17
4
+ 4.50.0.dev0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothXPOConfig(XPOConfig):
44
+ """
45
+
46
+ Configuration class for the [`XPOTrainer`].
47
+
48
+ Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
49
+
50
+ Parameters:
51
+ alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
52
+ Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
53
+ and the last alpha is used for the rest of the epochs.
54
+
55
+ """
56
+ vllm_sampling_params: Optional[Any] = field(
57
+ default = None,
58
+ metadata = {'help': 'vLLM SamplingParams'},
59
+ )
60
+ unsloth_num_chunks : Optional[int] = field(
61
+ default = -1,
62
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
63
+ )
64
+ def __init__(
65
+ self,
66
+ output_dir = None,
67
+ overwrite_output_dir = None,
68
+ do_train = False,
69
+ do_eval = False,
70
+ do_predict = False,
71
+ eval_strategy = 'no',
72
+ prediction_loss_only = False,
73
+ per_device_train_batch_size = 4,
74
+ per_device_eval_batch_size = 4,
75
+ per_gpu_train_batch_size = None,
76
+ per_gpu_eval_batch_size = None,
77
+ gradient_accumulation_steps = 2,
78
+ eval_accumulation_steps = 2,
79
+ eval_delay = 0,
80
+ torch_empty_cache_steps = 250,
81
+ learning_rate = 5e-05,
82
+ weight_decay = 0.01,
83
+ adam_beta1 = 0.9,
84
+ adam_beta2 = 0.999,
85
+ adam_epsilon = 1e-08,
86
+ max_grad_norm = 1.0,
87
+ num_train_epochs = 3.0,
88
+ max_steps = -1,
89
+ lr_scheduler_type = 'linear',
90
+ warmup_ratio = 0.1,
91
+ warmup_steps = 0,
92
+ log_level = 'passive',
93
+ log_level_replica = 'warning',
94
+ log_on_each_node = True,
95
+ logging_dir = None,
96
+ logging_strategy = 'steps',
97
+ logging_first_step = False,
98
+ logging_steps = 1,
99
+ logging_nan_inf_filter = False,
100
+ save_strategy = 'steps',
101
+ save_steps = 500,
102
+ save_total_limit = None,
103
+ save_safetensors = True,
104
+ save_on_each_node = False,
105
+ save_only_model = False,
106
+ restore_callback_states_from_checkpoint = False,
107
+ no_cuda = False,
108
+ use_cpu = False,
109
+ use_mps_device = False,
110
+ seed = 3407,
111
+ data_seed = 3407,
112
+ jit_mode_eval = False,
113
+ use_ipex = False,
114
+ bf16 = False,
115
+ fp16 = False,
116
+ fp16_opt_level = 'O1',
117
+ half_precision_backend = 'auto',
118
+ bf16_full_eval = False,
119
+ fp16_full_eval = False,
120
+ tf32 = None,
121
+ local_rank = -1,
122
+ ddp_backend = None,
123
+ tpu_num_cores = None,
124
+ tpu_metrics_debug = False,
125
+ debug = '',
126
+ dataloader_drop_last = False,
127
+ eval_steps = None,
128
+ dataloader_num_workers = 0,
129
+ dataloader_prefetch_factor = None,
130
+ past_index = -1,
131
+ run_name = None,
132
+ disable_tqdm = None,
133
+ remove_unused_columns = True,
134
+ label_names = None,
135
+ load_best_model_at_end = False,
136
+ metric_for_best_model = None,
137
+ greater_is_better = None,
138
+ ignore_data_skip = False,
139
+ fsdp = '',
140
+ fsdp_min_num_params = 0,
141
+ fsdp_config = None,
142
+ tp_size = 0,
143
+ fsdp_transformer_layer_cls_to_wrap = None,
144
+ accelerator_config = None,
145
+ deepspeed = None,
146
+ label_smoothing_factor = 0.0,
147
+ optim = 'adamw_8bit',
148
+ optim_args = None,
149
+ adafactor = False,
150
+ group_by_length = False,
151
+ length_column_name = 'length',
152
+ report_to = None,
153
+ ddp_find_unused_parameters = None,
154
+ ddp_bucket_cap_mb = None,
155
+ ddp_broadcast_buffers = None,
156
+ dataloader_pin_memory = True,
157
+ dataloader_persistent_workers = False,
158
+ skip_memory_metrics = True,
159
+ use_legacy_prediction_loop = False,
160
+ push_to_hub = False,
161
+ resume_from_checkpoint = None,
162
+ hub_model_id = None,
163
+ hub_strategy = 'every_save',
164
+ hub_token = None,
165
+ hub_private_repo = None,
166
+ hub_always_push = False,
167
+ gradient_checkpointing = False,
168
+ gradient_checkpointing_kwargs = None,
169
+ include_inputs_for_metrics = False,
170
+ eval_do_concat_batches = True,
171
+ fp16_backend = 'auto',
172
+ evaluation_strategy = None,
173
+ push_to_hub_model_id = None,
174
+ push_to_hub_organization = None,
175
+ push_to_hub_token = None,
176
+ mp_parameters = '',
177
+ auto_find_batch_size = False,
178
+ full_determinism = False,
179
+ torchdynamo = None,
180
+ ray_scope = 'last',
181
+ ddp_timeout = 1800,
182
+ torch_compile = False,
183
+ torch_compile_backend = None,
184
+ torch_compile_mode = None,
185
+ dispatch_batches = None,
186
+ split_batches = None,
187
+ include_tokens_per_second = False,
188
+ include_num_input_tokens_seen = False,
189
+ neftune_noise_alpha = None,
190
+ optim_target_modules = None,
191
+ batch_eval_metrics = False,
192
+ eval_on_start = False,
193
+ use_liger_kernel = False,
194
+ eval_use_gather_object = False,
195
+ average_tokens_across_devices = False,
196
+ reward_model_path = None,
197
+ judge = None,
198
+ max_new_tokens = 64,
199
+ max_length = 512,
200
+ temperature = 0.9,
201
+ missing_eos_penalty = None,
202
+ loss_type = 'sigmoid',
203
+ dataset_num_proc = None,
204
+ disable_dropout = True,
205
+ use_vllm = False,
206
+ ds3_gather_for_generation = True,
207
+ vllm_sampling_params = None,
208
+ unsloth_num_chunks = -1,
209
+ **kwargs,
210
+ ):
211
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
212
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
213
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
214
+ output_dir = 'unsloth_training_checkpoints'
215
+ save_strategy = 'no'
216
+ if dataset_num_proc is None:
217
+ from multiprocessing import cpu_count
218
+ dataset_num_proc = cpu_count()
219
+
220
+ super().__init__(
221
+ output_dir = output_dir,
222
+ overwrite_output_dir = overwrite_output_dir,
223
+ do_train = do_train,
224
+ do_eval = do_eval,
225
+ do_predict = do_predict,
226
+ eval_strategy = eval_strategy,
227
+ prediction_loss_only = prediction_loss_only,
228
+ per_device_train_batch_size = per_device_train_batch_size,
229
+ per_device_eval_batch_size = per_device_eval_batch_size,
230
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
231
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
232
+ gradient_accumulation_steps = gradient_accumulation_steps,
233
+ eval_accumulation_steps = eval_accumulation_steps,
234
+ eval_delay = eval_delay,
235
+ torch_empty_cache_steps = torch_empty_cache_steps,
236
+ learning_rate = learning_rate,
237
+ weight_decay = weight_decay,
238
+ adam_beta1 = adam_beta1,
239
+ adam_beta2 = adam_beta2,
240
+ adam_epsilon = adam_epsilon,
241
+ max_grad_norm = max_grad_norm,
242
+ num_train_epochs = num_train_epochs,
243
+ max_steps = max_steps,
244
+ lr_scheduler_type = lr_scheduler_type,
245
+ warmup_ratio = warmup_ratio,
246
+ warmup_steps = warmup_steps,
247
+ log_level = log_level,
248
+ log_level_replica = log_level_replica,
249
+ log_on_each_node = log_on_each_node,
250
+ logging_dir = logging_dir,
251
+ logging_strategy = logging_strategy,
252
+ logging_first_step = logging_first_step,
253
+ logging_steps = logging_steps,
254
+ logging_nan_inf_filter = logging_nan_inf_filter,
255
+ save_strategy = save_strategy,
256
+ save_steps = save_steps,
257
+ save_total_limit = save_total_limit,
258
+ save_safetensors = save_safetensors,
259
+ save_on_each_node = save_on_each_node,
260
+ save_only_model = save_only_model,
261
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
262
+ no_cuda = no_cuda,
263
+ use_cpu = use_cpu,
264
+ use_mps_device = use_mps_device,
265
+ seed = seed,
266
+ data_seed = data_seed,
267
+ jit_mode_eval = jit_mode_eval,
268
+ use_ipex = use_ipex,
269
+ bf16 = bf16,
270
+ fp16 = fp16,
271
+ fp16_opt_level = fp16_opt_level,
272
+ half_precision_backend = half_precision_backend,
273
+ bf16_full_eval = bf16_full_eval,
274
+ fp16_full_eval = fp16_full_eval,
275
+ tf32 = tf32,
276
+ local_rank = local_rank,
277
+ ddp_backend = ddp_backend,
278
+ tpu_num_cores = tpu_num_cores,
279
+ tpu_metrics_debug = tpu_metrics_debug,
280
+ debug = debug,
281
+ dataloader_drop_last = dataloader_drop_last,
282
+ eval_steps = eval_steps,
283
+ dataloader_num_workers = dataloader_num_workers,
284
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
285
+ past_index = past_index,
286
+ run_name = run_name,
287
+ disable_tqdm = disable_tqdm,
288
+ remove_unused_columns = remove_unused_columns,
289
+ label_names = label_names,
290
+ load_best_model_at_end = load_best_model_at_end,
291
+ metric_for_best_model = metric_for_best_model,
292
+ greater_is_better = greater_is_better,
293
+ ignore_data_skip = ignore_data_skip,
294
+ fsdp = fsdp,
295
+ fsdp_min_num_params = fsdp_min_num_params,
296
+ fsdp_config = fsdp_config,
297
+ tp_size = tp_size,
298
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
299
+ accelerator_config = accelerator_config,
300
+ deepspeed = deepspeed,
301
+ label_smoothing_factor = label_smoothing_factor,
302
+ optim = optim,
303
+ optim_args = optim_args,
304
+ adafactor = adafactor,
305
+ group_by_length = group_by_length,
306
+ length_column_name = length_column_name,
307
+ report_to = report_to,
308
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
309
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
310
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
311
+ dataloader_pin_memory = dataloader_pin_memory,
312
+ dataloader_persistent_workers = dataloader_persistent_workers,
313
+ skip_memory_metrics = skip_memory_metrics,
314
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
315
+ push_to_hub = push_to_hub,
316
+ resume_from_checkpoint = resume_from_checkpoint,
317
+ hub_model_id = hub_model_id,
318
+ hub_strategy = hub_strategy,
319
+ hub_token = hub_token,
320
+ hub_private_repo = hub_private_repo,
321
+ hub_always_push = hub_always_push,
322
+ gradient_checkpointing = gradient_checkpointing,
323
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
324
+ include_inputs_for_metrics = include_inputs_for_metrics,
325
+ eval_do_concat_batches = eval_do_concat_batches,
326
+ fp16_backend = fp16_backend,
327
+ evaluation_strategy = evaluation_strategy,
328
+ push_to_hub_model_id = push_to_hub_model_id,
329
+ push_to_hub_organization = push_to_hub_organization,
330
+ push_to_hub_token = push_to_hub_token,
331
+ mp_parameters = mp_parameters,
332
+ auto_find_batch_size = auto_find_batch_size,
333
+ full_determinism = full_determinism,
334
+ torchdynamo = torchdynamo,
335
+ ray_scope = ray_scope,
336
+ ddp_timeout = ddp_timeout,
337
+ torch_compile = torch_compile,
338
+ torch_compile_backend = torch_compile_backend,
339
+ torch_compile_mode = torch_compile_mode,
340
+ dispatch_batches = dispatch_batches,
341
+ split_batches = split_batches,
342
+ include_tokens_per_second = include_tokens_per_second,
343
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
344
+ neftune_noise_alpha = neftune_noise_alpha,
345
+ optim_target_modules = optim_target_modules,
346
+ batch_eval_metrics = batch_eval_metrics,
347
+ eval_on_start = eval_on_start,
348
+ use_liger_kernel = use_liger_kernel,
349
+ eval_use_gather_object = eval_use_gather_object,
350
+ average_tokens_across_devices = average_tokens_across_devices,
351
+ reward_model_path = reward_model_path,
352
+ judge = judge,
353
+ max_new_tokens = max_new_tokens,
354
+ max_length = max_length,
355
+ temperature = temperature,
356
+ missing_eos_penalty = missing_eos_penalty,
357
+ loss_type = loss_type,
358
+ dataset_num_proc = dataset_num_proc,
359
+ disable_dropout = disable_dropout,
360
+ use_vllm = use_vllm,
361
+ ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
362
+ self.vllm_sampling_params = vllm_sampling_params
363
+ self.unsloth_num_chunks = unsloth_num_chunks
364
+ pass
365
+
366
+ class _UnslothXPOTrainer(OnlineDPOTrainer):
367
+ r""""""
368
+
369
+ _tag_names = ["trl", "xpo"]
370
+
371
+ def __init__(
372
+ self,
373
+ model: Union[PreTrainedModel, nn.Module] = None,
374
+ ref_model: Union[PreTrainedModel, nn.Module] = None,
375
+ reward_model: Optional[nn.Module] = None,
376
+ judge: Optional[BasePairwiseJudge] = None,
377
+ args: Optional[XPOConfig] = None,
378
+ data_collator: Optional[Callable] = None,
379
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
380
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
381
+ processing_class: Optional[
382
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
383
+ ] = None,
384
+ peft_config: Optional[dict] = None,
385
+ compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
386
+ callbacks: Optional[list[TrainerCallback]] = None,
387
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
388
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
389
+ ) -> None:
390
+ super().__init__(
391
+ model=model,
392
+ ref_model=ref_model,
393
+ judge=judge,
394
+ reward_model=reward_model,
395
+ args=args,
396
+ data_collator=data_collator,
397
+ train_dataset=train_dataset,
398
+ eval_dataset=eval_dataset,
399
+ processing_class=processing_class,
400
+ reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
401
+ peft_config=peft_config,
402
+ compute_metrics=compute_metrics,
403
+ callbacks=callbacks,
404
+ optimizers=optimizers,
405
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
406
+ )
407
+
408
+ self._alpha = self.args.alpha
409
+
410
+ # Overwrite the stats dictionary to include XPO specific statistics
411
+ self.stats = {
412
+ # Remove "non_score_reward", "rlhf_reward", "scores"
413
+ # Add "loss/dpo", "loss/xpo"
414
+ "loss/dpo": [],
415
+ "loss/xpo": [],
416
+ "objective/kl": [],
417
+ "objective/entropy": [],
418
+ "rewards/chosen": [],
419
+ "rewards/rejected": [],
420
+ "rewards/accuracies": [],
421
+ "rewards/margins": [],
422
+ "logps/chosen": [],
423
+ "logps/rejected": [],
424
+ # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
425
+ "val/model_contain_eos_token": [],
426
+ "val/ref_contain_eos_token": [],
427
+ "alpha": [],
428
+ "beta": [],
429
+ }
430
+ if self.reward_model is not None:
431
+ # Replace "scores" by "model_scores" and "ref_scores"
432
+ self.stats["objective/model_scores"] = []
433
+ self.stats["objective/ref_scores"] = []
434
+ self.stats["objective/scores_margin"] = []
435
+
436
+ @property
437
+ def alpha(self):
438
+ if isinstance(self._alpha, list):
439
+ epoch = self.state.epoch
440
+ return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
441
+ else:
442
+ return self._alpha
443
+
444
+ def _generate_completions(self, prompts, model):
445
+ with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
446
+ model_output = unwrapped_model.generate(
447
+ input_ids=prompts["input_ids"],
448
+ attention_mask=prompts["attention_mask"],
449
+ generation_config=self.generation_config,
450
+ )
451
+
452
+ ref_model = model if self.ref_model is None else self.ref_model
453
+ with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
454
+ ref_output = unwrapped_ref_model.generate(
455
+ input_ids=prompts["input_ids"],
456
+ attention_mask=prompts["attention_mask"],
457
+ generation_config=self.generation_config,
458
+ )
459
+
460
+ return model_output, ref_output
461
+
462
+ def _process_completions(self, model_output, ref_output, prompts):
463
+ context_length = prompts["input_ids"].shape[1]
464
+
465
+ # Process model completions
466
+ model_completion_ids = model_output[:, context_length:]
467
+ model_completion_ids, model_completion_mask = truncate_right(
468
+ model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
469
+ )
470
+ model_data = {
471
+ "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
472
+ "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
473
+ "raw": prompts["raw"],
474
+ }
475
+
476
+ # Process reference model completions
477
+ ref_completion_ids = ref_output[:, context_length:]
478
+ ref_completion_ids, ref_completion_mask = truncate_right(
479
+ ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
480
+ )
481
+ ref_data = {
482
+ "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
483
+ "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
484
+ "raw": prompts["raw"],
485
+ }
486
+
487
+ return model_data, ref_data
488
+
489
+ def _compute_rewards(self, model_data, ref_data, context_length):
490
+ with torch.no_grad():
491
+ _, model_scores, _ = get_reward(
492
+ self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
493
+ )
494
+ _, ref_scores, _ = get_reward(
495
+ self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
496
+ )
497
+
498
+ # Apply EOS penalty if needed
499
+ if self.args.missing_eos_penalty is not None:
500
+ model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
501
+ ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
502
+ model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
503
+ ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
504
+
505
+ return model_scores, ref_scores
506
+
507
+ def _compute_judge(self, model_data, ref_data, context_length):
508
+ prompts = model_data["raw"]
509
+ model_data_completions = self.processing_class.batch_decode(
510
+ model_data["input_ids"][:, context_length:], skip_special_tokens=True
511
+ )
512
+ model_data_completions = [completion.strip() for completion in model_data_completions]
513
+
514
+ ref_data_completions = self.processing_class.batch_decode(
515
+ ref_data["input_ids"][:, context_length:], skip_special_tokens=True
516
+ )
517
+ ref_data_completions = [completion.strip() for completion in ref_data_completions]
518
+
519
+ if is_conversational({"prompt": prompts[0]}):
520
+ model_data_completions = [
521
+ [{"role": "assistant", "content": completion}] for completion in model_data_completions
522
+ ]
523
+ environment = jinja2.Environment()
524
+ template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
525
+ prompts = [template.render(messages=message) for message in prompts]
526
+ model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
527
+
528
+ ref_data_completions = [
529
+ [{"role": "assistant", "content": completion}] for completion in ref_data_completions
530
+ ]
531
+ ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
532
+
533
+ ranks_of_first_completion = self.judge.judge(
534
+ prompts,
535
+ list(zip(model_data_completions, ref_data_completions)),
536
+ )
537
+ # convert ranks to a True/False mask:
538
+ # when rank == 0, it means the first completion is the best
539
+ # when rank == 1, it means the second completion is the best
540
+ return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
541
+
542
+ def _compute_logprobs(self, model, model_data, ref_data, context_length):
543
+ def compute_logprobs_for_data(m, data):
544
+ output = m(data["input_ids"], attention_mask=data["attention_mask"])
545
+ logits = output.logits[:, context_length - 1 : -1]
546
+ token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
547
+ return token_logprobs
548
+
549
+ # Compute logprobs for model completions
550
+ model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
551
+ # Compute logprobs for model on reference completions (for XPO loss)
552
+ model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
553
+
554
+ # Compute logprobs for reference model completions
555
+ with torch.no_grad():
556
+ if self.ref_model is None:
557
+ with model.disable_adapter():
558
+ ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
559
+ ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
560
+ else:
561
+ ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
562
+ ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
563
+
564
+ # Mask padding tokens
565
+ model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
566
+ ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
567
+ model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
568
+ model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
569
+ ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
570
+ ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
571
+
572
+ return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
573
+
574
+ def _compute_losses(
575
+ self,
576
+ model_logprobs_model_data,
577
+ model_logprobs_ref_data,
578
+ ref_logprobs_ref_data,
579
+ ref_logprobs_model_data,
580
+ chosen_mask,
581
+ ):
582
+ # Compute log probs
583
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
584
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
585
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
586
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
587
+
588
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
589
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
590
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
591
+
592
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
593
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
594
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
595
+
596
+ # Compute logits as the difference between chosen and rejected log ratios
597
+ logits = chosen_log_ratios - rejected_log_ratios
598
+
599
+ if self.args.loss_type == "sigmoid":
600
+ dpo_losses = -F.logsigmoid(self.beta * logits)
601
+ elif self.args.loss_type == "ipo":
602
+ dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
603
+ else:
604
+ raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
605
+
606
+ # Compute XPO specific loss
607
+ xpo_losses = self.alpha * model_logprobs_ref_data_sum
608
+
609
+ # Total loss
610
+ loss = (dpo_losses + xpo_losses).mean()
611
+
612
+ return loss, dpo_losses, xpo_losses
613
+
614
+ def _log_statistics(
615
+ self,
616
+ model_data,
617
+ ref_data,
618
+ model_logprobs_model_data,
619
+ model_logprobs_ref_data,
620
+ ref_logprobs_ref_data,
621
+ ref_logprobs_model_data,
622
+ chosen_mask,
623
+ dpo_losses,
624
+ xpo_losses,
625
+ context_length,
626
+ model_scores=None,
627
+ ref_scores=None,
628
+ ):
629
+ # Helper function to gather and compute mean
630
+ def gather_mean(tensor):
631
+ return self.accelerator.gather_for_metrics(tensor).mean().item()
632
+
633
+ # Log losses
634
+ self.stats["loss/dpo"].append(gather_mean(dpo_losses))
635
+ self.stats["loss/xpo"].append(gather_mean(xpo_losses))
636
+
637
+ # Log scores
638
+ if self.reward_model is not None:
639
+ self.stats["objective/model_scores"].append(gather_mean(model_scores))
640
+ self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
641
+ self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
642
+
643
+ # Log logprobs
644
+ model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
645
+ model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
646
+ ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
647
+ ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
648
+
649
+ chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
650
+ chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
651
+ chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
652
+
653
+ rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
654
+ rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
655
+ rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
656
+
657
+ self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
658
+ self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
659
+
660
+ # Log rewards
661
+ # Compute various statistics
662
+ chosen_rewards = chosen_log_ratios * self.beta
663
+ rejected_rewards = rejected_log_ratios * self.beta
664
+ self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
665
+ self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
666
+
667
+ # Calculate KL divergence for model and ref data
668
+ kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
669
+ kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
670
+ mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
671
+ self.stats["objective/kl"].append(gather_mean(mean_kl))
672
+
673
+ # Calculate entropy for model and ref data
674
+ entropy_model_data = -model_logprobs_model_data.sum(1)
675
+ entropy_ref_data = -model_logprobs_ref_data.sum(1)
676
+ mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
677
+ self.stats["objective/entropy"].append(gather_mean(mean_entropy))
678
+
679
+ # Calculate margins
680
+ margin = chosen_rewards - rejected_rewards
681
+ self.stats["rewards/margins"].append(gather_mean(margin.mean()))
682
+
683
+ # Calculate accuracy
684
+ accuracy = (margin > 0).float()
685
+ self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
686
+
687
+ # Log EOS token statistics
688
+ model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
689
+ ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
690
+ self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
691
+ self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
692
+
693
+ # Log alpha and beta
694
+ self.stats["alpha"].append(self.alpha)
695
+ self.stats["beta"].append(self.beta)
696
+
697
+ def training_step(
698
+ self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
699
+ ) -> torch.Tensor:
700
+ model.train()
701
+
702
+ # Apply chat template and tokenize the input
703
+ batch_size = len(next(iter(inputs.values())))
704
+ prompts = inputs["prompt"]
705
+ inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
706
+ inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
707
+ inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
708
+ inputs = self.data_collator(inputs)
709
+
710
+ # need the prompt_ only
711
+ inputs = self._prepare_inputs(inputs)
712
+ context_length = inputs["prompt_input_ids"].shape[1]
713
+ prompts = {
714
+ "input_ids": inputs["prompt_input_ids"],
715
+ "attention_mask": inputs["prompt_attention_mask"],
716
+ "raw": prompts,
717
+ }
718
+ del inputs
719
+
720
+ # Sample completions from both the model and the reference model
721
+ model_output, ref_output = self._generate_completions(prompts, model)
722
+
723
+ # Process model completions
724
+ model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
725
+
726
+ # Compute rewards
727
+ if self.reward_model is not None:
728
+ model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
729
+ chosen_mask = model_scores >= ref_scores
730
+ else:
731
+ model_scores, ref_scores = None, None
732
+ chosen_mask = self._compute_judge(model_data, ref_data, context_length)
733
+
734
+ # Compute logprobs
735
+ model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
736
+ self._compute_logprobs(model, model_data, ref_data, context_length)
737
+ )
738
+
739
+ # Compute loss
740
+ loss, dpo_losses, xpo_losses = self._compute_losses(
741
+ model_logprobs_model_data,
742
+ model_logprobs_ref_data,
743
+ ref_logprobs_ref_data,
744
+ ref_logprobs_model_data,
745
+ chosen_mask,
746
+ )
747
+
748
+ # Log everything
749
+ self._log_statistics(
750
+ model_data,
751
+ ref_data,
752
+ model_logprobs_model_data.detach(),
753
+ model_logprobs_ref_data.detach(),
754
+ ref_logprobs_ref_data,
755
+ ref_logprobs_model_data,
756
+ chosen_mask,
757
+ dpo_losses.detach(),
758
+ xpo_losses.detach(),
759
+ context_length,
760
+ model_scores,
761
+ ref_scores,
762
+ )
763
+
764
+ if (
765
+ self.args.torch_empty_cache_steps is not None
766
+ and self.state.global_step % self.args.torch_empty_cache_steps == 0
767
+ ):
768
+ empty_cache()
769
+
770
+ kwargs = {}
771
+ # For LOMO optimizers you need to explicitly use the learning rate
772
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
773
+ kwargs["learning_rate"] = self._get_learning_rate()
774
+
775
+ if self.args.n_gpu > 1:
776
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
777
+
778
+ if self.use_apex:
779
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
780
+ scaled_loss.backward()
781
+ else:
782
+ self.accelerator.backward(loss, **kwargs)
783
+
784
+ return loss.detach() / self.args.gradient_accumulation_steps
785
+
786
+ def create_model_card(
787
+ self,
788
+ model_name: Optional[str] = None,
789
+ dataset_name: Optional[str] = None,
790
+ tags: Union[str, list[str], None] = None,
791
+ ):
792
+ """
793
+ Creates a draft of a model card using the information available to the `Trainer`.
794
+
795
+ Args:
796
+ model_name (`str` or `None`, *optional*, defaults to `None`):
797
+ Name of the model.
798
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
799
+ Name of the dataset used for training.
800
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
801
+ Tags to be associated with the model card.
802
+ """
803
+ if not self.is_world_process_zero():
804
+ return
805
+
806
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
807
+ base_model = self.model.config._name_or_path
808
+ else:
809
+ base_model = None
810
+
811
+ tags = tags or []
812
+ if isinstance(tags, str):
813
+ tags = [tags]
814
+
815
+ if hasattr(self.model.config, "unsloth_version"):
816
+ tags.append("unsloth")
817
+
818
+ citation = textwrap.dedent("""\
819
+ @article{jung2024binary,
820
+ title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
821
+ author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
822
+ year = 2024,
823
+ eprint = {arXiv:2405.21046}
824
+ }""")
825
+
826
+ model_card = generate_model_card(
827
+ base_model=base_model,
828
+ model_name=model_name,
829
+ hub_model_id=self.hub_model_id,
830
+ dataset_name=dataset_name,
831
+ tags=tags,
832
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
833
+ comet_url=get_comet_experiment_url(),
834
+ trainer_name="XPO",
835
+ trainer_citation=citation,
836
+ paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
837
+ paper_id="2405.21046",
838
+ )
839
+
840
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
841
+ class UnslothXPOTrainer(_UnslothXPOTrainer):
842
+ """
843
+
844
+ Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
845
+
846
+ Args:
847
+ model (`transformers.PreTrainedModel`):
848
+ The model to train, preferably an `AutoModelForCausalLM`.
849
+ ref_model (`PreTrainedModelWrapper`):
850
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
851
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
852
+ reward_model (`transformers.PreTrainedModel`):
853
+ The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
854
+ judge (`BasePairwiseJudge`):
855
+ The judge to use for pairwise comparison of model completions.
856
+ args (`XPOConfig`):
857
+ The XPO config arguments to use for training.
858
+ data_collator (`transformers.DataCollator`):
859
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
860
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
861
+ train_dataset (`datasets.Dataset`):
862
+ The dataset to use for training.
863
+ eval_dataset (`datasets.Dataset`):
864
+ The dataset to use for evaluation.
865
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
866
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
867
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
868
+ reuse the fine-tuned model.
869
+ peft_config (`dict`):
870
+ The peft config to use for training.
871
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
872
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
873
+ a dictionary string to metric values.
874
+ callbacks (`list[transformers.TrainerCallback]`):
875
+ The callbacks to use for training.
876
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
877
+ The optimizer and scheduler to use for training.
878
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
879
+ The function to use to preprocess the logits before computing the metrics.
880
+
881
+ """
882
+ def __init__(
883
+ self,
884
+ model = None,
885
+ ref_model = None,
886
+ reward_model = None,
887
+ judge = None,
888
+ args = None,
889
+ data_collator = None,
890
+ train_dataset = None,
891
+ eval_dataset = None,
892
+ processing_class = None,
893
+ peft_config = None,
894
+ compute_metrics = None,
895
+ callbacks = None,
896
+ preprocess_logits_for_metrics = None,
897
+ **kwargs
898
+ ):
899
+ if args is None: args = UnslothXPOConfig()
900
+ use_bf16 = getattr(args, 'bf16', False)
901
+ use_fp16 = getattr(args, 'fp16', False)
902
+ force_float32 = False
903
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
904
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
905
+ force_float32 = True
906
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
907
+ dtype = getattr(model.config, 'torch_dtype', None)
908
+ if dtype is None: dtype = model.get_input_embeddings().dtype
909
+ from unsloth_zoo.utils import _get_dtype
910
+ dtype = _get_dtype(dtype)
911
+ float16 = dtype == torch.float16
912
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
913
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
914
+ if force_float32:
915
+ args.fp16 = False
916
+ args.bf16 = False
917
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
918
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
919
+ args.fp16 = float16
920
+ args.bf16 = not float16
921
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
922
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
923
+ args.eval_strategy = 'steps'
924
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
925
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
926
+ if ga_steps is not None and ga_steps > 1:
927
+ from transformers import __version__ as transformers_version
928
+ if Version(transformers_version) <= Version('4.45.2'):
929
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
930
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
931
+ if getattr(args, 'eval_strategy', 'no') != 'no':
932
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
933
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
934
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
935
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
936
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
937
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
938
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
939
+ if force_float32:
940
+ args.bf16_full_eval = False
941
+ args.fp16_full_eval = False
942
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
943
+ args.bf16_full_eval = True
944
+ args.fp16_full_eval = False
945
+ elif not bf16_full_eval and not fp16_full_eval:
946
+ args.bf16_full_eval = args.bf16
947
+ args.fp16_full_eval = args.fp16
948
+ _output_logits = False
949
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
950
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
951
+ if _output_logits:
952
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
953
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
954
+ pass
955
+ else:
956
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
957
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
958
+ if args_max_seq_length is None and model_max_seq_length is not None:
959
+ max_seq_length = model.max_seq_length
960
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
961
+ if model is not None and hasattr(model, 'for_training'):
962
+ model.for_training()
963
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
964
+ if 'processing_class' in locals():
965
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
966
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
967
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
968
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
969
+ if not isinstance(data_collator, UnslothVisionDataCollator):
970
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
971
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
972
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
973
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
974
+ else:
975
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
976
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
977
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
978
+ if not isinstance(data_collator, UnslothVisionDataCollator):
979
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
980
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
981
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
982
+ else:
983
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
984
+ other_metrics = []
985
+
986
+ from unsloth_zoo.logging_utils import PatchRLStatistics
987
+ PatchRLStatistics('xpo_trainer', other_metrics)
988
+
989
+ super().__init__(
990
+ model = model,
991
+ ref_model = ref_model,
992
+ reward_model = reward_model,
993
+ judge = judge,
994
+ args = args,
995
+ data_collator = data_collator,
996
+ train_dataset = train_dataset,
997
+ eval_dataset = eval_dataset,
998
+ processing_class = processing_class,
999
+ peft_config = peft_config,
1000
+ compute_metrics = compute_metrics,
1001
+ callbacks = callbacks,
1002
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1003
+ if hasattr(self, 'neftune_hook_handle'):
1004
+ self.neftune_hook_handle.remove()
1005
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1006
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1007
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1008
+ pass
1009
+
1010
+ pass
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc ADDED
Binary file (91.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc ADDED
Binary file (75.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc ADDED
Binary file (45.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e432dd59a309384c1403b9b00dbd3c130f21ba431cd431132cbbe27003d587c4
3
+ size 103569
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc ADDED
Binary file (37.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc ADDED
Binary file (78.5 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc ADDED
Binary file (87.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc ADDED
Binary file (47.3 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc ADDED
Binary file (75.6 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc ADDED
Binary file (67.1 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc ADDED
Binary file (62.7 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc ADDED
Binary file (36.4 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc ADDED
Binary file (54.2 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc ADDED
Binary file (38.9 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc ADDED
Binary file (47.8 kB). View file
 
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc ADDED
Binary file (49.9 kB). View file