Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gradio/certificate.pem +31 -0
- demo_cpu.py +182 -0
- requirements.txt +2 -1
- unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
- unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
- unsloth_compiled_cache/UnslothCPOTrainer.py +1557 -0
- unsloth_compiled_cache/UnslothDDPOTrainer.py +872 -0
- unsloth_compiled_cache/UnslothDPOTrainer.py +0 -0
- unsloth_compiled_cache/UnslothGKDTrainer.py +863 -0
- unsloth_compiled_cache/UnslothGRPOTrainer.py +1438 -0
- unsloth_compiled_cache/UnslothKTOTrainer.py +1840 -0
- unsloth_compiled_cache/UnslothNashMDTrainer.py +955 -0
- unsloth_compiled_cache/UnslothORPOTrainer.py +1543 -0
- unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +1269 -0
- unsloth_compiled_cache/UnslothPPOTrainer.py +1259 -0
- unsloth_compiled_cache/UnslothPRMTrainer.py +800 -0
- unsloth_compiled_cache/UnslothRLOOTrainer.py +1133 -0
- unsloth_compiled_cache/UnslothRewardTrainer.py +819 -0
- unsloth_compiled_cache/UnslothSFTTrainer.py +1027 -0
- unsloth_compiled_cache/UnslothXPOTrainer.py +1010 -0
- unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc +3 -0
- unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
- unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
demo_cpu.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Setup logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
SYSTEM_INSTRUCTION = """Convert natural language queries into boolean search queries by following these rules:
|
12 |
+
|
13 |
+
1. FIRST: Remove all meta-terms from this list (they should NEVER appear in output):
|
14 |
+
- articles, papers, research, studies
|
15 |
+
- examining, investigating, analyzing
|
16 |
+
- findings, documents, literature
|
17 |
+
- publications, journals, reviews
|
18 |
+
Example: "Research examining X" → just "X"
|
19 |
+
|
20 |
+
2. SECOND: Remove generic implied terms that don't add search value:
|
21 |
+
- Remove words like "practices," "techniques," "methods," "approaches," "strategies"
|
22 |
+
- Remove words like "impacts," "effects," "influences," "role," "applications"
|
23 |
+
- For example: "sustainable agriculture practices" → "sustainable agriculture"
|
24 |
+
- For example: "teaching methodologies" → "teaching"
|
25 |
+
- For example: "leadership styles" → "leadership"
|
26 |
+
|
27 |
+
3. THEN: Format the remaining terms:
|
28 |
+
CRITICAL QUOTING RULES:
|
29 |
+
- Multi-word phrases MUST ALWAYS be in quotes - NO EXCEPTIONS
|
30 |
+
- Examples of correct quoting:
|
31 |
+
- Wrong: machine learning AND deep learning
|
32 |
+
- Right: "machine learning" AND "deep learning"
|
33 |
+
- Wrong: natural language processing
|
34 |
+
- Right: "natural language processing"
|
35 |
+
- Single words must NEVER have quotes (e.g., science, research, learning)
|
36 |
+
- Use AND to connect required concepts
|
37 |
+
- Use OR with parentheses for alternatives
|
38 |
+
- Provide ONLY the correct Boolean query
|
39 |
+
"""
|
40 |
+
|
41 |
+
def load_model():
|
42 |
+
"""Load the model and tokenizer."""
|
43 |
+
logger.info("Loading model on CPU...")
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(
|
45 |
+
"../boolean_model_llama/merged_llama",
|
46 |
+
torch_dtype=torch.float32,
|
47 |
+
device_map="cpu",
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
51 |
+
"../boolean_model_llama/merged_llama",
|
52 |
+
trust_remote_code=True
|
53 |
+
)
|
54 |
+
|
55 |
+
logger.info("Model loaded successfully")
|
56 |
+
return model, tokenizer
|
57 |
+
|
58 |
+
def extract_response(text: str) -> str:
|
59 |
+
"""Extract the actual boolean query from the model's response."""
|
60 |
+
# Split by newlines and get last non-empty line
|
61 |
+
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
62 |
+
return lines[-1] if lines else ""
|
63 |
+
|
64 |
+
def get_boolean_query(query: str, model_tuple=None) -> str:
|
65 |
+
"""Generate boolean query from natural language."""
|
66 |
+
if model_tuple is None:
|
67 |
+
return ""
|
68 |
+
model, tokenizer = model_tuple
|
69 |
+
|
70 |
+
# Format the conversation
|
71 |
+
conversation = [
|
72 |
+
{"role": "system", "content": SYSTEM_INSTRUCTION},
|
73 |
+
{"role": "user", "content": query}
|
74 |
+
]
|
75 |
+
|
76 |
+
# Apply chat template and tokenize
|
77 |
+
encoded = tokenizer.apply_chat_template(
|
78 |
+
conversation,
|
79 |
+
return_tensors="pt"
|
80 |
+
)
|
81 |
+
|
82 |
+
# Create attention mask
|
83 |
+
attention_mask = torch.ones_like(encoded)
|
84 |
+
inputs = {
|
85 |
+
"input_ids": encoded,
|
86 |
+
"attention_mask": attention_mask
|
87 |
+
}
|
88 |
+
|
89 |
+
# Generate response
|
90 |
+
outputs = model.generate(
|
91 |
+
**inputs,
|
92 |
+
max_new_tokens=64,
|
93 |
+
temperature=0.1,
|
94 |
+
pad_token_id=tokenizer.eos_token_id
|
95 |
+
)
|
96 |
+
|
97 |
+
# Decode and extract response
|
98 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
99 |
+
return extract_response(response)
|
100 |
+
|
101 |
+
# Example queries demonstrating various cases
|
102 |
+
examples = [
|
103 |
+
# Testing removal of meta-terms
|
104 |
+
["Find research papers examining the long-term effects of meditation on brain structure"],
|
105 |
+
|
106 |
+
# Testing removal of generic implied terms (practices, techniques, methods)
|
107 |
+
["Articles about deep learning techniques for natural language processing tasks"],
|
108 |
+
|
109 |
+
# Testing removal of impact/effect terms
|
110 |
+
["Studies on the impact of early childhood nutrition on cognitive development"],
|
111 |
+
|
112 |
+
# Testing handling of technology applications
|
113 |
+
["Information on virtual reality applications in architectural design and urban planning"],
|
114 |
+
|
115 |
+
# Testing proper OR relationship with parentheses
|
116 |
+
["Research on electric vehicles adoption in urban environments or rural communities"],
|
117 |
+
|
118 |
+
# Testing proper quoting of multi-word concepts only
|
119 |
+
["Articles on biodiversity loss in coral reefs and rainforest ecosystems"],
|
120 |
+
|
121 |
+
# Testing removal of strategy/approach terms
|
122 |
+
["Studies about different teaching approaches for children with learning disabilities"],
|
123 |
+
|
124 |
+
# Testing complex OR relationships
|
125 |
+
["Research examining social media influence on political polarization or public discourse"],
|
126 |
+
|
127 |
+
# Testing implied terms in specific industries
|
128 |
+
["Articles about implementation strategies for blockchain in supply chain management or financial services"],
|
129 |
+
|
130 |
+
# Testing qualifiers that don't add search value
|
131 |
+
["Research on effective leadership styles in multicultural organizations"],
|
132 |
+
|
133 |
+
# Testing removal of multiple implied terms
|
134 |
+
["Studies on the effects of microplastic pollution techniques on marine ecosystem health"],
|
135 |
+
|
136 |
+
# Testing domain-specific implied terms
|
137 |
+
["Articles about successful cybersecurity protection methods for critical infrastructure"],
|
138 |
+
|
139 |
+
# Testing generalized vs specific concepts
|
140 |
+
["Research papers on quantum computing algorithms for cryptography or optimization problems"],
|
141 |
+
|
142 |
+
# Testing implied terms in outcome descriptions
|
143 |
+
["Studies examining the relationship between sleep quality and academic performance outcomes"],
|
144 |
+
|
145 |
+
# Testing complex nesting of concepts
|
146 |
+
["Articles about renewable energy integration challenges in developing countries or island nations"]
|
147 |
+
]
|
148 |
+
|
149 |
+
# Load model and tokenizer globally
|
150 |
+
logger.info("Initializing model...")
|
151 |
+
model_tuple = load_model()
|
152 |
+
|
153 |
+
# Create Gradio interface
|
154 |
+
title = "Natural Language to Boolean Search (CPU Version)"
|
155 |
+
description = """Convert natural language queries into boolean search expressions. The model will:
|
156 |
+
|
157 |
+
1. Remove search-related terms (like 'articles', 'research', etc.)
|
158 |
+
2. Handle generic implied terms (like 'practices', 'methods')
|
159 |
+
3. Format concepts using proper boolean syntax:
|
160 |
+
- Multi-word phrases in quotes
|
161 |
+
- Single words without quotes
|
162 |
+
- AND to connect required concepts
|
163 |
+
- OR with parentheses for alternatives
|
164 |
+
"""
|
165 |
+
|
166 |
+
demo = gr.Interface(
|
167 |
+
fn=lambda x: get_boolean_query(x, model_tuple),
|
168 |
+
inputs=[
|
169 |
+
gr.Textbox(
|
170 |
+
label="Enter your natural language query",
|
171 |
+
placeholder="e.g., I'm looking for information about climate change and renewable energy"
|
172 |
+
)
|
173 |
+
],
|
174 |
+
outputs=gr.Textbox(label="Boolean Search Query"),
|
175 |
+
title=title,
|
176 |
+
description=description,
|
177 |
+
examples=examples,
|
178 |
+
theme=gr.themes.Soft()
|
179 |
+
)
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
demo.launch(share=True)
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
gradio>=4.0.0
|
2 |
-
|
|
|
|
1 |
gradio>=4.0.0
|
2 |
+
transformers>=4.0.0
|
3 |
+
torch>=1.0.0
|
unsloth_compiled_cache/UnslothAlignPropTrainer.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothAlignPropConfig(AlignPropConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`AlignPropTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
+
Name of this experiment (defaults to the file name without the extension).
|
55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
56 |
+
Name of this run.
|
57 |
+
seed (`int`, *optional*, defaults to `0`):
|
58 |
+
Random seed for reproducibility.
|
59 |
+
log_with (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Log with either `"wandb"` or `"tensorboard"`. Check
|
61 |
+
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
62 |
+
log_image_freq (`int`, *optional*, defaults to `1`):
|
63 |
+
Frequency for logging images.
|
64 |
+
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
65 |
+
Keyword arguments for the tracker (e.g., `wandb_project`).
|
66 |
+
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
67 |
+
Keyword arguments for the accelerator.
|
68 |
+
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
69 |
+
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
70 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
71 |
+
Name of project to use for tracking.
|
72 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
73 |
+
Top-level logging directory for checkpoint saving.
|
74 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
75 |
+
Number of epochs to train.
|
76 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
77 |
+
Number of epochs between saving model checkpoints.
|
78 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
79 |
+
Number of checkpoints to keep before overwriting old ones.
|
80 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
81 |
+
Mixed precision training.
|
82 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
83 |
+
Allow `tf32` on Ampere GPUs.
|
84 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
85 |
+
Path to resume training from a checkpoint.
|
86 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
87 |
+
Number of sampler inference steps.
|
88 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
89 |
+
Eta parameter for the DDIM sampler.
|
90 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
91 |
+
Classifier-free guidance weight.
|
92 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
93 |
+
Batch size for training.
|
94 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
95 |
+
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
96 |
+
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
97 |
+
Learning rate.
|
98 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
99 |
+
Beta1 for Adam optimizer.
|
100 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
101 |
+
Beta2 for Adam optimizer.
|
102 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
103 |
+
Weight decay for Adam optimizer.
|
104 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
105 |
+
Epsilon value for Adam optimizer.
|
106 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
107 |
+
Number of gradient accumulation steps.
|
108 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
109 |
+
Maximum gradient norm for gradient clipping.
|
110 |
+
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
111 |
+
Comma-separated list of prompts to use as negative examples.
|
112 |
+
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
113 |
+
If `True`, randomized truncation to different diffusion timesteps is used.
|
114 |
+
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
115 |
+
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
116 |
+
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
117 |
+
Range of diffusion timesteps for randomized truncated backpropagation.
|
118 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
119 |
+
Whether to push the final model to the Hub.
|
120 |
+
|
121 |
+
"""
|
122 |
+
vllm_sampling_params: Optional[Any] = field(
|
123 |
+
default = None,
|
124 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
125 |
+
)
|
126 |
+
unsloth_num_chunks : Optional[int] = field(
|
127 |
+
default = -1,
|
128 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
129 |
+
)
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
exp_name = 'demo',
|
133 |
+
run_name = '',
|
134 |
+
seed = 3407,
|
135 |
+
log_with = None,
|
136 |
+
log_image_freq = 1,
|
137 |
+
tracker_project_name = 'trl',
|
138 |
+
logdir = 'logs',
|
139 |
+
num_epochs = 100,
|
140 |
+
save_freq = 1,
|
141 |
+
num_checkpoint_limit = 5,
|
142 |
+
mixed_precision = 'fp16',
|
143 |
+
allow_tf32 = True,
|
144 |
+
resume_from = '',
|
145 |
+
sample_num_steps = 50,
|
146 |
+
sample_eta = 1.0,
|
147 |
+
sample_guidance_scale = 5.0,
|
148 |
+
train_batch_size = 1,
|
149 |
+
train_use_8bit_adam = False,
|
150 |
+
train_learning_rate = 5e-05,
|
151 |
+
train_adam_beta1 = 0.9,
|
152 |
+
train_adam_beta2 = 0.999,
|
153 |
+
train_adam_weight_decay = 0.01,
|
154 |
+
train_adam_epsilon = 1e-08,
|
155 |
+
train_gradient_accumulation_steps = 2,
|
156 |
+
train_max_grad_norm = 1.0,
|
157 |
+
negative_prompts = None,
|
158 |
+
truncated_backprop_rand = True,
|
159 |
+
truncated_backprop_timestep = 49,
|
160 |
+
push_to_hub = False,
|
161 |
+
vllm_sampling_params = None,
|
162 |
+
unsloth_num_chunks = -1,
|
163 |
+
**kwargs,
|
164 |
+
):
|
165 |
+
|
166 |
+
super().__init__(
|
167 |
+
exp_name = exp_name,
|
168 |
+
run_name = run_name,
|
169 |
+
seed = seed,
|
170 |
+
log_with = log_with,
|
171 |
+
log_image_freq = log_image_freq,
|
172 |
+
tracker_project_name = tracker_project_name,
|
173 |
+
logdir = logdir,
|
174 |
+
num_epochs = num_epochs,
|
175 |
+
save_freq = save_freq,
|
176 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
177 |
+
mixed_precision = mixed_precision,
|
178 |
+
allow_tf32 = allow_tf32,
|
179 |
+
resume_from = resume_from,
|
180 |
+
sample_num_steps = sample_num_steps,
|
181 |
+
sample_eta = sample_eta,
|
182 |
+
sample_guidance_scale = sample_guidance_scale,
|
183 |
+
train_batch_size = train_batch_size,
|
184 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
185 |
+
train_learning_rate = train_learning_rate,
|
186 |
+
train_adam_beta1 = train_adam_beta1,
|
187 |
+
train_adam_beta2 = train_adam_beta2,
|
188 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
189 |
+
train_adam_epsilon = train_adam_epsilon,
|
190 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
191 |
+
train_max_grad_norm = train_max_grad_norm,
|
192 |
+
negative_prompts = negative_prompts,
|
193 |
+
truncated_backprop_rand = truncated_backprop_rand,
|
194 |
+
truncated_backprop_timestep = truncated_backprop_timestep,
|
195 |
+
push_to_hub = push_to_hub,**kwargs)
|
196 |
+
self.vllm_sampling_params = vllm_sampling_params
|
197 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
198 |
+
pass
|
199 |
+
|
200 |
+
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
201 |
+
""""""
|
202 |
+
|
203 |
+
_tag_names = ["trl", "alignprop"]
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
config: AlignPropConfig,
|
208 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
209 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
210 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
211 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
212 |
+
):
|
213 |
+
if image_samples_hook is None:
|
214 |
+
warn("No image_samples_hook provided; no images will be logged")
|
215 |
+
|
216 |
+
self.prompt_fn = prompt_function
|
217 |
+
self.reward_fn = reward_function
|
218 |
+
self.config = config
|
219 |
+
self.image_samples_callback = image_samples_hook
|
220 |
+
|
221 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
222 |
+
|
223 |
+
if self.config.resume_from:
|
224 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
225 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
226 |
+
# get the most recent checkpoint in this directory
|
227 |
+
checkpoints = list(
|
228 |
+
filter(
|
229 |
+
lambda x: "checkpoint_" in x,
|
230 |
+
os.listdir(self.config.resume_from),
|
231 |
+
)
|
232 |
+
)
|
233 |
+
if len(checkpoints) == 0:
|
234 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
235 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
236 |
+
self.config.resume_from = os.path.join(
|
237 |
+
self.config.resume_from,
|
238 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
239 |
+
)
|
240 |
+
|
241 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
242 |
+
|
243 |
+
self.accelerator = Accelerator(
|
244 |
+
log_with=self.config.log_with,
|
245 |
+
mixed_precision=self.config.mixed_precision,
|
246 |
+
project_config=accelerator_project_config,
|
247 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
248 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
249 |
+
# the total number of optimizer steps to accumulate across.
|
250 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
251 |
+
**self.config.accelerator_kwargs,
|
252 |
+
)
|
253 |
+
|
254 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
255 |
+
|
256 |
+
if self.accelerator.is_main_process:
|
257 |
+
self.accelerator.init_trackers(
|
258 |
+
self.config.tracker_project_name,
|
259 |
+
config=dict(alignprop_trainer_config=config.to_dict())
|
260 |
+
if not is_using_tensorboard
|
261 |
+
else config.to_dict(),
|
262 |
+
init_kwargs=self.config.tracker_kwargs,
|
263 |
+
)
|
264 |
+
|
265 |
+
logger.info(f"\n{config}")
|
266 |
+
|
267 |
+
set_seed(self.config.seed, device_specific=True)
|
268 |
+
|
269 |
+
self.sd_pipeline = sd_pipeline
|
270 |
+
|
271 |
+
self.sd_pipeline.set_progress_bar_config(
|
272 |
+
position=1,
|
273 |
+
disable=not self.accelerator.is_local_main_process,
|
274 |
+
leave=False,
|
275 |
+
desc="Timestep",
|
276 |
+
dynamic_ncols=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
280 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
281 |
+
if self.accelerator.mixed_precision == "fp16":
|
282 |
+
inference_dtype = torch.float16
|
283 |
+
elif self.accelerator.mixed_precision == "bf16":
|
284 |
+
inference_dtype = torch.bfloat16
|
285 |
+
else:
|
286 |
+
inference_dtype = torch.float32
|
287 |
+
|
288 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
289 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
290 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
291 |
+
|
292 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
293 |
+
|
294 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
295 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
296 |
+
|
297 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
298 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
299 |
+
if self.config.allow_tf32:
|
300 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
301 |
+
|
302 |
+
self.optimizer = self._setup_optimizer(
|
303 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
304 |
+
)
|
305 |
+
|
306 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
307 |
+
self.sd_pipeline.tokenizer(
|
308 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
309 |
+
return_tensors="pt",
|
310 |
+
padding="max_length",
|
311 |
+
truncation=True,
|
312 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
313 |
+
).input_ids.to(self.accelerator.device)
|
314 |
+
)[0]
|
315 |
+
|
316 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
317 |
+
# more memory
|
318 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
319 |
+
|
320 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
321 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
322 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
323 |
+
else:
|
324 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
325 |
+
|
326 |
+
if config.resume_from:
|
327 |
+
logger.info(f"Resuming from {config.resume_from}")
|
328 |
+
self.accelerator.load_state(config.resume_from)
|
329 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
330 |
+
else:
|
331 |
+
self.first_epoch = 0
|
332 |
+
|
333 |
+
def compute_rewards(self, prompt_image_pairs):
|
334 |
+
reward, reward_metadata = self.reward_fn(
|
335 |
+
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
336 |
+
)
|
337 |
+
return reward
|
338 |
+
|
339 |
+
def step(self, epoch: int, global_step: int):
|
340 |
+
"""
|
341 |
+
Perform a single step of training.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
epoch (int): The current epoch.
|
345 |
+
global_step (int): The current global step.
|
346 |
+
|
347 |
+
Side Effects:
|
348 |
+
- Model weights are updated
|
349 |
+
- Logs the statistics to the accelerator trackers.
|
350 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
global_step (int): The updated global step.
|
354 |
+
"""
|
355 |
+
info = defaultdict(list)
|
356 |
+
|
357 |
+
self.sd_pipeline.unet.train()
|
358 |
+
|
359 |
+
for _ in range(self.config.train_gradient_accumulation_steps):
|
360 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
361 |
+
prompt_image_pairs = self._generate_samples(
|
362 |
+
batch_size=self.config.train_batch_size,
|
363 |
+
)
|
364 |
+
|
365 |
+
rewards = self.compute_rewards(prompt_image_pairs)
|
366 |
+
|
367 |
+
prompt_image_pairs["rewards"] = rewards
|
368 |
+
|
369 |
+
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
370 |
+
|
371 |
+
loss = self.calculate_loss(rewards)
|
372 |
+
|
373 |
+
self.accelerator.backward(loss)
|
374 |
+
|
375 |
+
if self.accelerator.sync_gradients:
|
376 |
+
self.accelerator.clip_grad_norm_(
|
377 |
+
self.trainable_layers.parameters()
|
378 |
+
if not isinstance(self.trainable_layers, list)
|
379 |
+
else self.trainable_layers,
|
380 |
+
self.config.train_max_grad_norm,
|
381 |
+
)
|
382 |
+
|
383 |
+
self.optimizer.step()
|
384 |
+
self.optimizer.zero_grad()
|
385 |
+
|
386 |
+
info["reward_mean"].append(rewards_vis.mean())
|
387 |
+
info["reward_std"].append(rewards_vis.std())
|
388 |
+
info["loss"].append(loss.item())
|
389 |
+
|
390 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
391 |
+
if self.accelerator.sync_gradients:
|
392 |
+
# log training-related stuff
|
393 |
+
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
394 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
395 |
+
info.update({"epoch": epoch})
|
396 |
+
self.accelerator.log(info, step=global_step)
|
397 |
+
global_step += 1
|
398 |
+
info = defaultdict(list)
|
399 |
+
else:
|
400 |
+
raise ValueError(
|
401 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
402 |
+
)
|
403 |
+
# Logs generated images
|
404 |
+
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
405 |
+
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
406 |
+
|
407 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
408 |
+
self.accelerator.save_state()
|
409 |
+
|
410 |
+
return global_step
|
411 |
+
|
412 |
+
def calculate_loss(self, rewards):
|
413 |
+
"""
|
414 |
+
Calculate the loss for a batch of an unpacked sample
|
415 |
+
|
416 |
+
Args:
|
417 |
+
rewards (torch.Tensor):
|
418 |
+
Differentiable reward scalars for each generated image, shape: [batch_size]
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
loss (torch.Tensor)
|
422 |
+
(all of these are of shape (1,))
|
423 |
+
"""
|
424 |
+
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
425 |
+
loss = 10.0 - (rewards).mean()
|
426 |
+
return loss
|
427 |
+
|
428 |
+
def loss(
|
429 |
+
self,
|
430 |
+
advantages: torch.Tensor,
|
431 |
+
clip_range: float,
|
432 |
+
ratio: torch.Tensor,
|
433 |
+
):
|
434 |
+
unclipped_loss = -advantages * ratio
|
435 |
+
clipped_loss = -advantages * torch.clamp(
|
436 |
+
ratio,
|
437 |
+
1.0 - clip_range,
|
438 |
+
1.0 + clip_range,
|
439 |
+
)
|
440 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
441 |
+
|
442 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
443 |
+
if self.config.train_use_8bit_adam:
|
444 |
+
import bitsandbytes
|
445 |
+
|
446 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
447 |
+
else:
|
448 |
+
optimizer_cls = torch.optim.AdamW
|
449 |
+
|
450 |
+
return optimizer_cls(
|
451 |
+
trainable_layers_parameters,
|
452 |
+
lr=self.config.train_learning_rate,
|
453 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
454 |
+
weight_decay=self.config.train_adam_weight_decay,
|
455 |
+
eps=self.config.train_adam_epsilon,
|
456 |
+
)
|
457 |
+
|
458 |
+
def _save_model_hook(self, models, weights, output_dir):
|
459 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
460 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
461 |
+
|
462 |
+
def _load_model_hook(self, models, input_dir):
|
463 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
464 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
465 |
+
|
466 |
+
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
467 |
+
"""
|
468 |
+
Generate samples from the model
|
469 |
+
|
470 |
+
Args:
|
471 |
+
batch_size (int): Batch size to use for sampling
|
472 |
+
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
prompt_image_pairs (dict[Any])
|
476 |
+
"""
|
477 |
+
prompt_image_pairs = {}
|
478 |
+
|
479 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
480 |
+
|
481 |
+
if prompts is None:
|
482 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
483 |
+
else:
|
484 |
+
prompt_metadata = [{} for _ in range(batch_size)]
|
485 |
+
|
486 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
487 |
+
prompts,
|
488 |
+
return_tensors="pt",
|
489 |
+
padding="max_length",
|
490 |
+
truncation=True,
|
491 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
492 |
+
).input_ids.to(self.accelerator.device)
|
493 |
+
|
494 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
495 |
+
|
496 |
+
if with_grad:
|
497 |
+
sd_output = self.sd_pipeline.rgb_with_grad(
|
498 |
+
prompt_embeds=prompt_embeds,
|
499 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
500 |
+
num_inference_steps=self.config.sample_num_steps,
|
501 |
+
guidance_scale=self.config.sample_guidance_scale,
|
502 |
+
eta=self.config.sample_eta,
|
503 |
+
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
504 |
+
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
505 |
+
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
506 |
+
output_type="pt",
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
sd_output = self.sd_pipeline(
|
510 |
+
prompt_embeds=prompt_embeds,
|
511 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
512 |
+
num_inference_steps=self.config.sample_num_steps,
|
513 |
+
guidance_scale=self.config.sample_guidance_scale,
|
514 |
+
eta=self.config.sample_eta,
|
515 |
+
output_type="pt",
|
516 |
+
)
|
517 |
+
|
518 |
+
images = sd_output.images
|
519 |
+
|
520 |
+
prompt_image_pairs["images"] = images
|
521 |
+
prompt_image_pairs["prompts"] = prompts
|
522 |
+
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
523 |
+
|
524 |
+
return prompt_image_pairs
|
525 |
+
|
526 |
+
def train(self, epochs: Optional[int] = None):
|
527 |
+
"""
|
528 |
+
Train the model for a given number of epochs
|
529 |
+
"""
|
530 |
+
global_step = 0
|
531 |
+
if epochs is None:
|
532 |
+
epochs = self.config.num_epochs
|
533 |
+
for epoch in range(self.first_epoch, epochs):
|
534 |
+
global_step = self.step(epoch, global_step)
|
535 |
+
|
536 |
+
def _save_pretrained(self, save_directory):
|
537 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
538 |
+
self.create_model_card()
|
539 |
+
|
540 |
+
def create_model_card(
|
541 |
+
self,
|
542 |
+
model_name: Optional[str] = None,
|
543 |
+
dataset_name: Optional[str] = None,
|
544 |
+
tags: Union[str, list[str], None] = None,
|
545 |
+
):
|
546 |
+
"""
|
547 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
551 |
+
Name of the model.
|
552 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
553 |
+
Name of the dataset used for training.
|
554 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
555 |
+
Tags to be associated with the model card.
|
556 |
+
"""
|
557 |
+
if not self.is_world_process_zero():
|
558 |
+
return
|
559 |
+
|
560 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
561 |
+
base_model = self.model.config._name_or_path
|
562 |
+
else:
|
563 |
+
base_model = None
|
564 |
+
|
565 |
+
tags = tags or []
|
566 |
+
if isinstance(tags, str):
|
567 |
+
tags = [tags]
|
568 |
+
|
569 |
+
if hasattr(self.model.config, "unsloth_version"):
|
570 |
+
tags.append("unsloth")
|
571 |
+
|
572 |
+
citation = textwrap.dedent("""\
|
573 |
+
@article{prabhudesai2024aligning,
|
574 |
+
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
575 |
+
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
576 |
+
year = 2024,
|
577 |
+
eprint = {arXiv:2310.03739}
|
578 |
+
}""")
|
579 |
+
|
580 |
+
model_card = generate_model_card(
|
581 |
+
base_model=base_model,
|
582 |
+
model_name=model_name,
|
583 |
+
hub_model_id=self.hub_model_id,
|
584 |
+
dataset_name=dataset_name,
|
585 |
+
tags=tags,
|
586 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
587 |
+
comet_url=get_comet_experiment_url(),
|
588 |
+
trainer_name="AlignProp",
|
589 |
+
trainer_citation=citation,
|
590 |
+
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
591 |
+
paper_id="2310.03739",
|
592 |
+
)
|
593 |
+
|
594 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
595 |
+
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
596 |
+
"""
|
597 |
+
|
598 |
+
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
599 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
600 |
+
As of now only Stable Diffusion based pipelines are supported
|
601 |
+
|
602 |
+
Attributes:
|
603 |
+
config (`AlignPropConfig`):
|
604 |
+
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
605 |
+
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
606 |
+
Reward function to be used
|
607 |
+
prompt_function (`Callable[[], tuple[str, Any]]`):
|
608 |
+
Function to generate prompts to guide model
|
609 |
+
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
610 |
+
Stable Diffusion pipeline to be used for training.
|
611 |
+
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
612 |
+
Hook to be called to log images
|
613 |
+
|
614 |
+
"""
|
615 |
+
def __init__(
|
616 |
+
self,
|
617 |
+
config,
|
618 |
+
reward_function,
|
619 |
+
prompt_function,
|
620 |
+
sd_pipeline,
|
621 |
+
image_samples_hook = None,
|
622 |
+
**kwargs
|
623 |
+
):
|
624 |
+
if args is None: args = UnslothAlignPropConfig()
|
625 |
+
other_metrics = []
|
626 |
+
|
627 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
628 |
+
PatchRLStatistics('alignprop_trainer', other_metrics)
|
629 |
+
|
630 |
+
super().__init__(
|
631 |
+
config = config,
|
632 |
+
reward_function = reward_function,
|
633 |
+
prompt_function = prompt_function,
|
634 |
+
sd_pipeline = sd_pipeline,
|
635 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
636 |
+
|
637 |
+
pass
|
unsloth_compiled_cache/UnslothBCOTrainer.py
ADDED
@@ -0,0 +1,1824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothBCOConfig(BCOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`BCOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
54 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
55 |
+
to use the default data collator.
|
56 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
57 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
58 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
59 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
60 |
+
and your model is an encoder-decoder.
|
61 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
62 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
63 |
+
reference model.
|
64 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
65 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
66 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
67 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
68 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
69 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
70 |
+
This argument is required if you want to use the default data collator.
|
71 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
72 |
+
Whether to disable dropout in the model and reference model.
|
73 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
74 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
75 |
+
evaluation.
|
76 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
77 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
78 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
79 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
81 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
82 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
83 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
84 |
+
string.
|
85 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
86 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
87 |
+
from a string.
|
88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
+
Number of processes to use for processing the dataset.
|
90 |
+
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
91 |
+
Number of prompts that are fed to density ratio classifier.
|
92 |
+
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
93 |
+
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
94 |
+
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
95 |
+
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
96 |
+
|
97 |
+
"""
|
98 |
+
vllm_sampling_params: Optional[Any] = field(
|
99 |
+
default = None,
|
100 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
101 |
+
)
|
102 |
+
unsloth_num_chunks : Optional[int] = field(
|
103 |
+
default = -1,
|
104 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
105 |
+
)
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
output_dir = None,
|
109 |
+
overwrite_output_dir = None,
|
110 |
+
do_train = False,
|
111 |
+
do_eval = False,
|
112 |
+
do_predict = False,
|
113 |
+
eval_strategy = 'no',
|
114 |
+
prediction_loss_only = False,
|
115 |
+
per_device_train_batch_size = 4,
|
116 |
+
per_device_eval_batch_size = 4,
|
117 |
+
per_gpu_train_batch_size = None,
|
118 |
+
per_gpu_eval_batch_size = None,
|
119 |
+
gradient_accumulation_steps = 2,
|
120 |
+
eval_accumulation_steps = 2,
|
121 |
+
eval_delay = 0,
|
122 |
+
torch_empty_cache_steps = 250,
|
123 |
+
learning_rate = 5e-05,
|
124 |
+
weight_decay = 0.01,
|
125 |
+
adam_beta1 = 0.9,
|
126 |
+
adam_beta2 = 0.999,
|
127 |
+
adam_epsilon = 1e-08,
|
128 |
+
max_grad_norm = 1.0,
|
129 |
+
num_train_epochs = 3.0,
|
130 |
+
max_steps = -1,
|
131 |
+
lr_scheduler_type = 'linear',
|
132 |
+
warmup_ratio = 0.1,
|
133 |
+
warmup_steps = 0,
|
134 |
+
log_level = 'passive',
|
135 |
+
log_level_replica = 'warning',
|
136 |
+
log_on_each_node = True,
|
137 |
+
logging_dir = None,
|
138 |
+
logging_strategy = 'steps',
|
139 |
+
logging_first_step = False,
|
140 |
+
logging_steps = 1,
|
141 |
+
logging_nan_inf_filter = False,
|
142 |
+
save_strategy = 'steps',
|
143 |
+
save_steps = 500,
|
144 |
+
save_total_limit = None,
|
145 |
+
save_safetensors = True,
|
146 |
+
save_on_each_node = False,
|
147 |
+
save_only_model = False,
|
148 |
+
restore_callback_states_from_checkpoint = False,
|
149 |
+
no_cuda = False,
|
150 |
+
use_cpu = False,
|
151 |
+
use_mps_device = False,
|
152 |
+
seed = 3407,
|
153 |
+
data_seed = 3407,
|
154 |
+
jit_mode_eval = False,
|
155 |
+
use_ipex = False,
|
156 |
+
bf16 = False,
|
157 |
+
fp16 = False,
|
158 |
+
fp16_opt_level = 'O1',
|
159 |
+
half_precision_backend = 'auto',
|
160 |
+
bf16_full_eval = False,
|
161 |
+
fp16_full_eval = False,
|
162 |
+
tf32 = None,
|
163 |
+
local_rank = -1,
|
164 |
+
ddp_backend = None,
|
165 |
+
tpu_num_cores = None,
|
166 |
+
tpu_metrics_debug = False,
|
167 |
+
debug = '',
|
168 |
+
dataloader_drop_last = False,
|
169 |
+
eval_steps = None,
|
170 |
+
dataloader_num_workers = 0,
|
171 |
+
dataloader_prefetch_factor = None,
|
172 |
+
past_index = -1,
|
173 |
+
run_name = None,
|
174 |
+
disable_tqdm = None,
|
175 |
+
remove_unused_columns = True,
|
176 |
+
label_names = None,
|
177 |
+
load_best_model_at_end = False,
|
178 |
+
metric_for_best_model = None,
|
179 |
+
greater_is_better = None,
|
180 |
+
ignore_data_skip = False,
|
181 |
+
fsdp = '',
|
182 |
+
fsdp_min_num_params = 0,
|
183 |
+
fsdp_config = None,
|
184 |
+
tp_size = 0,
|
185 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
186 |
+
accelerator_config = None,
|
187 |
+
deepspeed = None,
|
188 |
+
label_smoothing_factor = 0.0,
|
189 |
+
optim = 'adamw_8bit',
|
190 |
+
optim_args = None,
|
191 |
+
adafactor = False,
|
192 |
+
group_by_length = False,
|
193 |
+
length_column_name = 'length',
|
194 |
+
report_to = None,
|
195 |
+
ddp_find_unused_parameters = None,
|
196 |
+
ddp_bucket_cap_mb = None,
|
197 |
+
ddp_broadcast_buffers = None,
|
198 |
+
dataloader_pin_memory = True,
|
199 |
+
dataloader_persistent_workers = False,
|
200 |
+
skip_memory_metrics = True,
|
201 |
+
use_legacy_prediction_loop = False,
|
202 |
+
push_to_hub = False,
|
203 |
+
resume_from_checkpoint = None,
|
204 |
+
hub_model_id = None,
|
205 |
+
hub_strategy = 'every_save',
|
206 |
+
hub_token = None,
|
207 |
+
hub_private_repo = None,
|
208 |
+
hub_always_push = False,
|
209 |
+
gradient_checkpointing = False,
|
210 |
+
gradient_checkpointing_kwargs = None,
|
211 |
+
include_inputs_for_metrics = False,
|
212 |
+
eval_do_concat_batches = True,
|
213 |
+
fp16_backend = 'auto',
|
214 |
+
evaluation_strategy = None,
|
215 |
+
push_to_hub_model_id = None,
|
216 |
+
push_to_hub_organization = None,
|
217 |
+
push_to_hub_token = None,
|
218 |
+
mp_parameters = '',
|
219 |
+
auto_find_batch_size = False,
|
220 |
+
full_determinism = False,
|
221 |
+
torchdynamo = None,
|
222 |
+
ray_scope = 'last',
|
223 |
+
ddp_timeout = 1800,
|
224 |
+
torch_compile = False,
|
225 |
+
torch_compile_backend = None,
|
226 |
+
torch_compile_mode = None,
|
227 |
+
dispatch_batches = None,
|
228 |
+
split_batches = None,
|
229 |
+
include_tokens_per_second = False,
|
230 |
+
include_num_input_tokens_seen = False,
|
231 |
+
neftune_noise_alpha = None,
|
232 |
+
optim_target_modules = None,
|
233 |
+
batch_eval_metrics = False,
|
234 |
+
eval_on_start = False,
|
235 |
+
use_liger_kernel = False,
|
236 |
+
eval_use_gather_object = False,
|
237 |
+
average_tokens_across_devices = False,
|
238 |
+
max_length = 1024,
|
239 |
+
max_prompt_length = 512,
|
240 |
+
max_completion_length = None,
|
241 |
+
beta = 0.1,
|
242 |
+
label_pad_token_id = -100,
|
243 |
+
padding_value = None,
|
244 |
+
truncation_mode = 'keep_end',
|
245 |
+
disable_dropout = True,
|
246 |
+
generate_during_eval = False,
|
247 |
+
is_encoder_decoder = None,
|
248 |
+
precompute_ref_log_probs = False,
|
249 |
+
model_init_kwargs = None,
|
250 |
+
ref_model_init_kwargs = None,
|
251 |
+
dataset_num_proc = None,
|
252 |
+
prompt_sample_size = 1024,
|
253 |
+
min_density_ratio = 0.5,
|
254 |
+
max_density_ratio = 10.0,
|
255 |
+
vllm_sampling_params = None,
|
256 |
+
unsloth_num_chunks = -1,
|
257 |
+
**kwargs,
|
258 |
+
):
|
259 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
260 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
261 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
262 |
+
output_dir = 'unsloth_training_checkpoints'
|
263 |
+
save_strategy = 'no'
|
264 |
+
if dataset_num_proc is None:
|
265 |
+
from multiprocessing import cpu_count
|
266 |
+
dataset_num_proc = cpu_count()
|
267 |
+
|
268 |
+
super().__init__(
|
269 |
+
output_dir = output_dir,
|
270 |
+
overwrite_output_dir = overwrite_output_dir,
|
271 |
+
do_train = do_train,
|
272 |
+
do_eval = do_eval,
|
273 |
+
do_predict = do_predict,
|
274 |
+
eval_strategy = eval_strategy,
|
275 |
+
prediction_loss_only = prediction_loss_only,
|
276 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
277 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
278 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
279 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
280 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
281 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
282 |
+
eval_delay = eval_delay,
|
283 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
284 |
+
learning_rate = learning_rate,
|
285 |
+
weight_decay = weight_decay,
|
286 |
+
adam_beta1 = adam_beta1,
|
287 |
+
adam_beta2 = adam_beta2,
|
288 |
+
adam_epsilon = adam_epsilon,
|
289 |
+
max_grad_norm = max_grad_norm,
|
290 |
+
num_train_epochs = num_train_epochs,
|
291 |
+
max_steps = max_steps,
|
292 |
+
lr_scheduler_type = lr_scheduler_type,
|
293 |
+
warmup_ratio = warmup_ratio,
|
294 |
+
warmup_steps = warmup_steps,
|
295 |
+
log_level = log_level,
|
296 |
+
log_level_replica = log_level_replica,
|
297 |
+
log_on_each_node = log_on_each_node,
|
298 |
+
logging_dir = logging_dir,
|
299 |
+
logging_strategy = logging_strategy,
|
300 |
+
logging_first_step = logging_first_step,
|
301 |
+
logging_steps = logging_steps,
|
302 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
303 |
+
save_strategy = save_strategy,
|
304 |
+
save_steps = save_steps,
|
305 |
+
save_total_limit = save_total_limit,
|
306 |
+
save_safetensors = save_safetensors,
|
307 |
+
save_on_each_node = save_on_each_node,
|
308 |
+
save_only_model = save_only_model,
|
309 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
310 |
+
no_cuda = no_cuda,
|
311 |
+
use_cpu = use_cpu,
|
312 |
+
use_mps_device = use_mps_device,
|
313 |
+
seed = seed,
|
314 |
+
data_seed = data_seed,
|
315 |
+
jit_mode_eval = jit_mode_eval,
|
316 |
+
use_ipex = use_ipex,
|
317 |
+
bf16 = bf16,
|
318 |
+
fp16 = fp16,
|
319 |
+
fp16_opt_level = fp16_opt_level,
|
320 |
+
half_precision_backend = half_precision_backend,
|
321 |
+
bf16_full_eval = bf16_full_eval,
|
322 |
+
fp16_full_eval = fp16_full_eval,
|
323 |
+
tf32 = tf32,
|
324 |
+
local_rank = local_rank,
|
325 |
+
ddp_backend = ddp_backend,
|
326 |
+
tpu_num_cores = tpu_num_cores,
|
327 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
328 |
+
debug = debug,
|
329 |
+
dataloader_drop_last = dataloader_drop_last,
|
330 |
+
eval_steps = eval_steps,
|
331 |
+
dataloader_num_workers = dataloader_num_workers,
|
332 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
333 |
+
past_index = past_index,
|
334 |
+
run_name = run_name,
|
335 |
+
disable_tqdm = disable_tqdm,
|
336 |
+
remove_unused_columns = remove_unused_columns,
|
337 |
+
label_names = label_names,
|
338 |
+
load_best_model_at_end = load_best_model_at_end,
|
339 |
+
metric_for_best_model = metric_for_best_model,
|
340 |
+
greater_is_better = greater_is_better,
|
341 |
+
ignore_data_skip = ignore_data_skip,
|
342 |
+
fsdp = fsdp,
|
343 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
344 |
+
fsdp_config = fsdp_config,
|
345 |
+
tp_size = tp_size,
|
346 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
347 |
+
accelerator_config = accelerator_config,
|
348 |
+
deepspeed = deepspeed,
|
349 |
+
label_smoothing_factor = label_smoothing_factor,
|
350 |
+
optim = optim,
|
351 |
+
optim_args = optim_args,
|
352 |
+
adafactor = adafactor,
|
353 |
+
group_by_length = group_by_length,
|
354 |
+
length_column_name = length_column_name,
|
355 |
+
report_to = report_to,
|
356 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
357 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
358 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
359 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
360 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
361 |
+
skip_memory_metrics = skip_memory_metrics,
|
362 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
363 |
+
push_to_hub = push_to_hub,
|
364 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
365 |
+
hub_model_id = hub_model_id,
|
366 |
+
hub_strategy = hub_strategy,
|
367 |
+
hub_token = hub_token,
|
368 |
+
hub_private_repo = hub_private_repo,
|
369 |
+
hub_always_push = hub_always_push,
|
370 |
+
gradient_checkpointing = gradient_checkpointing,
|
371 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
372 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
373 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
374 |
+
fp16_backend = fp16_backend,
|
375 |
+
evaluation_strategy = evaluation_strategy,
|
376 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
377 |
+
push_to_hub_organization = push_to_hub_organization,
|
378 |
+
push_to_hub_token = push_to_hub_token,
|
379 |
+
mp_parameters = mp_parameters,
|
380 |
+
auto_find_batch_size = auto_find_batch_size,
|
381 |
+
full_determinism = full_determinism,
|
382 |
+
torchdynamo = torchdynamo,
|
383 |
+
ray_scope = ray_scope,
|
384 |
+
ddp_timeout = ddp_timeout,
|
385 |
+
torch_compile = torch_compile,
|
386 |
+
torch_compile_backend = torch_compile_backend,
|
387 |
+
torch_compile_mode = torch_compile_mode,
|
388 |
+
dispatch_batches = dispatch_batches,
|
389 |
+
split_batches = split_batches,
|
390 |
+
include_tokens_per_second = include_tokens_per_second,
|
391 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
392 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
393 |
+
optim_target_modules = optim_target_modules,
|
394 |
+
batch_eval_metrics = batch_eval_metrics,
|
395 |
+
eval_on_start = eval_on_start,
|
396 |
+
use_liger_kernel = use_liger_kernel,
|
397 |
+
eval_use_gather_object = eval_use_gather_object,
|
398 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
399 |
+
max_length = max_length,
|
400 |
+
max_prompt_length = max_prompt_length,
|
401 |
+
max_completion_length = max_completion_length,
|
402 |
+
beta = beta,
|
403 |
+
label_pad_token_id = label_pad_token_id,
|
404 |
+
padding_value = padding_value,
|
405 |
+
truncation_mode = truncation_mode,
|
406 |
+
disable_dropout = disable_dropout,
|
407 |
+
generate_during_eval = generate_during_eval,
|
408 |
+
is_encoder_decoder = is_encoder_decoder,
|
409 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
410 |
+
model_init_kwargs = model_init_kwargs,
|
411 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
412 |
+
dataset_num_proc = dataset_num_proc,
|
413 |
+
prompt_sample_size = prompt_sample_size,
|
414 |
+
min_density_ratio = min_density_ratio,
|
415 |
+
max_density_ratio = max_density_ratio,**kwargs)
|
416 |
+
self.vllm_sampling_params = vllm_sampling_params
|
417 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
418 |
+
pass
|
419 |
+
|
420 |
+
class _UnslothBCOTrainer(Trainer):
|
421 |
+
r""""""
|
422 |
+
|
423 |
+
_tag_names = ["trl", "bco"]
|
424 |
+
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
428 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
429 |
+
args: BCOConfig = None,
|
430 |
+
train_dataset: Optional[Dataset] = None,
|
431 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
432 |
+
processing_class: Optional[
|
433 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
434 |
+
] = None,
|
435 |
+
data_collator: Optional[DataCollator] = None,
|
436 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
437 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
438 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
439 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
440 |
+
peft_config: Optional[dict] = None,
|
441 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
442 |
+
model_adapter_name: Optional[str] = None,
|
443 |
+
ref_adapter_name: Optional[str] = None,
|
444 |
+
embedding_func: Optional[Callable] = None,
|
445 |
+
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
446 |
+
):
|
447 |
+
if not is_sklearn_available():
|
448 |
+
raise ImportError(
|
449 |
+
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
450 |
+
)
|
451 |
+
|
452 |
+
if type(args) is TrainingArguments:
|
453 |
+
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
454 |
+
|
455 |
+
if not isinstance(model, str) and ref_model is model:
|
456 |
+
raise ValueError(
|
457 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
458 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
459 |
+
)
|
460 |
+
|
461 |
+
if args.model_init_kwargs is None:
|
462 |
+
model_init_kwargs = {}
|
463 |
+
elif not isinstance(model, str):
|
464 |
+
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
465 |
+
else:
|
466 |
+
model_init_kwargs = args.model_init_kwargs
|
467 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
468 |
+
if torch_dtype is not None:
|
469 |
+
# Convert to `torch.dtype` if an str is passed
|
470 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
471 |
+
torch_dtype = getattr(torch, torch_dtype)
|
472 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
473 |
+
raise ValueError(
|
474 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
475 |
+
)
|
476 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
477 |
+
|
478 |
+
if args.ref_model_init_kwargs is None:
|
479 |
+
ref_model_init_kwargs = {}
|
480 |
+
elif not isinstance(ref_model, str):
|
481 |
+
raise ValueError(
|
482 |
+
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
486 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
487 |
+
if torch_dtype is not None:
|
488 |
+
# Convert to `torch.dtype` if an str is passed
|
489 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
490 |
+
torch_dtype = getattr(torch, torch_dtype)
|
491 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
492 |
+
raise ValueError(
|
493 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
494 |
+
)
|
495 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
496 |
+
|
497 |
+
if isinstance(model, str):
|
498 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
499 |
+
|
500 |
+
if isinstance(ref_model, str):
|
501 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
502 |
+
|
503 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
504 |
+
# has been called in order to properly call autocast if needed.
|
505 |
+
self._peft_has_been_casted_to_bf16 = False
|
506 |
+
|
507 |
+
if not is_peft_available() and peft_config is not None:
|
508 |
+
raise ValueError(
|
509 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
510 |
+
)
|
511 |
+
elif is_peft_available() and peft_config is not None:
|
512 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
513 |
+
if isinstance(model, PeftModel):
|
514 |
+
model = model.merge_and_unload()
|
515 |
+
|
516 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
517 |
+
_support_gc_kwargs = hasattr(
|
518 |
+
args, "gradient_checkpointing_kwargs"
|
519 |
+
) and "gradient_checkpointing_kwargs" in list(
|
520 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
521 |
+
)
|
522 |
+
|
523 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
524 |
+
|
525 |
+
if _support_gc_kwargs:
|
526 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
527 |
+
|
528 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
529 |
+
elif getattr(args, "gradient_checkpointing", False):
|
530 |
+
# For backward compatibility with older versions of transformers
|
531 |
+
if hasattr(model, "enable_input_require_grads"):
|
532 |
+
model.enable_input_require_grads()
|
533 |
+
else:
|
534 |
+
|
535 |
+
def make_inputs_require_grad(module, input, output):
|
536 |
+
output.requires_grad_(True)
|
537 |
+
|
538 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
539 |
+
|
540 |
+
# get peft model with the given config
|
541 |
+
model = model
|
542 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
543 |
+
peft_module_casting_to_bf16(model)
|
544 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
545 |
+
self._peft_has_been_casted_to_bf16 = True
|
546 |
+
|
547 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
548 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
549 |
+
# fail or completely fail.
|
550 |
+
elif getattr(args, "gradient_checkpointing", False):
|
551 |
+
# For backward compatibility with older versions of transformers
|
552 |
+
if hasattr(model, "enable_input_require_grads"):
|
553 |
+
model.enable_input_require_grads()
|
554 |
+
else:
|
555 |
+
|
556 |
+
def make_inputs_require_grad(module, input, output):
|
557 |
+
output.requires_grad_(True)
|
558 |
+
|
559 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
560 |
+
|
561 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
562 |
+
raise ValueError(
|
563 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
564 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
565 |
+
)
|
566 |
+
|
567 |
+
if model is not None:
|
568 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
569 |
+
elif args.is_encoder_decoder is None:
|
570 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
571 |
+
else:
|
572 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
573 |
+
|
574 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
575 |
+
self.model_adapter_name = model_adapter_name
|
576 |
+
self.ref_adapter_name = ref_adapter_name
|
577 |
+
|
578 |
+
if ref_model:
|
579 |
+
self.ref_model = ref_model
|
580 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
581 |
+
# The `model` with adapters turned off will be used as the reference model
|
582 |
+
self.ref_model = None
|
583 |
+
else:
|
584 |
+
self.ref_model = create_reference_model(model)
|
585 |
+
|
586 |
+
if processing_class is None:
|
587 |
+
raise ValueError(
|
588 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
589 |
+
)
|
590 |
+
if args.max_length is None:
|
591 |
+
warnings.warn(
|
592 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
593 |
+
"It will be set to `512` by default, but you should do it yourself in the future.",
|
594 |
+
UserWarning,
|
595 |
+
)
|
596 |
+
max_length = 512
|
597 |
+
if args.max_length is not None:
|
598 |
+
max_length = args.max_length
|
599 |
+
|
600 |
+
if args.max_prompt_length is None:
|
601 |
+
warnings.warn(
|
602 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
603 |
+
"It will be set to `128` by default, but you should do it yourself in the future.",
|
604 |
+
UserWarning,
|
605 |
+
)
|
606 |
+
max_prompt_length = 128
|
607 |
+
if args.max_prompt_length is not None:
|
608 |
+
max_prompt_length = args.max_prompt_length
|
609 |
+
|
610 |
+
max_completion_length = None
|
611 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
612 |
+
warnings.warn(
|
613 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
614 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
615 |
+
UserWarning,
|
616 |
+
)
|
617 |
+
max_completion_length = 128
|
618 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
619 |
+
max_completion_length = args.max_completion_length
|
620 |
+
|
621 |
+
if data_collator is None:
|
622 |
+
data_collator = DPODataCollatorWithPadding(
|
623 |
+
pad_token_id=processing_class.pad_token_id,
|
624 |
+
label_pad_token_id=args.label_pad_token_id,
|
625 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.remove_unused_columns:
|
629 |
+
args.remove_unused_columns = False
|
630 |
+
# warn users
|
631 |
+
warnings.warn(
|
632 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
633 |
+
" we have set it for you, but you should do it yourself in the future.",
|
634 |
+
UserWarning,
|
635 |
+
)
|
636 |
+
|
637 |
+
self.use_dpo_data_collator = True
|
638 |
+
else:
|
639 |
+
self.use_dpo_data_collator = False
|
640 |
+
|
641 |
+
# Disable dropout in the model and reference model
|
642 |
+
if args.disable_dropout:
|
643 |
+
disable_dropout_in_model(model)
|
644 |
+
if self.ref_model is not None:
|
645 |
+
disable_dropout_in_model(self.ref_model)
|
646 |
+
|
647 |
+
self.max_length = max_length
|
648 |
+
self.generate_during_eval = args.generate_during_eval
|
649 |
+
self.label_pad_token_id = args.label_pad_token_id
|
650 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
651 |
+
self.max_prompt_length = max_prompt_length
|
652 |
+
self.truncation_mode = args.truncation_mode
|
653 |
+
self.max_completion_length = max_completion_length
|
654 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
655 |
+
|
656 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
657 |
+
# keep track of first called to avoid computation of future calls
|
658 |
+
self._precomputed_train_ref_log_probs = False
|
659 |
+
self._precomputed_eval_ref_log_probs = False
|
660 |
+
|
661 |
+
# metric
|
662 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
663 |
+
|
664 |
+
# BCO parameter
|
665 |
+
self.beta = args.beta
|
666 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
667 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
668 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
669 |
+
warnings.warn(
|
670 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
671 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
672 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
673 |
+
"loss.",
|
674 |
+
UserWarning,
|
675 |
+
)
|
676 |
+
|
677 |
+
# Underlying Distribution Matching argument
|
678 |
+
self.embedding_func = embedding_func
|
679 |
+
self.embedding_tokenizer = embedding_tokenizer
|
680 |
+
|
681 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
682 |
+
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
683 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
684 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
685 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
686 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
687 |
+
# issued.
|
688 |
+
model.warnings_issued["estimate_tokens"] = True
|
689 |
+
|
690 |
+
with PartialState().local_main_process_first():
|
691 |
+
# Apply the chat template if needed
|
692 |
+
train_dataset = train_dataset.map(
|
693 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
694 |
+
)
|
695 |
+
if eval_dataset is not None:
|
696 |
+
eval_dataset = eval_dataset.map(
|
697 |
+
maybe_apply_chat_template,
|
698 |
+
fn_kwargs={"tokenizer": processing_class},
|
699 |
+
num_proc=args.dataset_num_proc,
|
700 |
+
)
|
701 |
+
# Shuffle the datasets
|
702 |
+
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
703 |
+
if eval_dataset is not None:
|
704 |
+
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
705 |
+
# Tokenize and prepare the training datasets
|
706 |
+
train_dataset = train_dataset.map(
|
707 |
+
_tokenize,
|
708 |
+
batched=True,
|
709 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
710 |
+
num_proc=args.dataset_num_proc,
|
711 |
+
desc="Tokenizing train dataset",
|
712 |
+
)
|
713 |
+
|
714 |
+
# Prepare the datasets
|
715 |
+
fn_kwargs = {
|
716 |
+
"prefix": "",
|
717 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
718 |
+
"tokenizer": processing_class,
|
719 |
+
"max_length": self.max_length,
|
720 |
+
"truncation_mode": self.truncation_mode,
|
721 |
+
"label_pad_token_id": self.label_pad_token_id,
|
722 |
+
"max_prompt_length": self.max_prompt_length,
|
723 |
+
"max_completion_length": self.max_completion_length,
|
724 |
+
}
|
725 |
+
train_dataset = train_dataset.map(
|
726 |
+
_process_tokens,
|
727 |
+
fn_kwargs=fn_kwargs,
|
728 |
+
num_proc=args.dataset_num_proc,
|
729 |
+
desc="Processing tokenized train dataset",
|
730 |
+
)
|
731 |
+
|
732 |
+
if eval_dataset is not None:
|
733 |
+
# Tokenize
|
734 |
+
eval_dataset = eval_dataset.map(
|
735 |
+
_tokenize,
|
736 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
737 |
+
batched=True,
|
738 |
+
num_proc=args.dataset_num_proc,
|
739 |
+
desc="Tokenizing eval dataset",
|
740 |
+
)
|
741 |
+
|
742 |
+
# Process
|
743 |
+
fn_kwargs = {
|
744 |
+
"prefix": "",
|
745 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
746 |
+
"tokenizer": processing_class,
|
747 |
+
"max_length": self.max_length,
|
748 |
+
"truncation_mode": self.truncation_mode,
|
749 |
+
"label_pad_token_id": self.label_pad_token_id,
|
750 |
+
"max_prompt_length": self.max_prompt_length,
|
751 |
+
"max_completion_length": self.max_completion_length,
|
752 |
+
}
|
753 |
+
eval_dataset = eval_dataset.map(
|
754 |
+
_process_tokens,
|
755 |
+
fn_kwargs=fn_kwargs,
|
756 |
+
num_proc=args.dataset_num_proc,
|
757 |
+
desc="Processing tokenized eval dataset",
|
758 |
+
)
|
759 |
+
|
760 |
+
desirable = train_dataset.filter(
|
761 |
+
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
762 |
+
)
|
763 |
+
undesirable = train_dataset.filter(
|
764 |
+
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
765 |
+
)
|
766 |
+
|
767 |
+
desirable = desirable.shuffle(seed=args.data_seed)
|
768 |
+
undesirable = undesirable.shuffle(seed=args.data_seed)
|
769 |
+
|
770 |
+
super().__init__(
|
771 |
+
model=model,
|
772 |
+
args=args,
|
773 |
+
data_collator=data_collator,
|
774 |
+
train_dataset=train_dataset,
|
775 |
+
eval_dataset=eval_dataset,
|
776 |
+
processing_class=processing_class,
|
777 |
+
model_init=model_init,
|
778 |
+
compute_metrics=compute_metrics,
|
779 |
+
callbacks=callbacks,
|
780 |
+
optimizers=optimizers,
|
781 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
782 |
+
)
|
783 |
+
|
784 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
785 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
786 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
787 |
+
self.model_accepts_loss_kwargs = False
|
788 |
+
|
789 |
+
# Add tags for models that have been loaded with the correct transformers version
|
790 |
+
if hasattr(self.model, "add_model_tags"):
|
791 |
+
self.model.add_model_tags(self._tag_names)
|
792 |
+
|
793 |
+
if not hasattr(self, "accelerator"):
|
794 |
+
raise AttributeError(
|
795 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
796 |
+
)
|
797 |
+
|
798 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
799 |
+
if self.is_deepspeed_enabled:
|
800 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
801 |
+
raise ValueError(
|
802 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
803 |
+
)
|
804 |
+
|
805 |
+
if self.ref_model is None:
|
806 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
807 |
+
raise ValueError(
|
808 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
if self.is_deepspeed_enabled:
|
812 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
813 |
+
else:
|
814 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
815 |
+
|
816 |
+
self.running = RunningMoments(accelerator=self.accelerator)
|
817 |
+
|
818 |
+
if self.embedding_func is None:
|
819 |
+
return
|
820 |
+
|
821 |
+
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
822 |
+
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
823 |
+
|
824 |
+
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
825 |
+
labels = torch.cat(
|
826 |
+
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
827 |
+
)
|
828 |
+
|
829 |
+
self.clf = LogisticRegression(class_weight="balanced").fit(
|
830 |
+
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
831 |
+
)
|
832 |
+
|
833 |
+
@property
|
834 |
+
def match_underlying_distribution(self):
|
835 |
+
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
836 |
+
|
837 |
+
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
838 |
+
"""
|
839 |
+
Calculates the probability if the given prompt embedding is from desirable dataset.
|
840 |
+
This function calculates the probability in the process and ensemble across processes.
|
841 |
+
"""
|
842 |
+
dtype = prompt_embeddings.dtype
|
843 |
+
device = prompt_embeddings.device
|
844 |
+
rank = self.accelerator.process_index
|
845 |
+
|
846 |
+
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
847 |
+
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
848 |
+
)
|
849 |
+
sample_size = padded_prompt_embeddings.shape[0]
|
850 |
+
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
851 |
+
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
852 |
+
|
853 |
+
# cannot predict for all empty values
|
854 |
+
if prompt_embeddings.shape[0] == 0:
|
855 |
+
return torch.tensor([], device=device, dtype=dtype)
|
856 |
+
|
857 |
+
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
858 |
+
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
859 |
+
prob = self.accelerator.reduce(prob, reduction="mean")
|
860 |
+
|
861 |
+
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
862 |
+
prob = prob[nonzero]
|
863 |
+
|
864 |
+
return prob
|
865 |
+
|
866 |
+
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
867 |
+
"""
|
868 |
+
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
869 |
+
and applies self.embedding_func
|
870 |
+
"""
|
871 |
+
input_ids = torch.where(
|
872 |
+
input_ids == self.processing_class.pad_token_id,
|
873 |
+
self.embedding_tokenizer.pad_token_id,
|
874 |
+
input_ids,
|
875 |
+
)
|
876 |
+
|
877 |
+
with torch.no_grad():
|
878 |
+
embeddings = self.embedding_func(
|
879 |
+
input_ids=input_ids,
|
880 |
+
attention_mask=attention_mask,
|
881 |
+
)
|
882 |
+
|
883 |
+
return embeddings
|
884 |
+
|
885 |
+
def _get_prompt_embeddings(
|
886 |
+
self, batch: dict[str, Union[list, torch.LongTensor]]
|
887 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
888 |
+
"""Extract embeddings from frozen embedding model"""
|
889 |
+
|
890 |
+
if not self.match_underlying_distribution:
|
891 |
+
return None, None
|
892 |
+
|
893 |
+
embeddings = self._vectorize_prompt(
|
894 |
+
input_ids=batch["embedding_input_ids"],
|
895 |
+
attention_mask=batch["embedding_attention_mask"],
|
896 |
+
)
|
897 |
+
|
898 |
+
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
899 |
+
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
900 |
+
|
901 |
+
chosen_embeddings = embeddings[chosen_idx, ...]
|
902 |
+
rejected_embeddings = embeddings[rejected_idx, ...]
|
903 |
+
|
904 |
+
return (chosen_embeddings, rejected_embeddings)
|
905 |
+
|
906 |
+
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
907 |
+
"""
|
908 |
+
Sample instances from dataset and get prompt embeddings.
|
909 |
+
Used for density ratio classifier training.
|
910 |
+
"""
|
911 |
+
n_samples = min(len(dataset), sample_size)
|
912 |
+
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
913 |
+
|
914 |
+
embedding_dataset = dataset.select(rand_indices)
|
915 |
+
|
916 |
+
dataloader_params = {
|
917 |
+
"batch_size": self.args.per_device_train_batch_size,
|
918 |
+
"collate_fn": self.data_collator,
|
919 |
+
"num_workers": self.args.dataloader_num_workers,
|
920 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
921 |
+
"shuffle": False,
|
922 |
+
}
|
923 |
+
|
924 |
+
# prepare dataloader
|
925 |
+
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
926 |
+
|
927 |
+
with torch.no_grad():
|
928 |
+
all_embeddings = torch.empty(0)
|
929 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
930 |
+
embeddings = self._vectorize_prompt(
|
931 |
+
input_ids=padded_batch["embedding_input_ids"],
|
932 |
+
attention_mask=padded_batch["embedding_attention_mask"],
|
933 |
+
)
|
934 |
+
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
935 |
+
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
936 |
+
|
937 |
+
return all_embeddings
|
938 |
+
|
939 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
940 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
941 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
942 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
943 |
+
|
944 |
+
if model is not None:
|
945 |
+
if hasattr(model, "config"):
|
946 |
+
hidden_size = (
|
947 |
+
max(model.config.hidden_sizes)
|
948 |
+
if getattr(model.config, "hidden_sizes", None)
|
949 |
+
else getattr(model.config, "hidden_size", None)
|
950 |
+
)
|
951 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
952 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
953 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
954 |
+
config_kwargs.update(
|
955 |
+
{
|
956 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
957 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
958 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
959 |
+
}
|
960 |
+
)
|
961 |
+
|
962 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
963 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
964 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
965 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
966 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
967 |
+
model.eval()
|
968 |
+
return model
|
969 |
+
|
970 |
+
def _save_optimizer_and_scheduler(self, output_dir):
|
971 |
+
super()._save_optimizer_and_scheduler(output_dir)
|
972 |
+
|
973 |
+
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
974 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
975 |
+
|
976 |
+
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
977 |
+
|
978 |
+
if self.match_underlying_distribution:
|
979 |
+
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
980 |
+
|
981 |
+
def _load_optimizer_and_scheduler(self, checkpoint):
|
982 |
+
super()._load_optimizer_and_scheduler(checkpoint)
|
983 |
+
|
984 |
+
if checkpoint is None:
|
985 |
+
return
|
986 |
+
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
987 |
+
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
988 |
+
if os.path.isfile(running_file):
|
989 |
+
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
990 |
+
|
991 |
+
if self.match_underlying_distribution:
|
992 |
+
clf_file = os.path.join(checkpoint, CLF_NAME)
|
993 |
+
if os.path.isfile(running_file):
|
994 |
+
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
995 |
+
|
996 |
+
@contextmanager
|
997 |
+
def null_ref_context(self):
|
998 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
999 |
+
with (
|
1000 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
1001 |
+
if self.is_peft_model and not self.ref_adapter_name
|
1002 |
+
else nullcontext()
|
1003 |
+
):
|
1004 |
+
if self.ref_adapter_name:
|
1005 |
+
self.model.set_adapter(self.ref_adapter_name)
|
1006 |
+
yield
|
1007 |
+
if self.ref_adapter_name:
|
1008 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
1009 |
+
|
1010 |
+
def get_train_dataloader(self) -> DataLoader:
|
1011 |
+
"""
|
1012 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
1013 |
+
|
1014 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
1015 |
+
"""
|
1016 |
+
|
1017 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
1018 |
+
dataloader_params = {
|
1019 |
+
"batch_size": self.args.per_device_train_batch_size,
|
1020 |
+
"collate_fn": self.data_collator,
|
1021 |
+
"num_workers": self.args.dataloader_num_workers,
|
1022 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1023 |
+
"shuffle": False,
|
1024 |
+
}
|
1025 |
+
|
1026 |
+
# prepare dataloader
|
1027 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
1028 |
+
reference_completion_logps = []
|
1029 |
+
|
1030 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
1031 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1032 |
+
|
1033 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1034 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1035 |
+
|
1036 |
+
self.train_dataset = self.train_dataset.add_column(
|
1037 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
self._precomputed_train_ref_log_probs = True
|
1041 |
+
|
1042 |
+
return super().get_train_dataloader()
|
1043 |
+
|
1044 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
1045 |
+
"""
|
1046 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
1047 |
+
|
1048 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
1049 |
+
|
1050 |
+
Args:
|
1051 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
1052 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
1053 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
1054 |
+
"""
|
1055 |
+
if eval_dataset is None and self.eval_dataset is None:
|
1056 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
1057 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
1058 |
+
|
1059 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
1060 |
+
dataloader_params = {
|
1061 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
1062 |
+
"collate_fn": self.data_collator,
|
1063 |
+
"num_workers": self.args.dataloader_num_workers,
|
1064 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1065 |
+
"shuffle": False,
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
# prepare dataloader
|
1069 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1070 |
+
|
1071 |
+
reference_completion_logps = []
|
1072 |
+
|
1073 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1074 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
1075 |
+
|
1076 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1077 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1078 |
+
|
1079 |
+
eval_dataset = eval_dataset.add_column(
|
1080 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1084 |
+
if self.eval_dataset is not None:
|
1085 |
+
self.eval_dataset = eval_dataset
|
1086 |
+
self._precomputed_eval_ref_log_probs = True
|
1087 |
+
|
1088 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1089 |
+
|
1090 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1091 |
+
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
1092 |
+
with torch.no_grad():
|
1093 |
+
if self.ref_model is None:
|
1094 |
+
with self.null_ref_context():
|
1095 |
+
if self.is_encoder_decoder:
|
1096 |
+
completion_logits = self.model(
|
1097 |
+
padded_batch["prompt_input_ids"],
|
1098 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1099 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1100 |
+
labels=padded_batch["completion_labels"],
|
1101 |
+
).logits
|
1102 |
+
|
1103 |
+
else:
|
1104 |
+
completion_logits = self.model(
|
1105 |
+
padded_batch["completion_input_ids"],
|
1106 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
1107 |
+
).logits
|
1108 |
+
|
1109 |
+
else:
|
1110 |
+
if self.is_encoder_decoder:
|
1111 |
+
completion_logits = self.ref_model(
|
1112 |
+
padded_batch["prompt_input_ids"],
|
1113 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1114 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1115 |
+
labels=padded_batch["completion_labels"],
|
1116 |
+
).logits
|
1117 |
+
|
1118 |
+
else:
|
1119 |
+
completion_logits = self.ref_model(
|
1120 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1121 |
+
).logits
|
1122 |
+
|
1123 |
+
completion_logps = self.get_batch_logps(
|
1124 |
+
completion_logits,
|
1125 |
+
padded_batch["completion_labels"],
|
1126 |
+
average_log_prob=False,
|
1127 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1128 |
+
label_pad_token_id=self.label_pad_token_id,
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
return completion_logps
|
1132 |
+
|
1133 |
+
@staticmethod
|
1134 |
+
def get_batch_logps(
|
1135 |
+
logits: torch.FloatTensor,
|
1136 |
+
labels: torch.LongTensor,
|
1137 |
+
average_log_prob: bool = False,
|
1138 |
+
label_pad_token_id: int = -100,
|
1139 |
+
is_encoder_decoder: bool = False,
|
1140 |
+
) -> torch.FloatTensor:
|
1141 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1142 |
+
|
1143 |
+
Args:
|
1144 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1145 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1146 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1147 |
+
|
1148 |
+
Returns:
|
1149 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1150 |
+
"""
|
1151 |
+
if logits.shape[:-1] != labels.shape:
|
1152 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1153 |
+
|
1154 |
+
if not is_encoder_decoder:
|
1155 |
+
labels = labels[:, 1:].clone()
|
1156 |
+
logits = logits[:, :-1, :]
|
1157 |
+
else:
|
1158 |
+
# Fixes end-dec RuntimeError
|
1159 |
+
labels = labels.clone()
|
1160 |
+
|
1161 |
+
loss_mask = labels != label_pad_token_id
|
1162 |
+
|
1163 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1164 |
+
labels[labels == label_pad_token_id] = 0
|
1165 |
+
|
1166 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1167 |
+
|
1168 |
+
if average_log_prob:
|
1169 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1170 |
+
else:
|
1171 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1172 |
+
|
1173 |
+
def forward(
|
1174 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1175 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1176 |
+
model_kwargs = (
|
1177 |
+
{
|
1178 |
+
"labels": batch["completion_labels"],
|
1179 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1180 |
+
}
|
1181 |
+
if self.is_encoder_decoder
|
1182 |
+
else {}
|
1183 |
+
)
|
1184 |
+
if self.aux_loss_enabled:
|
1185 |
+
model_kwargs["output_router_logits"] = True
|
1186 |
+
|
1187 |
+
outputs = model(
|
1188 |
+
batch["completion_input_ids"],
|
1189 |
+
attention_mask=batch["completion_attention_mask"],
|
1190 |
+
**model_kwargs,
|
1191 |
+
)
|
1192 |
+
completion_logits = outputs.logits
|
1193 |
+
|
1194 |
+
completion_logps = self.get_batch_logps(
|
1195 |
+
completion_logits,
|
1196 |
+
batch["completion_labels"],
|
1197 |
+
average_log_prob=False,
|
1198 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1199 |
+
label_pad_token_id=self.label_pad_token_id,
|
1200 |
+
)
|
1201 |
+
|
1202 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
1203 |
+
raise ValueError(
|
1204 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
1205 |
+
"examples for which an output sequence was predicted."
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1209 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1210 |
+
|
1211 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
1212 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
1213 |
+
|
1214 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
1215 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
1216 |
+
|
1217 |
+
if self.aux_loss_enabled:
|
1218 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
1219 |
+
else:
|
1220 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
1221 |
+
|
1222 |
+
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
1223 |
+
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
1224 |
+
min_ratio = self.args.min_density_ratio
|
1225 |
+
max_ratio = self.args.max_density_ratio
|
1226 |
+
|
1227 |
+
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
1228 |
+
|
1229 |
+
return weight
|
1230 |
+
|
1231 |
+
def bco_loss(
|
1232 |
+
self,
|
1233 |
+
policy_chosen_logps: torch.FloatTensor,
|
1234 |
+
policy_rejected_logps: torch.FloatTensor,
|
1235 |
+
reference_chosen_logps: torch.FloatTensor,
|
1236 |
+
reference_rejected_logps: torch.FloatTensor,
|
1237 |
+
chosen_embeddings: Optional[torch.FloatTensor],
|
1238 |
+
rejected_embeddings: Optional[torch.FloatTensor],
|
1239 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1240 |
+
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
1241 |
+
|
1242 |
+
Args:
|
1243 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1244 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1245 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1246 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1247 |
+
chosen_embeddings: embeddings of desirable prompts
|
1248 |
+
rejected_embeddings: embeddings of undesirable prompts
|
1249 |
+
|
1250 |
+
Returns:
|
1251 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
1252 |
+
The losses tensor contains the BCO loss for each example in the batch.
|
1253 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1254 |
+
The delta value contains the moving average of all implicit rewards.
|
1255 |
+
"""
|
1256 |
+
|
1257 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1258 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1259 |
+
chosen_rewards = self.beta * chosen_logratios
|
1260 |
+
else:
|
1261 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1262 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1263 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1264 |
+
|
1265 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1266 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1267 |
+
rejected_rewards = self.beta * rejected_logratios
|
1268 |
+
else:
|
1269 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1270 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1271 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1272 |
+
|
1273 |
+
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
1274 |
+
self.running.update(rewards)
|
1275 |
+
delta = self.running.mean
|
1276 |
+
|
1277 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1278 |
+
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
1279 |
+
|
1280 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1281 |
+
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
1282 |
+
|
1283 |
+
if self.match_underlying_distribution:
|
1284 |
+
chosen_weight = torch.ones_like(chosen_losses)
|
1285 |
+
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
1286 |
+
|
1287 |
+
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
1288 |
+
else:
|
1289 |
+
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
1290 |
+
|
1291 |
+
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
1292 |
+
|
1293 |
+
def get_batch_loss_metrics(
|
1294 |
+
self,
|
1295 |
+
model,
|
1296 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1297 |
+
):
|
1298 |
+
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
1299 |
+
metrics = {}
|
1300 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1301 |
+
|
1302 |
+
forward_output = self.forward(model, batch)
|
1303 |
+
(
|
1304 |
+
policy_chosen_logps,
|
1305 |
+
policy_rejected_logps,
|
1306 |
+
policy_chosen_logits,
|
1307 |
+
policy_rejected_logits,
|
1308 |
+
) = forward_output[:4]
|
1309 |
+
if self.aux_loss_enabled:
|
1310 |
+
aux_loss = forward_output[4]
|
1311 |
+
|
1312 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
1313 |
+
if "reference_logps" in batch:
|
1314 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1315 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1316 |
+
|
1317 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1318 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1319 |
+
else:
|
1320 |
+
with torch.no_grad():
|
1321 |
+
if self.ref_model is None:
|
1322 |
+
with self.null_ref_context():
|
1323 |
+
(
|
1324 |
+
reference_chosen_logps,
|
1325 |
+
reference_rejected_logps,
|
1326 |
+
_,
|
1327 |
+
_,
|
1328 |
+
) = self.forward(self.model, batch)[:4]
|
1329 |
+
else:
|
1330 |
+
(
|
1331 |
+
reference_chosen_logps,
|
1332 |
+
reference_rejected_logps,
|
1333 |
+
_,
|
1334 |
+
_,
|
1335 |
+
) = self.forward(self.ref_model, batch)[:4]
|
1336 |
+
|
1337 |
+
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
1338 |
+
|
1339 |
+
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
1340 |
+
policy_chosen_logps,
|
1341 |
+
policy_rejected_logps,
|
1342 |
+
reference_chosen_logps,
|
1343 |
+
reference_rejected_logps,
|
1344 |
+
chosen_embeddings,
|
1345 |
+
rejected_embeddings,
|
1346 |
+
)
|
1347 |
+
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
1348 |
+
|
1349 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1350 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1351 |
+
|
1352 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1353 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1354 |
+
|
1355 |
+
if all_num_chosen > 0:
|
1356 |
+
metrics["rewards/chosen_sum"] = (
|
1357 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1358 |
+
)
|
1359 |
+
metrics["logps/chosen_sum"] = (
|
1360 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1361 |
+
)
|
1362 |
+
metrics["logits/chosen_sum"] = (
|
1363 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1364 |
+
)
|
1365 |
+
metrics["count/chosen"] = all_num_chosen
|
1366 |
+
|
1367 |
+
if all_num_rejected > 0:
|
1368 |
+
metrics["rewards/rejected_sum"] = (
|
1369 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1370 |
+
)
|
1371 |
+
metrics["logps/rejected_sum"] = (
|
1372 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1373 |
+
)
|
1374 |
+
metrics["logits/rejected_sum"] = (
|
1375 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1376 |
+
)
|
1377 |
+
metrics["count/rejected"] = all_num_rejected
|
1378 |
+
|
1379 |
+
loss = losses.nanmean()
|
1380 |
+
if self.aux_loss_enabled:
|
1381 |
+
loss += self.aux_loss_coef * aux_loss
|
1382 |
+
|
1383 |
+
return loss, metrics
|
1384 |
+
|
1385 |
+
def compute_loss(
|
1386 |
+
self,
|
1387 |
+
model: Union[PreTrainedModel, nn.Module],
|
1388 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1389 |
+
return_outputs=False,
|
1390 |
+
num_items_in_batch=None,
|
1391 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1392 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1393 |
+
|
1394 |
+
with compute_loss_context_manager:
|
1395 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1396 |
+
|
1397 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1398 |
+
loss = loss.to(self.args.device)
|
1399 |
+
# force log the metrics
|
1400 |
+
if self.accelerator.is_main_process:
|
1401 |
+
self.store_metrics(metrics, train_eval="train")
|
1402 |
+
|
1403 |
+
if return_outputs:
|
1404 |
+
return (loss, metrics)
|
1405 |
+
return loss
|
1406 |
+
|
1407 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1408 |
+
for key, value in metrics.items():
|
1409 |
+
self._stored_metrics[train_eval][key].append(value)
|
1410 |
+
|
1411 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1412 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
1413 |
+
return None
|
1414 |
+
return SequentialSampler(self.train_dataset)
|
1415 |
+
|
1416 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1417 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1418 |
+
|
1419 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1420 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1421 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1422 |
+
with generate_context_manager:
|
1423 |
+
policy_output = model.generate(
|
1424 |
+
input_ids=batch["prompt_input_ids"],
|
1425 |
+
attention_mask=batch["prompt_attention_mask"],
|
1426 |
+
max_length=self.max_length,
|
1427 |
+
do_sample=True,
|
1428 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1429 |
+
)
|
1430 |
+
|
1431 |
+
# if reference_output in batch use that otherwise use the reference model
|
1432 |
+
if "reference_output" in batch:
|
1433 |
+
reference_output = batch["reference_output"]
|
1434 |
+
else:
|
1435 |
+
if self.ref_model is None:
|
1436 |
+
with self.null_ref_context():
|
1437 |
+
reference_output = self.model.generate(
|
1438 |
+
input_ids=batch["prompt_input_ids"],
|
1439 |
+
attention_mask=batch["prompt_attention_mask"],
|
1440 |
+
max_length=self.max_length,
|
1441 |
+
do_sample=True,
|
1442 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1443 |
+
)
|
1444 |
+
else:
|
1445 |
+
reference_output = self.ref_model.generate(
|
1446 |
+
input_ids=batch["prompt_input_ids"],
|
1447 |
+
attention_mask=batch["prompt_attention_mask"],
|
1448 |
+
max_length=self.max_length,
|
1449 |
+
do_sample=True,
|
1450 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1451 |
+
)
|
1452 |
+
|
1453 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1454 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1455 |
+
|
1456 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1457 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1458 |
+
|
1459 |
+
return policy_output_decoded, reference_output_decoded
|
1460 |
+
|
1461 |
+
def prediction_step(
|
1462 |
+
self,
|
1463 |
+
model: Union[PreTrainedModel, nn.Module],
|
1464 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1465 |
+
prediction_loss_only: bool,
|
1466 |
+
ignore_keys: Optional[list[str]] = None,
|
1467 |
+
):
|
1468 |
+
if ignore_keys is None:
|
1469 |
+
if hasattr(model, "config"):
|
1470 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1471 |
+
else:
|
1472 |
+
ignore_keys = []
|
1473 |
+
|
1474 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1475 |
+
with torch.no_grad(), prediction_context_manager:
|
1476 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1477 |
+
|
1478 |
+
# force log the metrics
|
1479 |
+
if self.accelerator.is_main_process:
|
1480 |
+
self.store_metrics(metrics, train_eval="eval")
|
1481 |
+
|
1482 |
+
if prediction_loss_only:
|
1483 |
+
return (loss.detach(), None, None)
|
1484 |
+
|
1485 |
+
# logits for the chosen and rejected samples from model
|
1486 |
+
logits_dict = {
|
1487 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
1488 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
1489 |
+
}
|
1490 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1491 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1492 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1493 |
+
|
1494 |
+
return (loss.detach(), logits, labels)
|
1495 |
+
|
1496 |
+
def evaluation_loop(
|
1497 |
+
self,
|
1498 |
+
dataloader: DataLoader,
|
1499 |
+
description: str,
|
1500 |
+
prediction_loss_only: Optional[bool] = None,
|
1501 |
+
ignore_keys: Optional[list[str]] = None,
|
1502 |
+
metric_key_prefix: str = "eval",
|
1503 |
+
) -> EvalLoopOutput:
|
1504 |
+
"""
|
1505 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1506 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1507 |
+
|
1508 |
+
Works both with or without labels.
|
1509 |
+
"""
|
1510 |
+
|
1511 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1512 |
+
if self.generate_during_eval:
|
1513 |
+
# Generate random indices within the range of the total number of samples
|
1514 |
+
num_samples = len(dataloader.dataset)
|
1515 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1516 |
+
|
1517 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1518 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1519 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1520 |
+
random_batch = self._prepare_inputs(random_batch)
|
1521 |
+
|
1522 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1523 |
+
target_batch = {
|
1524 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1525 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1526 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1527 |
+
}
|
1528 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1529 |
+
|
1530 |
+
table = pd.DataFrame(
|
1531 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
1532 |
+
data=[
|
1533 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1534 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1535 |
+
],
|
1536 |
+
)
|
1537 |
+
if "wandb" in self.args.report_to:
|
1538 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1539 |
+
|
1540 |
+
if "comet_ml" in self.args.report_to:
|
1541 |
+
log_table_to_comet_experiment(
|
1542 |
+
name="game_log.csv",
|
1543 |
+
table=table,
|
1544 |
+
)
|
1545 |
+
|
1546 |
+
# Base evaluation
|
1547 |
+
initial_output = super().evaluation_loop(
|
1548 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
return initial_output
|
1552 |
+
|
1553 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1554 |
+
"""
|
1555 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1556 |
+
|
1557 |
+
Args:
|
1558 |
+
logs (`dict[str, float]`):
|
1559 |
+
The values to log.
|
1560 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1561 |
+
Start time of the training.
|
1562 |
+
"""
|
1563 |
+
# logs either has 'loss' or 'eval_loss'
|
1564 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1565 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
1566 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1567 |
+
# accumulate average metrics from sums and lengths
|
1568 |
+
for split in ["chosen", "rejected"]:
|
1569 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1570 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1571 |
+
for metric in ["rewards", "logps", "logits"]:
|
1572 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
1573 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1574 |
+
/ count_sum
|
1575 |
+
)
|
1576 |
+
# delete obsolete metric
|
1577 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1578 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
1579 |
+
# calculate reward margin
|
1580 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1581 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1582 |
+
# Add averaged stored metrics to logs
|
1583 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1584 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1585 |
+
del self._stored_metrics[train_eval]
|
1586 |
+
|
1587 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1588 |
+
return super().log(logs, start_time)
|
1589 |
+
else: # transformers<=4.46
|
1590 |
+
return super().log(logs)
|
1591 |
+
|
1592 |
+
def create_model_card(
|
1593 |
+
self,
|
1594 |
+
model_name: Optional[str] = None,
|
1595 |
+
dataset_name: Optional[str] = None,
|
1596 |
+
tags: Union[str, list[str], None] = None,
|
1597 |
+
):
|
1598 |
+
"""
|
1599 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1600 |
+
|
1601 |
+
Args:
|
1602 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1603 |
+
Name of the model.
|
1604 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1605 |
+
Name of the dataset used for training.
|
1606 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1607 |
+
Tags to be associated with the model card.
|
1608 |
+
"""
|
1609 |
+
if not self.is_world_process_zero():
|
1610 |
+
return
|
1611 |
+
|
1612 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1613 |
+
base_model = self.model.config._name_or_path
|
1614 |
+
else:
|
1615 |
+
base_model = None
|
1616 |
+
|
1617 |
+
tags = tags or []
|
1618 |
+
if isinstance(tags, str):
|
1619 |
+
tags = [tags]
|
1620 |
+
|
1621 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1622 |
+
tags.append("unsloth")
|
1623 |
+
|
1624 |
+
citation = textwrap.dedent("""\
|
1625 |
+
@article{jung2024binary,
|
1626 |
+
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
1627 |
+
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
1628 |
+
year = 2024,
|
1629 |
+
eprint = {arXiv:2404.04656}
|
1630 |
+
}""")
|
1631 |
+
|
1632 |
+
model_card = generate_model_card(
|
1633 |
+
base_model=base_model,
|
1634 |
+
model_name=model_name,
|
1635 |
+
hub_model_id=self.hub_model_id,
|
1636 |
+
dataset_name=dataset_name,
|
1637 |
+
tags=tags,
|
1638 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1639 |
+
comet_url=get_comet_experiment_url(),
|
1640 |
+
trainer_name="BCO",
|
1641 |
+
trainer_citation=citation,
|
1642 |
+
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
1643 |
+
paper_id="2404.04656",
|
1644 |
+
)
|
1645 |
+
|
1646 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1647 |
+
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
1648 |
+
"""
|
1649 |
+
|
1650 |
+
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
1651 |
+
|
1652 |
+
Args:
|
1653 |
+
model (`transformers.PreTrainedModel`):
|
1654 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1655 |
+
ref_model (`PreTrainedModelWrapper`):
|
1656 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1657 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1658 |
+
args (`BCOConfig`):
|
1659 |
+
The arguments to use for training.
|
1660 |
+
train_dataset (`datasets.Dataset`):
|
1661 |
+
The dataset to use for training.
|
1662 |
+
eval_dataset (`datasets.Dataset`):
|
1663 |
+
The dataset to use for evaluation.
|
1664 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1665 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1666 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1667 |
+
reuse the fine-tuned model.
|
1668 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1669 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1670 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1671 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1672 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1673 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1674 |
+
The callbacks to use for training.
|
1675 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1676 |
+
The optimizer and scheduler to use for training.
|
1677 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1678 |
+
The function to use to preprocess the logits before computing the metrics.
|
1679 |
+
peft_config (`dict`, defaults to `None`):
|
1680 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1681 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1682 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1683 |
+
a dictionary string to metric values.
|
1684 |
+
model_adapter_name (`str`, defaults to `None`):
|
1685 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1686 |
+
ref_adapter_name (`str`, defaults to `None`):
|
1687 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1688 |
+
|
1689 |
+
"""
|
1690 |
+
def __init__(
|
1691 |
+
self,
|
1692 |
+
model = None,
|
1693 |
+
ref_model = None,
|
1694 |
+
args = None,
|
1695 |
+
train_dataset = None,
|
1696 |
+
eval_dataset = None,
|
1697 |
+
processing_class = None,
|
1698 |
+
data_collator = None,
|
1699 |
+
model_init = None,
|
1700 |
+
callbacks = None,
|
1701 |
+
preprocess_logits_for_metrics = None,
|
1702 |
+
peft_config = None,
|
1703 |
+
compute_metrics = None,
|
1704 |
+
model_adapter_name = None,
|
1705 |
+
ref_adapter_name = None,
|
1706 |
+
embedding_func = None,
|
1707 |
+
embedding_tokenizer = None,
|
1708 |
+
**kwargs
|
1709 |
+
):
|
1710 |
+
if args is None: args = UnslothBCOConfig()
|
1711 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1712 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1713 |
+
force_float32 = False
|
1714 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1715 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1716 |
+
force_float32 = True
|
1717 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1718 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1719 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1720 |
+
from unsloth_zoo.utils import _get_dtype
|
1721 |
+
dtype = _get_dtype(dtype)
|
1722 |
+
float16 = dtype == torch.float16
|
1723 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1724 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1725 |
+
if force_float32:
|
1726 |
+
args.fp16 = False
|
1727 |
+
args.bf16 = False
|
1728 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1729 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1730 |
+
args.fp16 = float16
|
1731 |
+
args.bf16 = not float16
|
1732 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1733 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1734 |
+
args.eval_strategy = 'steps'
|
1735 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1736 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1737 |
+
if ga_steps is not None and ga_steps > 1:
|
1738 |
+
from transformers import __version__ as transformers_version
|
1739 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1740 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1741 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1742 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1743 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1744 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1745 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1746 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1747 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1748 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1749 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1750 |
+
if force_float32:
|
1751 |
+
args.bf16_full_eval = False
|
1752 |
+
args.fp16_full_eval = False
|
1753 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1754 |
+
args.bf16_full_eval = True
|
1755 |
+
args.fp16_full_eval = False
|
1756 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1757 |
+
args.bf16_full_eval = args.bf16
|
1758 |
+
args.fp16_full_eval = args.fp16
|
1759 |
+
_output_logits = False
|
1760 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1761 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1762 |
+
if _output_logits:
|
1763 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1764 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1765 |
+
pass
|
1766 |
+
else:
|
1767 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1768 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1769 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1770 |
+
max_seq_length = model.max_seq_length
|
1771 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1772 |
+
if model is not None and hasattr(model, 'for_training'):
|
1773 |
+
model.for_training()
|
1774 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1775 |
+
if 'processing_class' in locals():
|
1776 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1777 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1778 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1779 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1780 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1781 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1782 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1783 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1784 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1785 |
+
else:
|
1786 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1787 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1788 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1789 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1790 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1791 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1792 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1793 |
+
else:
|
1794 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1795 |
+
other_metrics = []
|
1796 |
+
|
1797 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1798 |
+
PatchRLStatistics('bco_trainer', other_metrics)
|
1799 |
+
|
1800 |
+
super().__init__(
|
1801 |
+
model = model,
|
1802 |
+
ref_model = ref_model,
|
1803 |
+
args = args,
|
1804 |
+
train_dataset = train_dataset,
|
1805 |
+
eval_dataset = eval_dataset,
|
1806 |
+
processing_class = processing_class,
|
1807 |
+
data_collator = data_collator,
|
1808 |
+
model_init = model_init,
|
1809 |
+
callbacks = callbacks,
|
1810 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1811 |
+
peft_config = peft_config,
|
1812 |
+
compute_metrics = compute_metrics,
|
1813 |
+
model_adapter_name = model_adapter_name,
|
1814 |
+
ref_adapter_name = ref_adapter_name,
|
1815 |
+
embedding_func = embedding_func,
|
1816 |
+
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
1817 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1818 |
+
self.neftune_hook_handle.remove()
|
1819 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1820 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1821 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1822 |
+
pass
|
1823 |
+
|
1824 |
+
pass
|
unsloth_compiled_cache/UnslothCPOTrainer.py
ADDED
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothCPOConfig(CPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`CPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
67 |
+
the [paper](https://huggingface.co/papers/2310.12036).
|
68 |
+
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
69 |
+
Label smoothing factor. This argument is required if you want to use the default data collator.
|
70 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
71 |
+
Type of loss to use. Possible values are:
|
72 |
+
|
73 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
74 |
+
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
|
75 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
76 |
+
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
77 |
+
|
78 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether to disable dropout in the model.
|
80 |
+
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
81 |
+
Weight of the BC regularizer in CPO training.
|
82 |
+
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
83 |
+
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
84 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
85 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
86 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
87 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
88 |
+
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
89 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
90 |
+
This argument is required if you want to use the default data collator.
|
91 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
92 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
93 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
94 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
95 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
96 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
98 |
+
string.
|
99 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
100 |
+
Number of processes to use for processing the dataset.
|
101 |
+
|
102 |
+
"""
|
103 |
+
vllm_sampling_params: Optional[Any] = field(
|
104 |
+
default = None,
|
105 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
106 |
+
)
|
107 |
+
unsloth_num_chunks : Optional[int] = field(
|
108 |
+
default = -1,
|
109 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
110 |
+
)
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
output_dir = None,
|
114 |
+
overwrite_output_dir = None,
|
115 |
+
do_train = False,
|
116 |
+
do_eval = False,
|
117 |
+
do_predict = False,
|
118 |
+
eval_strategy = 'no',
|
119 |
+
prediction_loss_only = False,
|
120 |
+
per_device_train_batch_size = 4,
|
121 |
+
per_device_eval_batch_size = 4,
|
122 |
+
per_gpu_train_batch_size = None,
|
123 |
+
per_gpu_eval_batch_size = None,
|
124 |
+
gradient_accumulation_steps = 2,
|
125 |
+
eval_accumulation_steps = 2,
|
126 |
+
eval_delay = 0,
|
127 |
+
torch_empty_cache_steps = 250,
|
128 |
+
learning_rate = 5e-05,
|
129 |
+
weight_decay = 0.01,
|
130 |
+
adam_beta1 = 0.9,
|
131 |
+
adam_beta2 = 0.999,
|
132 |
+
adam_epsilon = 1e-08,
|
133 |
+
max_grad_norm = 1.0,
|
134 |
+
num_train_epochs = 3.0,
|
135 |
+
max_steps = -1,
|
136 |
+
lr_scheduler_type = 'linear',
|
137 |
+
warmup_ratio = 0.1,
|
138 |
+
warmup_steps = 0,
|
139 |
+
log_level = 'passive',
|
140 |
+
log_level_replica = 'warning',
|
141 |
+
log_on_each_node = True,
|
142 |
+
logging_dir = None,
|
143 |
+
logging_strategy = 'steps',
|
144 |
+
logging_first_step = False,
|
145 |
+
logging_steps = 1,
|
146 |
+
logging_nan_inf_filter = False,
|
147 |
+
save_strategy = 'steps',
|
148 |
+
save_steps = 500,
|
149 |
+
save_total_limit = None,
|
150 |
+
save_safetensors = True,
|
151 |
+
save_on_each_node = False,
|
152 |
+
save_only_model = False,
|
153 |
+
restore_callback_states_from_checkpoint = False,
|
154 |
+
no_cuda = False,
|
155 |
+
use_cpu = False,
|
156 |
+
use_mps_device = False,
|
157 |
+
seed = 3407,
|
158 |
+
data_seed = 3407,
|
159 |
+
jit_mode_eval = False,
|
160 |
+
use_ipex = False,
|
161 |
+
bf16 = False,
|
162 |
+
fp16 = False,
|
163 |
+
fp16_opt_level = 'O1',
|
164 |
+
half_precision_backend = 'auto',
|
165 |
+
bf16_full_eval = False,
|
166 |
+
fp16_full_eval = False,
|
167 |
+
tf32 = None,
|
168 |
+
local_rank = -1,
|
169 |
+
ddp_backend = None,
|
170 |
+
tpu_num_cores = None,
|
171 |
+
tpu_metrics_debug = False,
|
172 |
+
debug = '',
|
173 |
+
dataloader_drop_last = False,
|
174 |
+
eval_steps = None,
|
175 |
+
dataloader_num_workers = 0,
|
176 |
+
dataloader_prefetch_factor = None,
|
177 |
+
past_index = -1,
|
178 |
+
run_name = None,
|
179 |
+
disable_tqdm = None,
|
180 |
+
remove_unused_columns = True,
|
181 |
+
label_names = None,
|
182 |
+
load_best_model_at_end = False,
|
183 |
+
metric_for_best_model = None,
|
184 |
+
greater_is_better = None,
|
185 |
+
ignore_data_skip = False,
|
186 |
+
fsdp = '',
|
187 |
+
fsdp_min_num_params = 0,
|
188 |
+
fsdp_config = None,
|
189 |
+
tp_size = 0,
|
190 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
191 |
+
accelerator_config = None,
|
192 |
+
deepspeed = None,
|
193 |
+
label_smoothing_factor = 0.0,
|
194 |
+
optim = 'adamw_8bit',
|
195 |
+
optim_args = None,
|
196 |
+
adafactor = False,
|
197 |
+
group_by_length = False,
|
198 |
+
length_column_name = 'length',
|
199 |
+
report_to = None,
|
200 |
+
ddp_find_unused_parameters = None,
|
201 |
+
ddp_bucket_cap_mb = None,
|
202 |
+
ddp_broadcast_buffers = None,
|
203 |
+
dataloader_pin_memory = True,
|
204 |
+
dataloader_persistent_workers = False,
|
205 |
+
skip_memory_metrics = True,
|
206 |
+
use_legacy_prediction_loop = False,
|
207 |
+
push_to_hub = False,
|
208 |
+
resume_from_checkpoint = None,
|
209 |
+
hub_model_id = None,
|
210 |
+
hub_strategy = 'every_save',
|
211 |
+
hub_token = None,
|
212 |
+
hub_private_repo = None,
|
213 |
+
hub_always_push = False,
|
214 |
+
gradient_checkpointing = False,
|
215 |
+
gradient_checkpointing_kwargs = None,
|
216 |
+
include_inputs_for_metrics = False,
|
217 |
+
eval_do_concat_batches = True,
|
218 |
+
fp16_backend = 'auto',
|
219 |
+
evaluation_strategy = None,
|
220 |
+
push_to_hub_model_id = None,
|
221 |
+
push_to_hub_organization = None,
|
222 |
+
push_to_hub_token = None,
|
223 |
+
mp_parameters = '',
|
224 |
+
auto_find_batch_size = False,
|
225 |
+
full_determinism = False,
|
226 |
+
torchdynamo = None,
|
227 |
+
ray_scope = 'last',
|
228 |
+
ddp_timeout = 1800,
|
229 |
+
torch_compile = False,
|
230 |
+
torch_compile_backend = None,
|
231 |
+
torch_compile_mode = None,
|
232 |
+
dispatch_batches = None,
|
233 |
+
split_batches = None,
|
234 |
+
include_tokens_per_second = False,
|
235 |
+
include_num_input_tokens_seen = False,
|
236 |
+
neftune_noise_alpha = None,
|
237 |
+
optim_target_modules = None,
|
238 |
+
batch_eval_metrics = False,
|
239 |
+
eval_on_start = False,
|
240 |
+
use_liger_kernel = False,
|
241 |
+
eval_use_gather_object = False,
|
242 |
+
average_tokens_across_devices = False,
|
243 |
+
max_length = 1024,
|
244 |
+
max_prompt_length = 512,
|
245 |
+
max_completion_length = None,
|
246 |
+
beta = 0.1,
|
247 |
+
label_smoothing = 0.0,
|
248 |
+
loss_type = 'sigmoid',
|
249 |
+
disable_dropout = True,
|
250 |
+
cpo_alpha = 1.0,
|
251 |
+
simpo_gamma = 0.5,
|
252 |
+
label_pad_token_id = -100,
|
253 |
+
padding_value = None,
|
254 |
+
truncation_mode = 'keep_end',
|
255 |
+
generate_during_eval = False,
|
256 |
+
is_encoder_decoder = None,
|
257 |
+
model_init_kwargs = None,
|
258 |
+
dataset_num_proc = None,
|
259 |
+
vllm_sampling_params = None,
|
260 |
+
unsloth_num_chunks = -1,
|
261 |
+
**kwargs,
|
262 |
+
):
|
263 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
264 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
265 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
266 |
+
output_dir = 'unsloth_training_checkpoints'
|
267 |
+
save_strategy = 'no'
|
268 |
+
if dataset_num_proc is None:
|
269 |
+
from multiprocessing import cpu_count
|
270 |
+
dataset_num_proc = cpu_count()
|
271 |
+
|
272 |
+
super().__init__(
|
273 |
+
output_dir = output_dir,
|
274 |
+
overwrite_output_dir = overwrite_output_dir,
|
275 |
+
do_train = do_train,
|
276 |
+
do_eval = do_eval,
|
277 |
+
do_predict = do_predict,
|
278 |
+
eval_strategy = eval_strategy,
|
279 |
+
prediction_loss_only = prediction_loss_only,
|
280 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
281 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
282 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
283 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
284 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
285 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
286 |
+
eval_delay = eval_delay,
|
287 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
288 |
+
learning_rate = learning_rate,
|
289 |
+
weight_decay = weight_decay,
|
290 |
+
adam_beta1 = adam_beta1,
|
291 |
+
adam_beta2 = adam_beta2,
|
292 |
+
adam_epsilon = adam_epsilon,
|
293 |
+
max_grad_norm = max_grad_norm,
|
294 |
+
num_train_epochs = num_train_epochs,
|
295 |
+
max_steps = max_steps,
|
296 |
+
lr_scheduler_type = lr_scheduler_type,
|
297 |
+
warmup_ratio = warmup_ratio,
|
298 |
+
warmup_steps = warmup_steps,
|
299 |
+
log_level = log_level,
|
300 |
+
log_level_replica = log_level_replica,
|
301 |
+
log_on_each_node = log_on_each_node,
|
302 |
+
logging_dir = logging_dir,
|
303 |
+
logging_strategy = logging_strategy,
|
304 |
+
logging_first_step = logging_first_step,
|
305 |
+
logging_steps = logging_steps,
|
306 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
307 |
+
save_strategy = save_strategy,
|
308 |
+
save_steps = save_steps,
|
309 |
+
save_total_limit = save_total_limit,
|
310 |
+
save_safetensors = save_safetensors,
|
311 |
+
save_on_each_node = save_on_each_node,
|
312 |
+
save_only_model = save_only_model,
|
313 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
314 |
+
no_cuda = no_cuda,
|
315 |
+
use_cpu = use_cpu,
|
316 |
+
use_mps_device = use_mps_device,
|
317 |
+
seed = seed,
|
318 |
+
data_seed = data_seed,
|
319 |
+
jit_mode_eval = jit_mode_eval,
|
320 |
+
use_ipex = use_ipex,
|
321 |
+
bf16 = bf16,
|
322 |
+
fp16 = fp16,
|
323 |
+
fp16_opt_level = fp16_opt_level,
|
324 |
+
half_precision_backend = half_precision_backend,
|
325 |
+
bf16_full_eval = bf16_full_eval,
|
326 |
+
fp16_full_eval = fp16_full_eval,
|
327 |
+
tf32 = tf32,
|
328 |
+
local_rank = local_rank,
|
329 |
+
ddp_backend = ddp_backend,
|
330 |
+
tpu_num_cores = tpu_num_cores,
|
331 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
332 |
+
debug = debug,
|
333 |
+
dataloader_drop_last = dataloader_drop_last,
|
334 |
+
eval_steps = eval_steps,
|
335 |
+
dataloader_num_workers = dataloader_num_workers,
|
336 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
337 |
+
past_index = past_index,
|
338 |
+
run_name = run_name,
|
339 |
+
disable_tqdm = disable_tqdm,
|
340 |
+
remove_unused_columns = remove_unused_columns,
|
341 |
+
label_names = label_names,
|
342 |
+
load_best_model_at_end = load_best_model_at_end,
|
343 |
+
metric_for_best_model = metric_for_best_model,
|
344 |
+
greater_is_better = greater_is_better,
|
345 |
+
ignore_data_skip = ignore_data_skip,
|
346 |
+
fsdp = fsdp,
|
347 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
348 |
+
fsdp_config = fsdp_config,
|
349 |
+
tp_size = tp_size,
|
350 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
351 |
+
accelerator_config = accelerator_config,
|
352 |
+
deepspeed = deepspeed,
|
353 |
+
label_smoothing_factor = label_smoothing_factor,
|
354 |
+
optim = optim,
|
355 |
+
optim_args = optim_args,
|
356 |
+
adafactor = adafactor,
|
357 |
+
group_by_length = group_by_length,
|
358 |
+
length_column_name = length_column_name,
|
359 |
+
report_to = report_to,
|
360 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
361 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
362 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
363 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
364 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
365 |
+
skip_memory_metrics = skip_memory_metrics,
|
366 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
367 |
+
push_to_hub = push_to_hub,
|
368 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
369 |
+
hub_model_id = hub_model_id,
|
370 |
+
hub_strategy = hub_strategy,
|
371 |
+
hub_token = hub_token,
|
372 |
+
hub_private_repo = hub_private_repo,
|
373 |
+
hub_always_push = hub_always_push,
|
374 |
+
gradient_checkpointing = gradient_checkpointing,
|
375 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
376 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
377 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
378 |
+
fp16_backend = fp16_backend,
|
379 |
+
evaluation_strategy = evaluation_strategy,
|
380 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
381 |
+
push_to_hub_organization = push_to_hub_organization,
|
382 |
+
push_to_hub_token = push_to_hub_token,
|
383 |
+
mp_parameters = mp_parameters,
|
384 |
+
auto_find_batch_size = auto_find_batch_size,
|
385 |
+
full_determinism = full_determinism,
|
386 |
+
torchdynamo = torchdynamo,
|
387 |
+
ray_scope = ray_scope,
|
388 |
+
ddp_timeout = ddp_timeout,
|
389 |
+
torch_compile = torch_compile,
|
390 |
+
torch_compile_backend = torch_compile_backend,
|
391 |
+
torch_compile_mode = torch_compile_mode,
|
392 |
+
dispatch_batches = dispatch_batches,
|
393 |
+
split_batches = split_batches,
|
394 |
+
include_tokens_per_second = include_tokens_per_second,
|
395 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
396 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
397 |
+
optim_target_modules = optim_target_modules,
|
398 |
+
batch_eval_metrics = batch_eval_metrics,
|
399 |
+
eval_on_start = eval_on_start,
|
400 |
+
use_liger_kernel = use_liger_kernel,
|
401 |
+
eval_use_gather_object = eval_use_gather_object,
|
402 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
403 |
+
max_length = max_length,
|
404 |
+
max_prompt_length = max_prompt_length,
|
405 |
+
max_completion_length = max_completion_length,
|
406 |
+
beta = beta,
|
407 |
+
label_smoothing = label_smoothing,
|
408 |
+
loss_type = loss_type,
|
409 |
+
disable_dropout = disable_dropout,
|
410 |
+
cpo_alpha = cpo_alpha,
|
411 |
+
simpo_gamma = simpo_gamma,
|
412 |
+
label_pad_token_id = label_pad_token_id,
|
413 |
+
padding_value = padding_value,
|
414 |
+
truncation_mode = truncation_mode,
|
415 |
+
generate_during_eval = generate_during_eval,
|
416 |
+
is_encoder_decoder = is_encoder_decoder,
|
417 |
+
model_init_kwargs = model_init_kwargs,
|
418 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
419 |
+
self.vllm_sampling_params = vllm_sampling_params
|
420 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
421 |
+
pass
|
422 |
+
|
423 |
+
class _UnslothCPOTrainer(Trainer):
|
424 |
+
r""""""
|
425 |
+
|
426 |
+
_tag_names = ["trl", "cpo"]
|
427 |
+
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
431 |
+
args: Optional[CPOConfig] = None,
|
432 |
+
data_collator: Optional[DataCollator] = None,
|
433 |
+
train_dataset: Optional[Dataset] = None,
|
434 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
435 |
+
processing_class: Optional[
|
436 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
437 |
+
] = None,
|
438 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
439 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
440 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
441 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
442 |
+
peft_config: Optional[dict] = None,
|
443 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
444 |
+
):
|
445 |
+
if args.model_init_kwargs is None:
|
446 |
+
model_init_kwargs = {}
|
447 |
+
elif not isinstance(model, str):
|
448 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
449 |
+
else:
|
450 |
+
model_init_kwargs = args.model_init_kwargs
|
451 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
452 |
+
if torch_dtype is not None:
|
453 |
+
# Convert to `torch.dtype` if an str is passed
|
454 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
455 |
+
torch_dtype = getattr(torch, torch_dtype)
|
456 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
457 |
+
raise ValueError(
|
458 |
+
f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
459 |
+
)
|
460 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
461 |
+
|
462 |
+
if isinstance(model, str):
|
463 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
464 |
+
|
465 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
466 |
+
# has been called in order to properly call autocast if needed.
|
467 |
+
self._peft_has_been_casted_to_bf16 = False
|
468 |
+
|
469 |
+
if not is_peft_available() and peft_config is not None:
|
470 |
+
raise ValueError(
|
471 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
472 |
+
)
|
473 |
+
elif is_peft_available() and peft_config is not None:
|
474 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
475 |
+
if isinstance(model, PeftModel):
|
476 |
+
model = model.merge_and_unload()
|
477 |
+
|
478 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
479 |
+
_support_gc_kwargs = hasattr(
|
480 |
+
args, "gradient_checkpointing_kwargs"
|
481 |
+
) and "gradient_checkpointing_kwargs" in list(
|
482 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
483 |
+
)
|
484 |
+
|
485 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
486 |
+
|
487 |
+
if _support_gc_kwargs:
|
488 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
489 |
+
|
490 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
491 |
+
elif getattr(args, "gradient_checkpointing", False):
|
492 |
+
# For backward compatibility with older versions of transformers
|
493 |
+
if hasattr(model, "enable_input_require_grads"):
|
494 |
+
model.enable_input_require_grads()
|
495 |
+
else:
|
496 |
+
|
497 |
+
def make_inputs_require_grad(module, input, output):
|
498 |
+
output.requires_grad_(True)
|
499 |
+
|
500 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
501 |
+
|
502 |
+
# get peft model with the given config
|
503 |
+
model = model
|
504 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
505 |
+
peft_module_casting_to_bf16(model)
|
506 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
507 |
+
self._peft_has_been_casted_to_bf16 = True
|
508 |
+
|
509 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
510 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
511 |
+
# fail or completely fail.
|
512 |
+
elif getattr(args, "gradient_checkpointing", False):
|
513 |
+
# For backward compatibility with older versions of transformers
|
514 |
+
if hasattr(model, "enable_input_require_grads"):
|
515 |
+
model.enable_input_require_grads()
|
516 |
+
else:
|
517 |
+
|
518 |
+
def make_inputs_require_grad(module, input, output):
|
519 |
+
output.requires_grad_(True)
|
520 |
+
|
521 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
522 |
+
|
523 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
524 |
+
raise ValueError(
|
525 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
526 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
527 |
+
)
|
528 |
+
|
529 |
+
if model is not None:
|
530 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
531 |
+
elif args.is_encoder_decoder is None:
|
532 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
533 |
+
else:
|
534 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
535 |
+
|
536 |
+
if self.is_encoder_decoder:
|
537 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
538 |
+
self.pad_token_id = model.config.pad_token_id
|
539 |
+
|
540 |
+
if processing_class is None:
|
541 |
+
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
542 |
+
if args.max_length is None:
|
543 |
+
warnings.warn(
|
544 |
+
"`max_length` is not set in the CPOConfig's init"
|
545 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
546 |
+
UserWarning,
|
547 |
+
)
|
548 |
+
max_length = 512
|
549 |
+
else:
|
550 |
+
max_length = args.max_length
|
551 |
+
if args.max_prompt_length is None:
|
552 |
+
warnings.warn(
|
553 |
+
"`max_prompt_length` is not set in the CPOConfig's init"
|
554 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
555 |
+
UserWarning,
|
556 |
+
)
|
557 |
+
max_prompt_length = 128
|
558 |
+
else:
|
559 |
+
max_prompt_length = args.max_prompt_length
|
560 |
+
|
561 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
562 |
+
warnings.warn(
|
563 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
564 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
565 |
+
UserWarning,
|
566 |
+
)
|
567 |
+
max_completion_length = 128
|
568 |
+
else:
|
569 |
+
max_completion_length = args.max_completion_length
|
570 |
+
|
571 |
+
if data_collator is None:
|
572 |
+
data_collator = DPODataCollatorWithPadding(
|
573 |
+
pad_token_id=processing_class.pad_token_id,
|
574 |
+
label_pad_token_id=args.label_pad_token_id,
|
575 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
576 |
+
)
|
577 |
+
|
578 |
+
if args.remove_unused_columns:
|
579 |
+
args.remove_unused_columns = False
|
580 |
+
# warn users
|
581 |
+
warnings.warn(
|
582 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
583 |
+
" we have set it for you, but you should do it yourself in the future.",
|
584 |
+
UserWarning,
|
585 |
+
)
|
586 |
+
|
587 |
+
self.use_dpo_data_collator = True
|
588 |
+
else:
|
589 |
+
self.use_dpo_data_collator = False
|
590 |
+
|
591 |
+
# Disable dropout in the model
|
592 |
+
if args.disable_dropout:
|
593 |
+
disable_dropout_in_model(model)
|
594 |
+
|
595 |
+
self.max_length = max_length
|
596 |
+
self.generate_during_eval = args.generate_during_eval
|
597 |
+
self.label_pad_token_id = args.label_pad_token_id
|
598 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
599 |
+
self.max_prompt_length = max_prompt_length
|
600 |
+
self.truncation_mode = args.truncation_mode
|
601 |
+
self.max_completion_length = max_completion_length
|
602 |
+
self.processing_class = processing_class
|
603 |
+
|
604 |
+
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
605 |
+
warnings.warn(
|
606 |
+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
607 |
+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
608 |
+
UserWarning,
|
609 |
+
)
|
610 |
+
if args.loss_type == "kto_pair":
|
611 |
+
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
612 |
+
|
613 |
+
self.beta = args.beta
|
614 |
+
self.label_smoothing = args.label_smoothing
|
615 |
+
self.loss_type = args.loss_type
|
616 |
+
self.cpo_alpha = args.cpo_alpha
|
617 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
618 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
619 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
620 |
+
warnings.warn(
|
621 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
622 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
623 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
624 |
+
"loss.",
|
625 |
+
UserWarning,
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.loss_type == "simpo":
|
629 |
+
self.simpo_gamma = args.simpo_gamma
|
630 |
+
|
631 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
632 |
+
|
633 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
634 |
+
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
635 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
636 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
637 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
638 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
639 |
+
# that the warning has already been issued.
|
640 |
+
model.warnings_issued["estimate_tokens"] = True
|
641 |
+
|
642 |
+
# Compute that only on the main process for faster data processing.
|
643 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
644 |
+
with PartialState().local_main_process_first():
|
645 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
646 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
647 |
+
train_dataset = train_dataset.map(
|
648 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
649 |
+
)
|
650 |
+
if eval_dataset is not None:
|
651 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
652 |
+
eval_dataset = eval_dataset.map(
|
653 |
+
maybe_apply_chat_template,
|
654 |
+
fn_kwargs={"tokenizer": processing_class},
|
655 |
+
num_proc=args.dataset_num_proc,
|
656 |
+
)
|
657 |
+
|
658 |
+
# tokenize the dataset
|
659 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
660 |
+
if eval_dataset is not None:
|
661 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
662 |
+
|
663 |
+
super().__init__(
|
664 |
+
model=model,
|
665 |
+
args=args,
|
666 |
+
data_collator=data_collator,
|
667 |
+
train_dataset=train_dataset,
|
668 |
+
eval_dataset=eval_dataset,
|
669 |
+
processing_class=processing_class,
|
670 |
+
model_init=model_init,
|
671 |
+
compute_metrics=compute_metrics,
|
672 |
+
callbacks=callbacks,
|
673 |
+
optimizers=optimizers,
|
674 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
675 |
+
)
|
676 |
+
|
677 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
678 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
679 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
680 |
+
self.model_accepts_loss_kwargs = False
|
681 |
+
|
682 |
+
# Add tags for models that have been loaded with the correct transformers version
|
683 |
+
if hasattr(self.model, "add_model_tags"):
|
684 |
+
self.model.add_model_tags(self._tag_names)
|
685 |
+
|
686 |
+
if not hasattr(self, "accelerator"):
|
687 |
+
raise AttributeError(
|
688 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
689 |
+
)
|
690 |
+
|
691 |
+
def build_tokenized_answer(self, prompt, answer):
|
692 |
+
"""
|
693 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
694 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
695 |
+
Reference:
|
696 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
697 |
+
"""
|
698 |
+
|
699 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
700 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
701 |
+
|
702 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
703 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
704 |
+
|
705 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
706 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
707 |
+
|
708 |
+
# Prepare input tokens for token by token comparison
|
709 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
710 |
+
|
711 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
712 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
713 |
+
|
714 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
715 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
716 |
+
# on the last token from the prompt being different when tokenized on its own
|
717 |
+
# vs when done as prompt+answer.
|
718 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
719 |
+
|
720 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
721 |
+
# last token has changed due to merging.
|
722 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
723 |
+
response_token_ids_start_idx -= 1
|
724 |
+
|
725 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
726 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
727 |
+
|
728 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
729 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
730 |
+
|
731 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
732 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
733 |
+
|
734 |
+
return dict(
|
735 |
+
prompt_input_ids=prompt_input_ids,
|
736 |
+
prompt_attention_mask=prompt_attention_mask,
|
737 |
+
input_ids=answer_input_ids,
|
738 |
+
attention_mask=answer_attention_mask,
|
739 |
+
)
|
740 |
+
|
741 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
742 |
+
"""Tokenize a single row from a CPO specific dataset.
|
743 |
+
|
744 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
745 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
746 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
747 |
+
|
748 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
749 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
750 |
+
label_pad_token_id for the prompt tokens.
|
751 |
+
"""
|
752 |
+
batch = {}
|
753 |
+
prompt = feature["prompt"]
|
754 |
+
chosen = feature["chosen"]
|
755 |
+
rejected = feature["rejected"]
|
756 |
+
|
757 |
+
if not self.is_encoder_decoder:
|
758 |
+
# Check issues below for more details
|
759 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
760 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
761 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
762 |
+
|
763 |
+
if not isinstance(prompt, str):
|
764 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
765 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
766 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
767 |
+
|
768 |
+
if not isinstance(chosen, str):
|
769 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
770 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
771 |
+
|
772 |
+
if not isinstance(rejected, str):
|
773 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
774 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
775 |
+
|
776 |
+
# Last prompt token might get merged by tokenizer and
|
777 |
+
# it should not be included for generation if that happens
|
778 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
779 |
+
|
780 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
781 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
782 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
783 |
+
|
784 |
+
for k, v in prompt_tokens.items():
|
785 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
786 |
+
|
787 |
+
# Make sure prompts only have one different token at most an
|
788 |
+
# and length only differs by 1 at most
|
789 |
+
num_diff_tokens = sum(
|
790 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
791 |
+
)
|
792 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
793 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
794 |
+
raise ValueError(
|
795 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
796 |
+
"last token due to tokenizer merge ops."
|
797 |
+
)
|
798 |
+
|
799 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
800 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
801 |
+
self.processing_class.bos_token_id,
|
802 |
+
prompt_len_input_ids,
|
803 |
+
prompt_tokens,
|
804 |
+
chosen_prompt_len_input_ids,
|
805 |
+
chosen_tokens,
|
806 |
+
rejected_prompt_len_input_ids,
|
807 |
+
rejected_tokens,
|
808 |
+
)
|
809 |
+
|
810 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
811 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
812 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
813 |
+
)
|
814 |
+
|
815 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
816 |
+
|
817 |
+
# if combined sequence is too long, truncate the prompt
|
818 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
819 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
820 |
+
if self.truncation_mode == "keep_start":
|
821 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
822 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
823 |
+
elif self.truncation_mode == "keep_end":
|
824 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
825 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
826 |
+
else:
|
827 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
828 |
+
|
829 |
+
# if that's still too long, truncate the response
|
830 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
831 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
832 |
+
for k in ["input_ids", "attention_mask"]:
|
833 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
834 |
+
|
835 |
+
# Create labels
|
836 |
+
chosen_sequence_tokens = {
|
837 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
838 |
+
}
|
839 |
+
rejected_sequence_tokens = {
|
840 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
841 |
+
}
|
842 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
843 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
844 |
+
self.label_pad_token_id
|
845 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
846 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
847 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
848 |
+
self.label_pad_token_id
|
849 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
850 |
+
|
851 |
+
for k, toks in {
|
852 |
+
"chosen_": chosen_sequence_tokens,
|
853 |
+
"rejected_": rejected_sequence_tokens,
|
854 |
+
"": prompt_tokens,
|
855 |
+
}.items():
|
856 |
+
for type_key, tokens in toks.items():
|
857 |
+
if type_key == "token_type_ids":
|
858 |
+
continue
|
859 |
+
batch[f"{k}{type_key}"] = tokens
|
860 |
+
|
861 |
+
else:
|
862 |
+
chosen_tokens = self.processing_class(
|
863 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
864 |
+
)
|
865 |
+
rejected_tokens = self.processing_class(
|
866 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
867 |
+
)
|
868 |
+
prompt_tokens = self.processing_class(
|
869 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
870 |
+
)
|
871 |
+
|
872 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
873 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
874 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
875 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
876 |
+
|
877 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
878 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
879 |
+
labels=torch.tensor(batch["rejected_labels"])
|
880 |
+
)
|
881 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
882 |
+
labels=torch.tensor(batch["chosen_labels"])
|
883 |
+
)
|
884 |
+
|
885 |
+
return batch
|
886 |
+
|
887 |
+
@staticmethod
|
888 |
+
def concatenated_inputs(
|
889 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
890 |
+
is_encoder_decoder: bool = False,
|
891 |
+
label_pad_token_id: int = -100,
|
892 |
+
padding_value: int = 0,
|
893 |
+
device: Optional[torch.device] = None,
|
894 |
+
) -> dict[str, torch.LongTensor]:
|
895 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
896 |
+
|
897 |
+
Args:
|
898 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
899 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
900 |
+
label_pad_token_id: The label pad token id.
|
901 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
902 |
+
device: The device for the concatenated inputs.
|
903 |
+
|
904 |
+
Returns:
|
905 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
906 |
+
"""
|
907 |
+
concatenated_batch = {}
|
908 |
+
|
909 |
+
if is_encoder_decoder:
|
910 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
911 |
+
else:
|
912 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
913 |
+
|
914 |
+
for k in batch:
|
915 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
916 |
+
if "labels" in k or is_encoder_decoder:
|
917 |
+
pad_value = label_pad_token_id
|
918 |
+
elif k.endswith("_input_ids"):
|
919 |
+
pad_value = padding_value
|
920 |
+
elif k.endswith("_attention_mask"):
|
921 |
+
pad_value = 0
|
922 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
923 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
924 |
+
for k in batch:
|
925 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
926 |
+
if "labels" in k or is_encoder_decoder:
|
927 |
+
pad_value = label_pad_token_id
|
928 |
+
elif k.endswith("_input_ids"):
|
929 |
+
pad_value = padding_value
|
930 |
+
elif k.endswith("_attention_mask"):
|
931 |
+
pad_value = 0
|
932 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
933 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
934 |
+
(
|
935 |
+
concatenated_batch[concatenated_key],
|
936 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
937 |
+
),
|
938 |
+
dim=0,
|
939 |
+
).to(device=device)
|
940 |
+
|
941 |
+
if is_encoder_decoder:
|
942 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
943 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
944 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
945 |
+
)
|
946 |
+
|
947 |
+
return concatenated_batch
|
948 |
+
|
949 |
+
def cpo_loss(
|
950 |
+
self,
|
951 |
+
policy_chosen_logps: torch.FloatTensor,
|
952 |
+
policy_rejected_logps: torch.FloatTensor,
|
953 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
954 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
955 |
+
|
956 |
+
Args:
|
957 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
958 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
959 |
+
|
960 |
+
Returns:
|
961 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
962 |
+
The losses tensor contains the CPO loss for each example in the batch.
|
963 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
964 |
+
"""
|
965 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
966 |
+
|
967 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
968 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
969 |
+
# calculates a conservative CPO loss.
|
970 |
+
|
971 |
+
if self.loss_type == "simpo":
|
972 |
+
gamma_logratios = self.simpo_gamma / self.beta
|
973 |
+
logits = logits - gamma_logratios
|
974 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
975 |
+
losses = (
|
976 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
977 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
978 |
+
)
|
979 |
+
elif self.loss_type == "sigmoid":
|
980 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
981 |
+
losses = (
|
982 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
983 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
984 |
+
)
|
985 |
+
elif self.loss_type == "hinge":
|
986 |
+
losses = torch.relu(1 - self.beta * logits)
|
987 |
+
elif self.loss_type == "ipo":
|
988 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
989 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
990 |
+
else:
|
991 |
+
raise ValueError(
|
992 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
993 |
+
)
|
994 |
+
|
995 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
996 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
997 |
+
|
998 |
+
return losses, chosen_rewards, rejected_rewards
|
999 |
+
|
1000 |
+
@staticmethod
|
1001 |
+
def get_batch_logps(
|
1002 |
+
logits: torch.FloatTensor,
|
1003 |
+
labels: torch.LongTensor,
|
1004 |
+
average_log_prob: bool = False,
|
1005 |
+
label_pad_token_id: int = -100,
|
1006 |
+
is_encoder_decoder: bool = False,
|
1007 |
+
) -> torch.FloatTensor:
|
1008 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1009 |
+
|
1010 |
+
Args:
|
1011 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1012 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1013 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1014 |
+
label_pad_token_id: The label pad token id.
|
1015 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
1016 |
+
|
1017 |
+
Returns:
|
1018 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1019 |
+
"""
|
1020 |
+
if logits.shape[:-1] != labels.shape:
|
1021 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1022 |
+
|
1023 |
+
if not is_encoder_decoder:
|
1024 |
+
labels = labels[:, 1:].clone()
|
1025 |
+
logits = logits[:, :-1, :]
|
1026 |
+
loss_mask = labels != label_pad_token_id
|
1027 |
+
|
1028 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1029 |
+
labels[labels == label_pad_token_id] = 0
|
1030 |
+
|
1031 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1032 |
+
|
1033 |
+
if average_log_prob:
|
1034 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1035 |
+
else:
|
1036 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1037 |
+
|
1038 |
+
def concatenated_forward(
|
1039 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1040 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1041 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1042 |
+
|
1043 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1044 |
+
"""
|
1045 |
+
concatenated_batch = self.concatenated_inputs(
|
1046 |
+
batch,
|
1047 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1048 |
+
label_pad_token_id=self.label_pad_token_id,
|
1049 |
+
padding_value=self.padding_value,
|
1050 |
+
device=self.accelerator.device,
|
1051 |
+
)
|
1052 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
1053 |
+
|
1054 |
+
model_kwargs = (
|
1055 |
+
{
|
1056 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1057 |
+
}
|
1058 |
+
if self.is_encoder_decoder
|
1059 |
+
else {}
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
if self.aux_loss_enabled:
|
1063 |
+
model_kwargs["output_router_logits"] = True
|
1064 |
+
|
1065 |
+
outputs = model(
|
1066 |
+
concatenated_batch["concatenated_input_ids"],
|
1067 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1068 |
+
use_cache=False,
|
1069 |
+
**model_kwargs,
|
1070 |
+
)
|
1071 |
+
all_logits = outputs.logits
|
1072 |
+
|
1073 |
+
def cross_entropy_loss(logits, labels):
|
1074 |
+
if not self.is_encoder_decoder:
|
1075 |
+
# Shift so that tokens < n predict n
|
1076 |
+
logits = logits[..., :-1, :].contiguous()
|
1077 |
+
labels = labels[..., 1:].contiguous()
|
1078 |
+
# Flatten the tokens
|
1079 |
+
loss_fct = nn.CrossEntropyLoss()
|
1080 |
+
logits = logits.view(-1, logits.shape[-1])
|
1081 |
+
labels = labels.view(-1)
|
1082 |
+
# Enable model parallelism
|
1083 |
+
labels = labels.to(logits.device)
|
1084 |
+
loss = loss_fct(logits, labels)
|
1085 |
+
return loss
|
1086 |
+
|
1087 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1088 |
+
|
1089 |
+
if self.cpo_alpha == 0:
|
1090 |
+
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
1091 |
+
else:
|
1092 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1093 |
+
|
1094 |
+
all_logps = self.get_batch_logps(
|
1095 |
+
all_logits,
|
1096 |
+
concatenated_batch["concatenated_labels"],
|
1097 |
+
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
1098 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1099 |
+
label_pad_token_id=self.label_pad_token_id,
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
chosen_logps = all_logps[:len_chosen]
|
1103 |
+
rejected_logps = all_logps[len_chosen:]
|
1104 |
+
|
1105 |
+
chosen_logits = all_logits[:len_chosen]
|
1106 |
+
rejected_logits = all_logits[len_chosen:]
|
1107 |
+
|
1108 |
+
if self.aux_loss_enabled:
|
1109 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
1110 |
+
|
1111 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
1112 |
+
|
1113 |
+
def get_batch_loss_metrics(
|
1114 |
+
self,
|
1115 |
+
model,
|
1116 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1117 |
+
train_eval: Literal["train", "eval"] = "train",
|
1118 |
+
):
|
1119 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
1120 |
+
metrics = {}
|
1121 |
+
|
1122 |
+
forward_output = self.concatenated_forward(model, batch)
|
1123 |
+
(
|
1124 |
+
policy_chosen_logps,
|
1125 |
+
policy_rejected_logps,
|
1126 |
+
policy_chosen_logits,
|
1127 |
+
policy_rejected_logits,
|
1128 |
+
policy_nll_loss,
|
1129 |
+
) = forward_output[:5]
|
1130 |
+
if self.aux_loss_enabled:
|
1131 |
+
aux_loss = forward_output[5]
|
1132 |
+
|
1133 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
1134 |
+
policy_chosen_logps,
|
1135 |
+
policy_rejected_logps,
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
1139 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1140 |
+
|
1141 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1142 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
1143 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
1144 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
1145 |
+
metrics[f"{prefix}rewards/margins"] = (
|
1146 |
+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
1147 |
+
)
|
1148 |
+
metrics[f"{prefix}logps/rejected"] = (
|
1149 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
1150 |
+
)
|
1151 |
+
metrics[f"{prefix}logps/chosen"] = (
|
1152 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
1153 |
+
)
|
1154 |
+
metrics[f"{prefix}logits/rejected"] = (
|
1155 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
|
1156 |
+
)
|
1157 |
+
metrics[f"{prefix}logits/chosen"] = (
|
1158 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
|
1159 |
+
)
|
1160 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
1161 |
+
|
1162 |
+
if self.aux_loss_enabled:
|
1163 |
+
loss += self.aux_loss_coef * aux_loss
|
1164 |
+
|
1165 |
+
return loss, metrics
|
1166 |
+
|
1167 |
+
def compute_loss(
|
1168 |
+
self,
|
1169 |
+
model: Union[PreTrainedModel, nn.Module],
|
1170 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1171 |
+
return_outputs=False,
|
1172 |
+
num_items_in_batch=None,
|
1173 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1174 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1175 |
+
|
1176 |
+
with compute_loss_context_manager:
|
1177 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1178 |
+
|
1179 |
+
# force log the metrics
|
1180 |
+
self.store_metrics(metrics, train_eval="train")
|
1181 |
+
|
1182 |
+
if return_outputs:
|
1183 |
+
return (loss, metrics)
|
1184 |
+
return loss
|
1185 |
+
|
1186 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1187 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1188 |
+
|
1189 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1190 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1191 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1192 |
+
|
1193 |
+
with generate_context_manager:
|
1194 |
+
policy_output = model.generate(
|
1195 |
+
input_ids=batch["prompt_input_ids"],
|
1196 |
+
attention_mask=batch["prompt_attention_mask"],
|
1197 |
+
max_length=self.max_length,
|
1198 |
+
do_sample=True,
|
1199 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1200 |
+
)
|
1201 |
+
|
1202 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1203 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1204 |
+
|
1205 |
+
return policy_output_decoded
|
1206 |
+
|
1207 |
+
def prediction_step(
|
1208 |
+
self,
|
1209 |
+
model: Union[PreTrainedModel, nn.Module],
|
1210 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1211 |
+
prediction_loss_only: bool,
|
1212 |
+
ignore_keys: Optional[list[str]] = None,
|
1213 |
+
):
|
1214 |
+
if ignore_keys is None:
|
1215 |
+
if hasattr(model, "config"):
|
1216 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1217 |
+
else:
|
1218 |
+
ignore_keys = []
|
1219 |
+
|
1220 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1221 |
+
|
1222 |
+
with torch.no_grad(), prediction_context_manager:
|
1223 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1224 |
+
|
1225 |
+
# force log the metrics
|
1226 |
+
self.store_metrics(metrics, train_eval="eval")
|
1227 |
+
|
1228 |
+
if prediction_loss_only:
|
1229 |
+
return (loss.detach(), None, None)
|
1230 |
+
|
1231 |
+
# logits for the chosen and rejected samples from model
|
1232 |
+
logits_dict = {
|
1233 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1234 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1235 |
+
}
|
1236 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1237 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1238 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1239 |
+
|
1240 |
+
return (loss.detach(), logits, labels)
|
1241 |
+
|
1242 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1243 |
+
for key, value in metrics.items():
|
1244 |
+
self._stored_metrics[train_eval][key].append(value)
|
1245 |
+
|
1246 |
+
def evaluation_loop(
|
1247 |
+
self,
|
1248 |
+
dataloader: DataLoader,
|
1249 |
+
description: str,
|
1250 |
+
prediction_loss_only: Optional[bool] = None,
|
1251 |
+
ignore_keys: Optional[list[str]] = None,
|
1252 |
+
metric_key_prefix: str = "eval",
|
1253 |
+
) -> EvalLoopOutput:
|
1254 |
+
"""
|
1255 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1256 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1257 |
+
|
1258 |
+
Works both with or without labels.
|
1259 |
+
"""
|
1260 |
+
|
1261 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1262 |
+
if self.generate_during_eval:
|
1263 |
+
# Generate random indices within the range of the total number of samples
|
1264 |
+
num_samples = len(dataloader.dataset)
|
1265 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1266 |
+
|
1267 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1268 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1269 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1270 |
+
random_batch = self._prepare_inputs(random_batch)
|
1271 |
+
|
1272 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1273 |
+
|
1274 |
+
table = pd.DataFrame(
|
1275 |
+
columns=["Prompt", "Policy"],
|
1276 |
+
data=[
|
1277 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1278 |
+
],
|
1279 |
+
)
|
1280 |
+
if "wandb" in self.args.report_to:
|
1281 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1282 |
+
|
1283 |
+
if "comet_ml" in self.args.report_to:
|
1284 |
+
log_table_to_comet_experiment(
|
1285 |
+
name="game_log.csv",
|
1286 |
+
table=table,
|
1287 |
+
)
|
1288 |
+
|
1289 |
+
# Base evaluation
|
1290 |
+
initial_output = super().evaluation_loop(
|
1291 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
return initial_output
|
1295 |
+
|
1296 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1297 |
+
"""
|
1298 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1299 |
+
|
1300 |
+
Args:
|
1301 |
+
logs (`dict[str, float]`):
|
1302 |
+
The values to log.
|
1303 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1304 |
+
Start time of the training.
|
1305 |
+
"""
|
1306 |
+
# logs either has 'loss' or 'eval_loss'
|
1307 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1308 |
+
# Add averaged stored metrics to logs
|
1309 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1310 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1311 |
+
del self._stored_metrics[train_eval]
|
1312 |
+
|
1313 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1314 |
+
return super().log(logs, start_time)
|
1315 |
+
else: # transformers<=4.46
|
1316 |
+
return super().log(logs)
|
1317 |
+
|
1318 |
+
def _shift_right(self, input_ids):
|
1319 |
+
if self.decoder_start_token_id is None:
|
1320 |
+
raise ValueError(
|
1321 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1322 |
+
)
|
1323 |
+
|
1324 |
+
# shift inputs to the right
|
1325 |
+
if is_torch_fx_proxy(input_ids):
|
1326 |
+
# Item assignment is not supported natively for proxies.
|
1327 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1328 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1329 |
+
else:
|
1330 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1331 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1332 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1333 |
+
|
1334 |
+
if self.pad_token_id is None:
|
1335 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1336 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1337 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1338 |
+
|
1339 |
+
return shifted_input_ids
|
1340 |
+
|
1341 |
+
def create_model_card(
|
1342 |
+
self,
|
1343 |
+
model_name: Optional[str] = None,
|
1344 |
+
dataset_name: Optional[str] = None,
|
1345 |
+
tags: Union[str, list[str], None] = None,
|
1346 |
+
):
|
1347 |
+
"""
|
1348 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1349 |
+
|
1350 |
+
Args:
|
1351 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1352 |
+
Name of the model.
|
1353 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1354 |
+
Name of the dataset used for training.
|
1355 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1356 |
+
Tags to be associated with the model card.
|
1357 |
+
"""
|
1358 |
+
if not self.is_world_process_zero():
|
1359 |
+
return
|
1360 |
+
|
1361 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1362 |
+
base_model = self.model.config._name_or_path
|
1363 |
+
else:
|
1364 |
+
base_model = None
|
1365 |
+
|
1366 |
+
tags = tags or []
|
1367 |
+
if isinstance(tags, str):
|
1368 |
+
tags = [tags]
|
1369 |
+
|
1370 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1371 |
+
tags.append("unsloth")
|
1372 |
+
|
1373 |
+
citation = textwrap.dedent("""\
|
1374 |
+
@inproceedings{xu2024contrastive,
|
1375 |
+
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
1376 |
+
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
1377 |
+
year = 2024,
|
1378 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
1379 |
+
publisher = {OpenReview.net},
|
1380 |
+
url = {https://openreview.net/forum?id=51iwkioZpn}
|
1381 |
+
}""")
|
1382 |
+
|
1383 |
+
model_card = generate_model_card(
|
1384 |
+
base_model=base_model,
|
1385 |
+
model_name=model_name,
|
1386 |
+
hub_model_id=self.hub_model_id,
|
1387 |
+
dataset_name=dataset_name,
|
1388 |
+
tags=tags,
|
1389 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1390 |
+
comet_url=get_comet_experiment_url(),
|
1391 |
+
trainer_name="CPO",
|
1392 |
+
trainer_citation=citation,
|
1393 |
+
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
1394 |
+
paper_id="2401.08417",
|
1395 |
+
)
|
1396 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1397 |
+
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
1398 |
+
"""
|
1399 |
+
|
1400 |
+
Initialize CPOTrainer.
|
1401 |
+
|
1402 |
+
Args:
|
1403 |
+
model (`transformers.PreTrainedModel`):
|
1404 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1405 |
+
args (`CPOConfig`):
|
1406 |
+
The CPO config arguments to use for training.
|
1407 |
+
data_collator (`transformers.DataCollator`):
|
1408 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1409 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1410 |
+
train_dataset (`datasets.Dataset`):
|
1411 |
+
The dataset to use for training.
|
1412 |
+
eval_dataset (`datasets.Dataset`):
|
1413 |
+
The dataset to use for evaluation.
|
1414 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1415 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1416 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1417 |
+
reuse the fine-tuned model.
|
1418 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1419 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1420 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1421 |
+
The callbacks to use for training.
|
1422 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1423 |
+
The optimizer and scheduler to use for training.
|
1424 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1425 |
+
The function to use to preprocess the logits before computing the metrics.
|
1426 |
+
peft_config (`dict`, defaults to `None`):
|
1427 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1428 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1429 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1430 |
+
a dictionary string to metric values.
|
1431 |
+
|
1432 |
+
"""
|
1433 |
+
def __init__(
|
1434 |
+
self,
|
1435 |
+
model = None,
|
1436 |
+
args = None,
|
1437 |
+
data_collator = None,
|
1438 |
+
train_dataset = None,
|
1439 |
+
eval_dataset = None,
|
1440 |
+
processing_class = None,
|
1441 |
+
model_init = None,
|
1442 |
+
callbacks = None,
|
1443 |
+
preprocess_logits_for_metrics = None,
|
1444 |
+
peft_config = None,
|
1445 |
+
compute_metrics = None,
|
1446 |
+
**kwargs
|
1447 |
+
):
|
1448 |
+
if args is None: args = UnslothCPOConfig()
|
1449 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1450 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1451 |
+
force_float32 = False
|
1452 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1453 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1454 |
+
force_float32 = True
|
1455 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1456 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1457 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1458 |
+
from unsloth_zoo.utils import _get_dtype
|
1459 |
+
dtype = _get_dtype(dtype)
|
1460 |
+
float16 = dtype == torch.float16
|
1461 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1462 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1463 |
+
if force_float32:
|
1464 |
+
args.fp16 = False
|
1465 |
+
args.bf16 = False
|
1466 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1467 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1468 |
+
args.fp16 = float16
|
1469 |
+
args.bf16 = not float16
|
1470 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1471 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1472 |
+
args.eval_strategy = 'steps'
|
1473 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1474 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1475 |
+
if ga_steps is not None and ga_steps > 1:
|
1476 |
+
from transformers import __version__ as transformers_version
|
1477 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1478 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1479 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1480 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1481 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1482 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1483 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1484 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1485 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1486 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1487 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1488 |
+
if force_float32:
|
1489 |
+
args.bf16_full_eval = False
|
1490 |
+
args.fp16_full_eval = False
|
1491 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1492 |
+
args.bf16_full_eval = True
|
1493 |
+
args.fp16_full_eval = False
|
1494 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1495 |
+
args.bf16_full_eval = args.bf16
|
1496 |
+
args.fp16_full_eval = args.fp16
|
1497 |
+
_output_logits = False
|
1498 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1499 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1500 |
+
if _output_logits:
|
1501 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1502 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1503 |
+
pass
|
1504 |
+
else:
|
1505 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1506 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1507 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1508 |
+
max_seq_length = model.max_seq_length
|
1509 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1510 |
+
if model is not None and hasattr(model, 'for_training'):
|
1511 |
+
model.for_training()
|
1512 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1513 |
+
if 'processing_class' in locals():
|
1514 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1515 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1516 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1517 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1518 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1519 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1520 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1521 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1522 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1523 |
+
else:
|
1524 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1525 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1526 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1527 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1528 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1529 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1530 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1531 |
+
else:
|
1532 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1533 |
+
other_metrics = []
|
1534 |
+
|
1535 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1536 |
+
PatchRLStatistics('cpo_trainer', other_metrics)
|
1537 |
+
|
1538 |
+
super().__init__(
|
1539 |
+
model = model,
|
1540 |
+
args = args,
|
1541 |
+
data_collator = data_collator,
|
1542 |
+
train_dataset = train_dataset,
|
1543 |
+
eval_dataset = eval_dataset,
|
1544 |
+
processing_class = processing_class,
|
1545 |
+
model_init = model_init,
|
1546 |
+
callbacks = callbacks,
|
1547 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1548 |
+
peft_config = peft_config,
|
1549 |
+
compute_metrics = compute_metrics,**kwargs)
|
1550 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1551 |
+
self.neftune_hook_handle.remove()
|
1552 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1553 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1554 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1555 |
+
pass
|
1556 |
+
|
1557 |
+
pass
|
unsloth_compiled_cache/UnslothDDPOTrainer.py
ADDED
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothDDPOConfig(DDPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`DDPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
54 |
+
Name of this experiment (by default is the file name without the extension name).
|
55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
56 |
+
Name of this run.
|
57 |
+
seed (`int`, *optional*, defaults to `0`):
|
58 |
+
Random seed.
|
59 |
+
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
60 |
+
Log with either 'wandb' or 'tensorboard', check
|
61 |
+
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
62 |
+
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
63 |
+
Keyword arguments for the tracker (e.g. wandb_project).
|
64 |
+
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
65 |
+
Keyword arguments for the accelerator.
|
66 |
+
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
67 |
+
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
68 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
69 |
+
Name of project to use for tracking.
|
70 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
71 |
+
Top-level logging directory for checkpoint saving.
|
72 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
73 |
+
Number of epochs to train.
|
74 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
75 |
+
Number of epochs between saving model checkpoints.
|
76 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
77 |
+
Number of checkpoints to keep before overwriting old ones.
|
78 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
79 |
+
Mixed precision training.
|
80 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
81 |
+
Allow `tf32` on Ampere GPUs.
|
82 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
83 |
+
Resume training from a checkpoint.
|
84 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
85 |
+
Number of sampler inference steps.
|
86 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
87 |
+
Eta parameter for the DDIM sampler.
|
88 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
89 |
+
Classifier-free guidance weight.
|
90 |
+
sample_batch_size (`int`, *optional*, defaults to `1`):
|
91 |
+
Batch size (per GPU) to use for sampling.
|
92 |
+
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
93 |
+
Number of batches to sample per epoch.
|
94 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
95 |
+
Batch size (per GPU) to use for training.
|
96 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
97 |
+
Use 8bit Adam optimizer from bitsandbytes.
|
98 |
+
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
99 |
+
Learning rate.
|
100 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
101 |
+
Adam beta1.
|
102 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
103 |
+
Adam beta2.
|
104 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
105 |
+
Adam weight decay.
|
106 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
107 |
+
Adam epsilon.
|
108 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
109 |
+
Number of gradient accumulation steps.
|
110 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
111 |
+
Maximum gradient norm for gradient clipping.
|
112 |
+
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
113 |
+
Number of inner epochs per outer epoch.
|
114 |
+
train_cfg (`bool`, *optional*, defaults to `True`):
|
115 |
+
Whether to use classifier-free guidance during training.
|
116 |
+
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
117 |
+
Clip advantages to the range.
|
118 |
+
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
119 |
+
PPO clip range.
|
120 |
+
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
121 |
+
Fraction of timesteps to train on.
|
122 |
+
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
123 |
+
Whether to track statistics for each prompt separately.
|
124 |
+
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
125 |
+
Number of reward values to store in the buffer for each prompt.
|
126 |
+
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
127 |
+
Minimum number of reward values to store in the buffer.
|
128 |
+
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
129 |
+
Whether to compute rewards asynchronously.
|
130 |
+
max_workers (`int`, *optional*, defaults to `2`):
|
131 |
+
Maximum number of workers to use for async reward computation.
|
132 |
+
negative_prompts (`str`, *optional*, defaults to `""`):
|
133 |
+
Comma-separated list of prompts to use as negative examples.
|
134 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
135 |
+
Whether to push the final model checkpoint to the Hub.
|
136 |
+
|
137 |
+
"""
|
138 |
+
vllm_sampling_params: Optional[Any] = field(
|
139 |
+
default = None,
|
140 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
141 |
+
)
|
142 |
+
unsloth_num_chunks : Optional[int] = field(
|
143 |
+
default = -1,
|
144 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
145 |
+
)
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
exp_name = 'demo',
|
149 |
+
run_name = '',
|
150 |
+
seed = 3407,
|
151 |
+
log_with = None,
|
152 |
+
tracker_project_name = 'trl',
|
153 |
+
logdir = 'logs',
|
154 |
+
num_epochs = 100,
|
155 |
+
save_freq = 1,
|
156 |
+
num_checkpoint_limit = 5,
|
157 |
+
mixed_precision = 'fp16',
|
158 |
+
allow_tf32 = True,
|
159 |
+
resume_from = '',
|
160 |
+
sample_num_steps = 50,
|
161 |
+
sample_eta = 1.0,
|
162 |
+
sample_guidance_scale = 5.0,
|
163 |
+
sample_batch_size = 1,
|
164 |
+
sample_num_batches_per_epoch = 2,
|
165 |
+
train_batch_size = 1,
|
166 |
+
train_use_8bit_adam = False,
|
167 |
+
train_learning_rate = 5e-05,
|
168 |
+
train_adam_beta1 = 0.9,
|
169 |
+
train_adam_beta2 = 0.999,
|
170 |
+
train_adam_weight_decay = 0.01,
|
171 |
+
train_adam_epsilon = 1e-08,
|
172 |
+
train_gradient_accumulation_steps = 2,
|
173 |
+
train_max_grad_norm = 1.0,
|
174 |
+
train_num_inner_epochs = 1,
|
175 |
+
train_cfg = True,
|
176 |
+
train_adv_clip_max = 5.0,
|
177 |
+
train_clip_range = 0.0001,
|
178 |
+
train_timestep_fraction = 1.0,
|
179 |
+
per_prompt_stat_tracking = False,
|
180 |
+
per_prompt_stat_tracking_buffer_size = 16,
|
181 |
+
per_prompt_stat_tracking_min_count = 16,
|
182 |
+
async_reward_computation = False,
|
183 |
+
max_workers = 2,
|
184 |
+
negative_prompts = '',
|
185 |
+
push_to_hub = False,
|
186 |
+
vllm_sampling_params = None,
|
187 |
+
unsloth_num_chunks = -1,
|
188 |
+
**kwargs,
|
189 |
+
):
|
190 |
+
|
191 |
+
super().__init__(
|
192 |
+
exp_name = exp_name,
|
193 |
+
run_name = run_name,
|
194 |
+
seed = seed,
|
195 |
+
log_with = log_with,
|
196 |
+
tracker_project_name = tracker_project_name,
|
197 |
+
logdir = logdir,
|
198 |
+
num_epochs = num_epochs,
|
199 |
+
save_freq = save_freq,
|
200 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
201 |
+
mixed_precision = mixed_precision,
|
202 |
+
allow_tf32 = allow_tf32,
|
203 |
+
resume_from = resume_from,
|
204 |
+
sample_num_steps = sample_num_steps,
|
205 |
+
sample_eta = sample_eta,
|
206 |
+
sample_guidance_scale = sample_guidance_scale,
|
207 |
+
sample_batch_size = sample_batch_size,
|
208 |
+
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
209 |
+
train_batch_size = train_batch_size,
|
210 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
211 |
+
train_learning_rate = train_learning_rate,
|
212 |
+
train_adam_beta1 = train_adam_beta1,
|
213 |
+
train_adam_beta2 = train_adam_beta2,
|
214 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
215 |
+
train_adam_epsilon = train_adam_epsilon,
|
216 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
217 |
+
train_max_grad_norm = train_max_grad_norm,
|
218 |
+
train_num_inner_epochs = train_num_inner_epochs,
|
219 |
+
train_cfg = train_cfg,
|
220 |
+
train_adv_clip_max = train_adv_clip_max,
|
221 |
+
train_clip_range = train_clip_range,
|
222 |
+
train_timestep_fraction = train_timestep_fraction,
|
223 |
+
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
224 |
+
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
225 |
+
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
226 |
+
async_reward_computation = async_reward_computation,
|
227 |
+
max_workers = max_workers,
|
228 |
+
negative_prompts = negative_prompts,
|
229 |
+
push_to_hub = push_to_hub,**kwargs)
|
230 |
+
self.vllm_sampling_params = vllm_sampling_params
|
231 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
232 |
+
pass
|
233 |
+
|
234 |
+
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
235 |
+
""""""
|
236 |
+
|
237 |
+
_tag_names = ["trl", "ddpo"]
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
config: DDPOConfig,
|
242 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
243 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
244 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
245 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
246 |
+
):
|
247 |
+
if image_samples_hook is None:
|
248 |
+
warn("No image_samples_hook provided; no images will be logged")
|
249 |
+
|
250 |
+
self.prompt_fn = prompt_function
|
251 |
+
self.reward_fn = reward_function
|
252 |
+
self.config = config
|
253 |
+
self.image_samples_callback = image_samples_hook
|
254 |
+
|
255 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
256 |
+
|
257 |
+
if self.config.resume_from:
|
258 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
259 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
260 |
+
# get the most recent checkpoint in this directory
|
261 |
+
checkpoints = list(
|
262 |
+
filter(
|
263 |
+
lambda x: "checkpoint_" in x,
|
264 |
+
os.listdir(self.config.resume_from),
|
265 |
+
)
|
266 |
+
)
|
267 |
+
if len(checkpoints) == 0:
|
268 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
269 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
270 |
+
self.config.resume_from = os.path.join(
|
271 |
+
self.config.resume_from,
|
272 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
273 |
+
)
|
274 |
+
|
275 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
276 |
+
|
277 |
+
# number of timesteps within each trajectory to train on
|
278 |
+
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
279 |
+
|
280 |
+
self.accelerator = Accelerator(
|
281 |
+
log_with=self.config.log_with,
|
282 |
+
mixed_precision=self.config.mixed_precision,
|
283 |
+
project_config=accelerator_project_config,
|
284 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
285 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
286 |
+
# the total number of optimizer steps to accumulate across.
|
287 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
288 |
+
**self.config.accelerator_kwargs,
|
289 |
+
)
|
290 |
+
|
291 |
+
is_okay, message = self._config_check()
|
292 |
+
if not is_okay:
|
293 |
+
raise ValueError(message)
|
294 |
+
|
295 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
296 |
+
|
297 |
+
if self.accelerator.is_main_process:
|
298 |
+
self.accelerator.init_trackers(
|
299 |
+
self.config.tracker_project_name,
|
300 |
+
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
301 |
+
init_kwargs=self.config.tracker_kwargs,
|
302 |
+
)
|
303 |
+
|
304 |
+
logger.info(f"\n{config}")
|
305 |
+
|
306 |
+
set_seed(self.config.seed, device_specific=True)
|
307 |
+
|
308 |
+
self.sd_pipeline = sd_pipeline
|
309 |
+
|
310 |
+
self.sd_pipeline.set_progress_bar_config(
|
311 |
+
position=1,
|
312 |
+
disable=not self.accelerator.is_local_main_process,
|
313 |
+
leave=False,
|
314 |
+
desc="Timestep",
|
315 |
+
dynamic_ncols=True,
|
316 |
+
)
|
317 |
+
|
318 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
319 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
320 |
+
if self.accelerator.mixed_precision == "fp16":
|
321 |
+
inference_dtype = torch.float16
|
322 |
+
elif self.accelerator.mixed_precision == "bf16":
|
323 |
+
inference_dtype = torch.bfloat16
|
324 |
+
else:
|
325 |
+
inference_dtype = torch.float32
|
326 |
+
|
327 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
328 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
329 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
330 |
+
|
331 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
332 |
+
|
333 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
334 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
335 |
+
|
336 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
337 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
338 |
+
if self.config.allow_tf32:
|
339 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
340 |
+
|
341 |
+
self.optimizer = self._setup_optimizer(
|
342 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
343 |
+
)
|
344 |
+
|
345 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
346 |
+
self.sd_pipeline.tokenizer(
|
347 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
348 |
+
return_tensors="pt",
|
349 |
+
padding="max_length",
|
350 |
+
truncation=True,
|
351 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
352 |
+
).input_ids.to(self.accelerator.device)
|
353 |
+
)[0]
|
354 |
+
|
355 |
+
if config.per_prompt_stat_tracking:
|
356 |
+
self.stat_tracker = PerPromptStatTracker(
|
357 |
+
config.per_prompt_stat_tracking_buffer_size,
|
358 |
+
config.per_prompt_stat_tracking_min_count,
|
359 |
+
)
|
360 |
+
|
361 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
362 |
+
# more memory
|
363 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
364 |
+
|
365 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
366 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
367 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
368 |
+
else:
|
369 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
370 |
+
|
371 |
+
if self.config.async_reward_computation:
|
372 |
+
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
373 |
+
|
374 |
+
if config.resume_from:
|
375 |
+
logger.info(f"Resuming from {config.resume_from}")
|
376 |
+
self.accelerator.load_state(config.resume_from)
|
377 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
378 |
+
else:
|
379 |
+
self.first_epoch = 0
|
380 |
+
|
381 |
+
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
382 |
+
if not is_async:
|
383 |
+
rewards = []
|
384 |
+
for images, prompts, prompt_metadata in prompt_image_pairs:
|
385 |
+
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
386 |
+
rewards.append(
|
387 |
+
(
|
388 |
+
torch.as_tensor(reward, device=self.accelerator.device),
|
389 |
+
reward_metadata,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
394 |
+
rewards = [
|
395 |
+
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
396 |
+
for reward, reward_metadata in rewards
|
397 |
+
]
|
398 |
+
|
399 |
+
return zip(*rewards)
|
400 |
+
|
401 |
+
def step(self, epoch: int, global_step: int):
|
402 |
+
"""
|
403 |
+
Perform a single step of training.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
epoch (int): The current epoch.
|
407 |
+
global_step (int): The current global step.
|
408 |
+
|
409 |
+
Side Effects:
|
410 |
+
- Model weights are updated
|
411 |
+
- Logs the statistics to the accelerator trackers.
|
412 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
global_step (int): The updated global step.
|
416 |
+
|
417 |
+
"""
|
418 |
+
samples, prompt_image_data = self._generate_samples(
|
419 |
+
iterations=self.config.sample_num_batches_per_epoch,
|
420 |
+
batch_size=self.config.sample_batch_size,
|
421 |
+
)
|
422 |
+
|
423 |
+
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
424 |
+
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
425 |
+
rewards, rewards_metadata = self.compute_rewards(
|
426 |
+
prompt_image_data, is_async=self.config.async_reward_computation
|
427 |
+
)
|
428 |
+
|
429 |
+
for i, image_data in enumerate(prompt_image_data):
|
430 |
+
image_data.extend([rewards[i], rewards_metadata[i]])
|
431 |
+
|
432 |
+
if self.image_samples_callback is not None:
|
433 |
+
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
434 |
+
|
435 |
+
rewards = torch.cat(rewards)
|
436 |
+
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
437 |
+
|
438 |
+
self.accelerator.log(
|
439 |
+
{
|
440 |
+
"reward": rewards,
|
441 |
+
"epoch": epoch,
|
442 |
+
"reward_mean": rewards.mean(),
|
443 |
+
"reward_std": rewards.std(),
|
444 |
+
},
|
445 |
+
step=global_step,
|
446 |
+
)
|
447 |
+
|
448 |
+
if self.config.per_prompt_stat_tracking:
|
449 |
+
# gather the prompts across processes
|
450 |
+
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
451 |
+
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
452 |
+
advantages = self.stat_tracker.update(prompts, rewards)
|
453 |
+
else:
|
454 |
+
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
455 |
+
|
456 |
+
# ungather advantages; keep the entries corresponding to the samples on this process
|
457 |
+
samples["advantages"] = (
|
458 |
+
torch.as_tensor(advantages)
|
459 |
+
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
460 |
+
.to(self.accelerator.device)
|
461 |
+
)
|
462 |
+
|
463 |
+
del samples["prompt_ids"]
|
464 |
+
|
465 |
+
total_batch_size, num_timesteps = samples["timesteps"].shape
|
466 |
+
|
467 |
+
for inner_epoch in range(self.config.train_num_inner_epochs):
|
468 |
+
# shuffle samples along batch dimension
|
469 |
+
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
470 |
+
samples = {k: v[perm] for k, v in samples.items()}
|
471 |
+
|
472 |
+
# shuffle along time dimension independently for each sample
|
473 |
+
# still trying to understand the code below
|
474 |
+
perms = torch.stack(
|
475 |
+
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
476 |
+
)
|
477 |
+
|
478 |
+
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
479 |
+
samples[key] = samples[key][
|
480 |
+
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
481 |
+
perms,
|
482 |
+
]
|
483 |
+
|
484 |
+
original_keys = samples.keys()
|
485 |
+
original_values = samples.values()
|
486 |
+
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
487 |
+
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
488 |
+
|
489 |
+
# Transpose the list of original values
|
490 |
+
transposed_values = zip(*reshaped_values)
|
491 |
+
# Create new dictionaries for each row of transposed values
|
492 |
+
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
493 |
+
|
494 |
+
self.sd_pipeline.unet.train()
|
495 |
+
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
496 |
+
# ensure optimization step at the end of the inner epoch
|
497 |
+
if not self.accelerator.sync_gradients:
|
498 |
+
raise ValueError(
|
499 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
500 |
+
)
|
501 |
+
|
502 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
503 |
+
self.accelerator.save_state()
|
504 |
+
|
505 |
+
return global_step
|
506 |
+
|
507 |
+
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
508 |
+
"""
|
509 |
+
Calculate the loss for a batch of an unpacked sample
|
510 |
+
|
511 |
+
Args:
|
512 |
+
latents (torch.Tensor):
|
513 |
+
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
514 |
+
timesteps (torch.Tensor):
|
515 |
+
The timesteps sampled from the diffusion model, shape: [batch_size]
|
516 |
+
next_latents (torch.Tensor):
|
517 |
+
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
518 |
+
log_probs (torch.Tensor):
|
519 |
+
The log probabilities of the latents, shape: [batch_size]
|
520 |
+
advantages (torch.Tensor):
|
521 |
+
The advantages of the latents, shape: [batch_size]
|
522 |
+
embeds (torch.Tensor):
|
523 |
+
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
524 |
+
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
528 |
+
(all of these are of shape (1,))
|
529 |
+
"""
|
530 |
+
with self.autocast():
|
531 |
+
if self.config.train_cfg:
|
532 |
+
noise_pred = self.sd_pipeline.unet(
|
533 |
+
torch.cat([latents] * 2),
|
534 |
+
torch.cat([timesteps] * 2),
|
535 |
+
embeds,
|
536 |
+
).sample
|
537 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
538 |
+
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
539 |
+
noise_pred_text - noise_pred_uncond
|
540 |
+
)
|
541 |
+
else:
|
542 |
+
noise_pred = self.sd_pipeline.unet(
|
543 |
+
latents,
|
544 |
+
timesteps,
|
545 |
+
embeds,
|
546 |
+
).sample
|
547 |
+
# compute the log prob of next_latents given latents under the current model
|
548 |
+
|
549 |
+
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
550 |
+
noise_pred,
|
551 |
+
timesteps,
|
552 |
+
latents,
|
553 |
+
eta=self.config.sample_eta,
|
554 |
+
prev_sample=next_latents,
|
555 |
+
)
|
556 |
+
|
557 |
+
log_prob = scheduler_step_output.log_probs
|
558 |
+
|
559 |
+
advantages = torch.clamp(
|
560 |
+
advantages,
|
561 |
+
-self.config.train_adv_clip_max,
|
562 |
+
self.config.train_adv_clip_max,
|
563 |
+
)
|
564 |
+
|
565 |
+
ratio = torch.exp(log_prob - log_probs)
|
566 |
+
|
567 |
+
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
568 |
+
|
569 |
+
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
570 |
+
|
571 |
+
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
572 |
+
|
573 |
+
return loss, approx_kl, clipfrac
|
574 |
+
|
575 |
+
def loss(
|
576 |
+
self,
|
577 |
+
advantages: torch.Tensor,
|
578 |
+
clip_range: float,
|
579 |
+
ratio: torch.Tensor,
|
580 |
+
):
|
581 |
+
unclipped_loss = -advantages * ratio
|
582 |
+
clipped_loss = -advantages * torch.clamp(
|
583 |
+
ratio,
|
584 |
+
1.0 - clip_range,
|
585 |
+
1.0 + clip_range,
|
586 |
+
)
|
587 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
588 |
+
|
589 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
590 |
+
if self.config.train_use_8bit_adam:
|
591 |
+
import bitsandbytes
|
592 |
+
|
593 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
594 |
+
else:
|
595 |
+
optimizer_cls = torch.optim.AdamW
|
596 |
+
|
597 |
+
return optimizer_cls(
|
598 |
+
trainable_layers_parameters,
|
599 |
+
lr=self.config.train_learning_rate,
|
600 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
601 |
+
weight_decay=self.config.train_adam_weight_decay,
|
602 |
+
eps=self.config.train_adam_epsilon,
|
603 |
+
)
|
604 |
+
|
605 |
+
def _save_model_hook(self, models, weights, output_dir):
|
606 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
607 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
608 |
+
|
609 |
+
def _load_model_hook(self, models, input_dir):
|
610 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
611 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
612 |
+
|
613 |
+
def _generate_samples(self, iterations, batch_size):
|
614 |
+
"""
|
615 |
+
Generate samples from the model
|
616 |
+
|
617 |
+
Args:
|
618 |
+
iterations (int): Number of iterations to generate samples for
|
619 |
+
batch_size (int): Batch size to use for sampling
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
623 |
+
"""
|
624 |
+
samples = []
|
625 |
+
prompt_image_pairs = []
|
626 |
+
self.sd_pipeline.unet.eval()
|
627 |
+
|
628 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
629 |
+
|
630 |
+
for _ in range(iterations):
|
631 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
632 |
+
|
633 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
634 |
+
prompts,
|
635 |
+
return_tensors="pt",
|
636 |
+
padding="max_length",
|
637 |
+
truncation=True,
|
638 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
639 |
+
).input_ids.to(self.accelerator.device)
|
640 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
641 |
+
|
642 |
+
with self.autocast():
|
643 |
+
sd_output = self.sd_pipeline(
|
644 |
+
prompt_embeds=prompt_embeds,
|
645 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
646 |
+
num_inference_steps=self.config.sample_num_steps,
|
647 |
+
guidance_scale=self.config.sample_guidance_scale,
|
648 |
+
eta=self.config.sample_eta,
|
649 |
+
output_type="pt",
|
650 |
+
)
|
651 |
+
|
652 |
+
images = sd_output.images
|
653 |
+
latents = sd_output.latents
|
654 |
+
log_probs = sd_output.log_probs
|
655 |
+
|
656 |
+
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
657 |
+
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
658 |
+
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
659 |
+
|
660 |
+
samples.append(
|
661 |
+
{
|
662 |
+
"prompt_ids": prompt_ids,
|
663 |
+
"prompt_embeds": prompt_embeds,
|
664 |
+
"timesteps": timesteps,
|
665 |
+
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
666 |
+
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
667 |
+
"log_probs": log_probs,
|
668 |
+
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
669 |
+
}
|
670 |
+
)
|
671 |
+
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
672 |
+
|
673 |
+
return samples, prompt_image_pairs
|
674 |
+
|
675 |
+
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
676 |
+
"""
|
677 |
+
Train on a batch of samples. Main training segment
|
678 |
+
|
679 |
+
Args:
|
680 |
+
inner_epoch (int): The current inner epoch
|
681 |
+
epoch (int): The current epoch
|
682 |
+
global_step (int): The current global step
|
683 |
+
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
684 |
+
|
685 |
+
Side Effects:
|
686 |
+
- Model weights are updated
|
687 |
+
- Logs the statistics to the accelerator trackers.
|
688 |
+
|
689 |
+
Returns:
|
690 |
+
global_step (int): The updated global step
|
691 |
+
"""
|
692 |
+
info = defaultdict(list)
|
693 |
+
for _i, sample in enumerate(batched_samples):
|
694 |
+
if self.config.train_cfg:
|
695 |
+
# concat negative prompts to sample prompts to avoid two forward passes
|
696 |
+
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
697 |
+
else:
|
698 |
+
embeds = sample["prompt_embeds"]
|
699 |
+
|
700 |
+
for j in range(self.num_train_timesteps):
|
701 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
702 |
+
loss, approx_kl, clipfrac = self.calculate_loss(
|
703 |
+
sample["latents"][:, j],
|
704 |
+
sample["timesteps"][:, j],
|
705 |
+
sample["next_latents"][:, j],
|
706 |
+
sample["log_probs"][:, j],
|
707 |
+
sample["advantages"],
|
708 |
+
embeds,
|
709 |
+
)
|
710 |
+
info["approx_kl"].append(approx_kl)
|
711 |
+
info["clipfrac"].append(clipfrac)
|
712 |
+
info["loss"].append(loss)
|
713 |
+
|
714 |
+
self.accelerator.backward(loss)
|
715 |
+
if self.accelerator.sync_gradients:
|
716 |
+
self.accelerator.clip_grad_norm_(
|
717 |
+
self.trainable_layers.parameters()
|
718 |
+
if not isinstance(self.trainable_layers, list)
|
719 |
+
else self.trainable_layers,
|
720 |
+
self.config.train_max_grad_norm,
|
721 |
+
)
|
722 |
+
self.optimizer.step()
|
723 |
+
self.optimizer.zero_grad()
|
724 |
+
|
725 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
726 |
+
if self.accelerator.sync_gradients:
|
727 |
+
# log training-related stuff
|
728 |
+
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
729 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
730 |
+
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
731 |
+
self.accelerator.log(info, step=global_step)
|
732 |
+
global_step += 1
|
733 |
+
info = defaultdict(list)
|
734 |
+
return global_step
|
735 |
+
|
736 |
+
def _config_check(self) -> tuple[bool, str]:
|
737 |
+
samples_per_epoch = (
|
738 |
+
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
739 |
+
)
|
740 |
+
total_train_batch_size = (
|
741 |
+
self.config.train_batch_size
|
742 |
+
* self.accelerator.num_processes
|
743 |
+
* self.config.train_gradient_accumulation_steps
|
744 |
+
)
|
745 |
+
|
746 |
+
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
747 |
+
return (
|
748 |
+
False,
|
749 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
750 |
+
)
|
751 |
+
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
752 |
+
return (
|
753 |
+
False,
|
754 |
+
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
755 |
+
)
|
756 |
+
if not samples_per_epoch % total_train_batch_size == 0:
|
757 |
+
return (
|
758 |
+
False,
|
759 |
+
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
760 |
+
)
|
761 |
+
return True, ""
|
762 |
+
|
763 |
+
def train(self, epochs: Optional[int] = None):
|
764 |
+
"""
|
765 |
+
Train the model for a given number of epochs
|
766 |
+
"""
|
767 |
+
global_step = 0
|
768 |
+
if epochs is None:
|
769 |
+
epochs = self.config.num_epochs
|
770 |
+
for epoch in range(self.first_epoch, epochs):
|
771 |
+
global_step = self.step(epoch, global_step)
|
772 |
+
|
773 |
+
def _save_pretrained(self, save_directory):
|
774 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
775 |
+
self.create_model_card()
|
776 |
+
|
777 |
+
def create_model_card(
|
778 |
+
self,
|
779 |
+
model_name: Optional[str] = None,
|
780 |
+
dataset_name: Optional[str] = None,
|
781 |
+
tags: Union[str, list[str], None] = None,
|
782 |
+
):
|
783 |
+
"""
|
784 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
788 |
+
Name of the model.
|
789 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
790 |
+
Name of the dataset used for training.
|
791 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
792 |
+
Tags to be associated with the model card.
|
793 |
+
"""
|
794 |
+
if not self.is_world_process_zero():
|
795 |
+
return
|
796 |
+
|
797 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
798 |
+
base_model = self.model.config._name_or_path
|
799 |
+
else:
|
800 |
+
base_model = None
|
801 |
+
|
802 |
+
tags = tags or []
|
803 |
+
if isinstance(tags, str):
|
804 |
+
tags = [tags]
|
805 |
+
|
806 |
+
if hasattr(self.model.config, "unsloth_version"):
|
807 |
+
tags.append("unsloth")
|
808 |
+
|
809 |
+
citation = textwrap.dedent("""\
|
810 |
+
@inproceedings{black2024training,
|
811 |
+
title = {{Training Diffusion Models with Reinforcement Learning}},
|
812 |
+
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
813 |
+
year = 2024,
|
814 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
815 |
+
publisher = {OpenReview.net},
|
816 |
+
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
817 |
+
}""")
|
818 |
+
|
819 |
+
model_card = generate_model_card(
|
820 |
+
base_model=base_model,
|
821 |
+
model_name=model_name,
|
822 |
+
hub_model_id=self.hub_model_id,
|
823 |
+
dataset_name=dataset_name,
|
824 |
+
tags=tags,
|
825 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
826 |
+
comet_url=get_comet_experiment_url(),
|
827 |
+
trainer_name="DDPO",
|
828 |
+
trainer_citation=citation,
|
829 |
+
paper_title="Training Diffusion Models with Reinforcement Learning",
|
830 |
+
paper_id="2305.13301",
|
831 |
+
)
|
832 |
+
|
833 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
834 |
+
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
835 |
+
"""
|
836 |
+
|
837 |
+
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
838 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
839 |
+
As of now only Stable Diffusion based pipelines are supported
|
840 |
+
|
841 |
+
Attributes:
|
842 |
+
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
843 |
+
details.
|
844 |
+
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
845 |
+
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
846 |
+
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
847 |
+
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
848 |
+
|
849 |
+
"""
|
850 |
+
def __init__(
|
851 |
+
self,
|
852 |
+
config,
|
853 |
+
reward_function,
|
854 |
+
prompt_function,
|
855 |
+
sd_pipeline,
|
856 |
+
image_samples_hook = None,
|
857 |
+
**kwargs
|
858 |
+
):
|
859 |
+
if args is None: args = UnslothDDPOConfig()
|
860 |
+
other_metrics = []
|
861 |
+
|
862 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
863 |
+
PatchRLStatistics('ddpo_trainer', other_metrics)
|
864 |
+
|
865 |
+
super().__init__(
|
866 |
+
config = config,
|
867 |
+
reward_function = reward_function,
|
868 |
+
prompt_function = prompt_function,
|
869 |
+
sd_pipeline = sd_pipeline,
|
870 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
871 |
+
|
872 |
+
pass
|
unsloth_compiled_cache/UnslothDPOTrainer.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
unsloth_compiled_cache/UnslothGKDTrainer.py
ADDED
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothGKDConfig(GKDConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for [`GKDTrainer`].
|
47 |
+
|
48 |
+
Args:
|
49 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
50 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
51 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
52 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
53 |
+
student-generated outputs).
|
54 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
55 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
56 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
57 |
+
max_new_tokens (`int`, *optional*, defaults to `128`):
|
58 |
+
Maximum number of tokens to generate per completion.
|
59 |
+
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
61 |
+
being trained.
|
62 |
+
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
63 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
64 |
+
from a string.
|
65 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether to disable dropout in the model.
|
67 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
68 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
69 |
+
on teacher-generated output).
|
70 |
+
|
71 |
+
"""
|
72 |
+
vllm_sampling_params: Optional[Any] = field(
|
73 |
+
default = None,
|
74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
+
)
|
76 |
+
unsloth_num_chunks : Optional[int] = field(
|
77 |
+
default = -1,
|
78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
+
)
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
output_dir = None,
|
83 |
+
overwrite_output_dir = None,
|
84 |
+
do_train = False,
|
85 |
+
do_eval = False,
|
86 |
+
do_predict = False,
|
87 |
+
eval_strategy = 'no',
|
88 |
+
prediction_loss_only = False,
|
89 |
+
per_device_train_batch_size = 4,
|
90 |
+
per_device_eval_batch_size = 4,
|
91 |
+
per_gpu_train_batch_size = None,
|
92 |
+
per_gpu_eval_batch_size = None,
|
93 |
+
gradient_accumulation_steps = 2,
|
94 |
+
eval_accumulation_steps = 2,
|
95 |
+
eval_delay = 0,
|
96 |
+
torch_empty_cache_steps = 250,
|
97 |
+
learning_rate = 5e-05,
|
98 |
+
weight_decay = 0.01,
|
99 |
+
adam_beta1 = 0.9,
|
100 |
+
adam_beta2 = 0.999,
|
101 |
+
adam_epsilon = 1e-08,
|
102 |
+
max_grad_norm = 1.0,
|
103 |
+
num_train_epochs = 3.0,
|
104 |
+
max_steps = -1,
|
105 |
+
lr_scheduler_type = 'linear',
|
106 |
+
warmup_ratio = 0.1,
|
107 |
+
warmup_steps = 0,
|
108 |
+
log_level = 'passive',
|
109 |
+
log_level_replica = 'warning',
|
110 |
+
log_on_each_node = True,
|
111 |
+
logging_dir = None,
|
112 |
+
logging_strategy = 'steps',
|
113 |
+
logging_first_step = False,
|
114 |
+
logging_steps = 1,
|
115 |
+
logging_nan_inf_filter = False,
|
116 |
+
save_strategy = 'steps',
|
117 |
+
save_steps = 500,
|
118 |
+
save_total_limit = None,
|
119 |
+
save_safetensors = True,
|
120 |
+
save_on_each_node = False,
|
121 |
+
save_only_model = False,
|
122 |
+
restore_callback_states_from_checkpoint = False,
|
123 |
+
no_cuda = False,
|
124 |
+
use_cpu = False,
|
125 |
+
use_mps_device = False,
|
126 |
+
seed = 3407,
|
127 |
+
data_seed = 3407,
|
128 |
+
jit_mode_eval = False,
|
129 |
+
use_ipex = False,
|
130 |
+
bf16 = False,
|
131 |
+
fp16 = False,
|
132 |
+
fp16_opt_level = 'O1',
|
133 |
+
half_precision_backend = 'auto',
|
134 |
+
bf16_full_eval = False,
|
135 |
+
fp16_full_eval = False,
|
136 |
+
tf32 = None,
|
137 |
+
local_rank = -1,
|
138 |
+
ddp_backend = None,
|
139 |
+
tpu_num_cores = None,
|
140 |
+
tpu_metrics_debug = False,
|
141 |
+
debug = '',
|
142 |
+
dataloader_drop_last = False,
|
143 |
+
eval_steps = None,
|
144 |
+
dataloader_num_workers = 0,
|
145 |
+
dataloader_prefetch_factor = None,
|
146 |
+
past_index = -1,
|
147 |
+
run_name = None,
|
148 |
+
disable_tqdm = None,
|
149 |
+
remove_unused_columns = True,
|
150 |
+
label_names = None,
|
151 |
+
load_best_model_at_end = False,
|
152 |
+
metric_for_best_model = None,
|
153 |
+
greater_is_better = None,
|
154 |
+
ignore_data_skip = False,
|
155 |
+
fsdp = '',
|
156 |
+
fsdp_min_num_params = 0,
|
157 |
+
fsdp_config = None,
|
158 |
+
tp_size = 0,
|
159 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
160 |
+
accelerator_config = None,
|
161 |
+
deepspeed = None,
|
162 |
+
label_smoothing_factor = 0.0,
|
163 |
+
optim = 'adamw_8bit',
|
164 |
+
optim_args = None,
|
165 |
+
adafactor = False,
|
166 |
+
group_by_length = False,
|
167 |
+
length_column_name = 'length',
|
168 |
+
report_to = None,
|
169 |
+
ddp_find_unused_parameters = None,
|
170 |
+
ddp_bucket_cap_mb = None,
|
171 |
+
ddp_broadcast_buffers = None,
|
172 |
+
dataloader_pin_memory = True,
|
173 |
+
dataloader_persistent_workers = False,
|
174 |
+
skip_memory_metrics = True,
|
175 |
+
use_legacy_prediction_loop = False,
|
176 |
+
push_to_hub = False,
|
177 |
+
resume_from_checkpoint = None,
|
178 |
+
hub_model_id = None,
|
179 |
+
hub_strategy = 'every_save',
|
180 |
+
hub_token = None,
|
181 |
+
hub_private_repo = None,
|
182 |
+
hub_always_push = False,
|
183 |
+
gradient_checkpointing = False,
|
184 |
+
gradient_checkpointing_kwargs = None,
|
185 |
+
include_inputs_for_metrics = False,
|
186 |
+
eval_do_concat_batches = True,
|
187 |
+
fp16_backend = 'auto',
|
188 |
+
evaluation_strategy = None,
|
189 |
+
push_to_hub_model_id = None,
|
190 |
+
push_to_hub_organization = None,
|
191 |
+
push_to_hub_token = None,
|
192 |
+
mp_parameters = '',
|
193 |
+
auto_find_batch_size = False,
|
194 |
+
full_determinism = False,
|
195 |
+
torchdynamo = None,
|
196 |
+
ray_scope = 'last',
|
197 |
+
ddp_timeout = 1800,
|
198 |
+
torch_compile = False,
|
199 |
+
torch_compile_backend = None,
|
200 |
+
torch_compile_mode = None,
|
201 |
+
dispatch_batches = None,
|
202 |
+
split_batches = None,
|
203 |
+
include_tokens_per_second = False,
|
204 |
+
include_num_input_tokens_seen = False,
|
205 |
+
neftune_noise_alpha = None,
|
206 |
+
optim_target_modules = None,
|
207 |
+
batch_eval_metrics = False,
|
208 |
+
eval_on_start = False,
|
209 |
+
use_liger_kernel = False,
|
210 |
+
eval_use_gather_object = False,
|
211 |
+
average_tokens_across_devices = False,
|
212 |
+
model_init_kwargs = None,
|
213 |
+
use_liger = False,
|
214 |
+
dataset_text_field = 'text',
|
215 |
+
dataset_kwargs = None,
|
216 |
+
dataset_num_proc = None,
|
217 |
+
max_seq_length = None,
|
218 |
+
packing = False,
|
219 |
+
eval_packing = None,
|
220 |
+
dataset_batch_size = None,
|
221 |
+
num_of_sequences = None,
|
222 |
+
chars_per_token = None,
|
223 |
+
temperature = 0.9,
|
224 |
+
lmbda = 0.5,
|
225 |
+
beta = 0.5,
|
226 |
+
max_new_tokens = 128,
|
227 |
+
teacher_model_name_or_path = None,
|
228 |
+
teacher_model_init_kwargs = None,
|
229 |
+
disable_dropout = True,
|
230 |
+
seq_kd = False,
|
231 |
+
vllm_sampling_params = None,
|
232 |
+
unsloth_num_chunks = -1,
|
233 |
+
**kwargs,
|
234 |
+
):
|
235 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
236 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
237 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
238 |
+
output_dir = 'unsloth_training_checkpoints'
|
239 |
+
save_strategy = 'no'
|
240 |
+
if dataset_num_proc is None:
|
241 |
+
from multiprocessing import cpu_count
|
242 |
+
dataset_num_proc = cpu_count()
|
243 |
+
|
244 |
+
super().__init__(
|
245 |
+
output_dir = output_dir,
|
246 |
+
overwrite_output_dir = overwrite_output_dir,
|
247 |
+
do_train = do_train,
|
248 |
+
do_eval = do_eval,
|
249 |
+
do_predict = do_predict,
|
250 |
+
eval_strategy = eval_strategy,
|
251 |
+
prediction_loss_only = prediction_loss_only,
|
252 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
253 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
254 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
255 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
256 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
257 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
258 |
+
eval_delay = eval_delay,
|
259 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
260 |
+
learning_rate = learning_rate,
|
261 |
+
weight_decay = weight_decay,
|
262 |
+
adam_beta1 = adam_beta1,
|
263 |
+
adam_beta2 = adam_beta2,
|
264 |
+
adam_epsilon = adam_epsilon,
|
265 |
+
max_grad_norm = max_grad_norm,
|
266 |
+
num_train_epochs = num_train_epochs,
|
267 |
+
max_steps = max_steps,
|
268 |
+
lr_scheduler_type = lr_scheduler_type,
|
269 |
+
warmup_ratio = warmup_ratio,
|
270 |
+
warmup_steps = warmup_steps,
|
271 |
+
log_level = log_level,
|
272 |
+
log_level_replica = log_level_replica,
|
273 |
+
log_on_each_node = log_on_each_node,
|
274 |
+
logging_dir = logging_dir,
|
275 |
+
logging_strategy = logging_strategy,
|
276 |
+
logging_first_step = logging_first_step,
|
277 |
+
logging_steps = logging_steps,
|
278 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
279 |
+
save_strategy = save_strategy,
|
280 |
+
save_steps = save_steps,
|
281 |
+
save_total_limit = save_total_limit,
|
282 |
+
save_safetensors = save_safetensors,
|
283 |
+
save_on_each_node = save_on_each_node,
|
284 |
+
save_only_model = save_only_model,
|
285 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
286 |
+
no_cuda = no_cuda,
|
287 |
+
use_cpu = use_cpu,
|
288 |
+
use_mps_device = use_mps_device,
|
289 |
+
seed = seed,
|
290 |
+
data_seed = data_seed,
|
291 |
+
jit_mode_eval = jit_mode_eval,
|
292 |
+
use_ipex = use_ipex,
|
293 |
+
bf16 = bf16,
|
294 |
+
fp16 = fp16,
|
295 |
+
fp16_opt_level = fp16_opt_level,
|
296 |
+
half_precision_backend = half_precision_backend,
|
297 |
+
bf16_full_eval = bf16_full_eval,
|
298 |
+
fp16_full_eval = fp16_full_eval,
|
299 |
+
tf32 = tf32,
|
300 |
+
local_rank = local_rank,
|
301 |
+
ddp_backend = ddp_backend,
|
302 |
+
tpu_num_cores = tpu_num_cores,
|
303 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
304 |
+
debug = debug,
|
305 |
+
dataloader_drop_last = dataloader_drop_last,
|
306 |
+
eval_steps = eval_steps,
|
307 |
+
dataloader_num_workers = dataloader_num_workers,
|
308 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
309 |
+
past_index = past_index,
|
310 |
+
run_name = run_name,
|
311 |
+
disable_tqdm = disable_tqdm,
|
312 |
+
remove_unused_columns = remove_unused_columns,
|
313 |
+
label_names = label_names,
|
314 |
+
load_best_model_at_end = load_best_model_at_end,
|
315 |
+
metric_for_best_model = metric_for_best_model,
|
316 |
+
greater_is_better = greater_is_better,
|
317 |
+
ignore_data_skip = ignore_data_skip,
|
318 |
+
fsdp = fsdp,
|
319 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
320 |
+
fsdp_config = fsdp_config,
|
321 |
+
tp_size = tp_size,
|
322 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
323 |
+
accelerator_config = accelerator_config,
|
324 |
+
deepspeed = deepspeed,
|
325 |
+
label_smoothing_factor = label_smoothing_factor,
|
326 |
+
optim = optim,
|
327 |
+
optim_args = optim_args,
|
328 |
+
adafactor = adafactor,
|
329 |
+
group_by_length = group_by_length,
|
330 |
+
length_column_name = length_column_name,
|
331 |
+
report_to = report_to,
|
332 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
333 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
334 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
335 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
336 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
337 |
+
skip_memory_metrics = skip_memory_metrics,
|
338 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
339 |
+
push_to_hub = push_to_hub,
|
340 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
341 |
+
hub_model_id = hub_model_id,
|
342 |
+
hub_strategy = hub_strategy,
|
343 |
+
hub_token = hub_token,
|
344 |
+
hub_private_repo = hub_private_repo,
|
345 |
+
hub_always_push = hub_always_push,
|
346 |
+
gradient_checkpointing = gradient_checkpointing,
|
347 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
348 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
349 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
350 |
+
fp16_backend = fp16_backend,
|
351 |
+
evaluation_strategy = evaluation_strategy,
|
352 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
353 |
+
push_to_hub_organization = push_to_hub_organization,
|
354 |
+
push_to_hub_token = push_to_hub_token,
|
355 |
+
mp_parameters = mp_parameters,
|
356 |
+
auto_find_batch_size = auto_find_batch_size,
|
357 |
+
full_determinism = full_determinism,
|
358 |
+
torchdynamo = torchdynamo,
|
359 |
+
ray_scope = ray_scope,
|
360 |
+
ddp_timeout = ddp_timeout,
|
361 |
+
torch_compile = torch_compile,
|
362 |
+
torch_compile_backend = torch_compile_backend,
|
363 |
+
torch_compile_mode = torch_compile_mode,
|
364 |
+
dispatch_batches = dispatch_batches,
|
365 |
+
split_batches = split_batches,
|
366 |
+
include_tokens_per_second = include_tokens_per_second,
|
367 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
368 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
369 |
+
optim_target_modules = optim_target_modules,
|
370 |
+
batch_eval_metrics = batch_eval_metrics,
|
371 |
+
eval_on_start = eval_on_start,
|
372 |
+
use_liger_kernel = use_liger_kernel,
|
373 |
+
eval_use_gather_object = eval_use_gather_object,
|
374 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
375 |
+
model_init_kwargs = model_init_kwargs,
|
376 |
+
use_liger = use_liger,
|
377 |
+
dataset_text_field = dataset_text_field,
|
378 |
+
dataset_kwargs = dataset_kwargs,
|
379 |
+
dataset_num_proc = dataset_num_proc,
|
380 |
+
max_seq_length = max_seq_length,
|
381 |
+
packing = packing,
|
382 |
+
eval_packing = eval_packing,
|
383 |
+
dataset_batch_size = dataset_batch_size,
|
384 |
+
num_of_sequences = num_of_sequences,
|
385 |
+
chars_per_token = chars_per_token,
|
386 |
+
temperature = temperature,
|
387 |
+
lmbda = lmbda,
|
388 |
+
beta = beta,
|
389 |
+
max_new_tokens = max_new_tokens,
|
390 |
+
teacher_model_name_or_path = teacher_model_name_or_path,
|
391 |
+
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
392 |
+
disable_dropout = disable_dropout,
|
393 |
+
seq_kd = seq_kd,**kwargs)
|
394 |
+
self.vllm_sampling_params = vllm_sampling_params
|
395 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
396 |
+
pass
|
397 |
+
|
398 |
+
class _UnslothGKDTrainer(SFTTrainer):
|
399 |
+
_tag_names = ["trl", "gkd"]
|
400 |
+
|
401 |
+
def __init__(
|
402 |
+
self,
|
403 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
404 |
+
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
405 |
+
args: Optional[GKDConfig] = None,
|
406 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
407 |
+
train_dataset: Optional[Dataset] = None,
|
408 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
409 |
+
processing_class: Optional[
|
410 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
411 |
+
] = None,
|
412 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
413 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
414 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
415 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
416 |
+
peft_config: Optional["PeftConfig"] = None,
|
417 |
+
formatting_func: Optional[Callable] = None,
|
418 |
+
):
|
419 |
+
# add remove_unused_columns=False to the dataclass args
|
420 |
+
args.remove_unused_columns = False
|
421 |
+
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
|
422 |
+
|
423 |
+
super().__init__(
|
424 |
+
model,
|
425 |
+
args=args,
|
426 |
+
data_collator=data_collator,
|
427 |
+
train_dataset=train_dataset,
|
428 |
+
eval_dataset=eval_dataset,
|
429 |
+
processing_class=processing_class,
|
430 |
+
compute_metrics=compute_metrics,
|
431 |
+
callbacks=callbacks,
|
432 |
+
optimizers=optimizers,
|
433 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
434 |
+
peft_config=peft_config,
|
435 |
+
formatting_func=formatting_func,
|
436 |
+
)
|
437 |
+
|
438 |
+
if args.teacher_model_init_kwargs is None:
|
439 |
+
teacher_model_init_kwargs = {}
|
440 |
+
elif not isinstance(teacher_model, str):
|
441 |
+
raise ValueError(
|
442 |
+
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
446 |
+
teacher_model_init_kwargs["torch_dtype"] = (
|
447 |
+
teacher_model_init_kwargs["torch_dtype"]
|
448 |
+
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
449 |
+
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
450 |
+
)
|
451 |
+
|
452 |
+
if isinstance(teacher_model, str):
|
453 |
+
if args.use_liger:
|
454 |
+
teacher_model = AutoLigerKernelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
455 |
+
else:
|
456 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
457 |
+
|
458 |
+
# Disable dropout in the model
|
459 |
+
if args.disable_dropout:
|
460 |
+
disable_dropout_in_model(self.model)
|
461 |
+
|
462 |
+
if self.is_deepspeed_enabled:
|
463 |
+
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
464 |
+
else:
|
465 |
+
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
466 |
+
|
467 |
+
self.lmbda = args.lmbda
|
468 |
+
self.beta = args.beta
|
469 |
+
self.temperature = args.temperature
|
470 |
+
self.seq_kd = args.seq_kd
|
471 |
+
|
472 |
+
self.generation_config = GenerationConfig(
|
473 |
+
max_new_tokens=args.max_new_tokens,
|
474 |
+
temperature=args.temperature,
|
475 |
+
do_sample=True,
|
476 |
+
top_k=0,
|
477 |
+
use_cache=False if args.gradient_checkpointing else True,
|
478 |
+
pad_token_id=self.processing_class.pad_token_id,
|
479 |
+
)
|
480 |
+
# Set custom EOS tokens if they are specified by the model's generation
|
481 |
+
# config. This is important for models with the Llama 3 chat template,
|
482 |
+
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
483 |
+
# turns or messages.
|
484 |
+
if (
|
485 |
+
hasattr(self.model.generation_config, "eos_token_id")
|
486 |
+
and self.model.generation_config.eos_token_id is not None
|
487 |
+
):
|
488 |
+
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
489 |
+
|
490 |
+
def _prepare_dataset(self, dataset, *args):
|
491 |
+
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
492 |
+
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
493 |
+
dataset = dataset.add_column("_messages", dataset["messages"])
|
494 |
+
dataset = super()._prepare_dataset(dataset, *args)
|
495 |
+
dataset = dataset.rename_column("_messages", "messages")
|
496 |
+
return dataset
|
497 |
+
|
498 |
+
@staticmethod
|
499 |
+
def generalized_jsd_loss(
|
500 |
+
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
501 |
+
):
|
502 |
+
"""
|
503 |
+
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
504 |
+
of https://huggingface.co/papers/2306.13649 for the definition.
|
505 |
+
|
506 |
+
Args:
|
507 |
+
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
508 |
+
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
509 |
+
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
510 |
+
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
511 |
+
temperature: Softmax temperature (default: 1.0)
|
512 |
+
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
loss: Scalar tensor with the generalized JSD loss
|
516 |
+
"""
|
517 |
+
|
518 |
+
# Apply temperature scaling
|
519 |
+
student_logits = student_logits / temperature
|
520 |
+
teacher_logits = teacher_logits / temperature
|
521 |
+
|
522 |
+
# Compute log probabilities for student and probabilities for teacher
|
523 |
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
524 |
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
525 |
+
|
526 |
+
# Compute the log of the mixture distribution
|
527 |
+
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
528 |
+
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
529 |
+
mixture_log_probs = torch.logsumexp(
|
530 |
+
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
|
531 |
+
dim=0,
|
532 |
+
)
|
533 |
+
|
534 |
+
# Compute KL divergences using F.kl_div
|
535 |
+
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
536 |
+
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
537 |
+
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
538 |
+
|
539 |
+
# Compute the Generalized Jensen-Shannon Divergence
|
540 |
+
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
541 |
+
|
542 |
+
# Masking
|
543 |
+
if labels is not None:
|
544 |
+
mask = labels != -100
|
545 |
+
jsd = jsd[mask]
|
546 |
+
|
547 |
+
# Apply reduction
|
548 |
+
if reduction == "batchmean":
|
549 |
+
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
550 |
+
elif reduction == "sum":
|
551 |
+
return jsd.sum()
|
552 |
+
elif reduction == "mean":
|
553 |
+
return jsd.mean()
|
554 |
+
else:
|
555 |
+
return jsd
|
556 |
+
|
557 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
558 |
+
# compute student output
|
559 |
+
outputs_student = model(
|
560 |
+
input_ids=inputs["input_ids"],
|
561 |
+
attention_mask=inputs["attention_mask"],
|
562 |
+
)
|
563 |
+
|
564 |
+
# compute teacher output in eval mode
|
565 |
+
self.teacher_model.eval()
|
566 |
+
with torch.no_grad():
|
567 |
+
outputs_teacher = self.teacher_model(
|
568 |
+
input_ids=inputs["input_ids"],
|
569 |
+
attention_mask=inputs["attention_mask"],
|
570 |
+
)
|
571 |
+
|
572 |
+
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
573 |
+
prompt_lengths = inputs["prompts"].shape[1]
|
574 |
+
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
575 |
+
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
576 |
+
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
577 |
+
|
578 |
+
# compute loss
|
579 |
+
loss = self.generalized_jsd_loss(
|
580 |
+
student_logits=shifted_student_logits,
|
581 |
+
teacher_logits=shifted_teacher_logits,
|
582 |
+
labels=shifted_labels,
|
583 |
+
beta=self.beta,
|
584 |
+
)
|
585 |
+
|
586 |
+
# empty cache
|
587 |
+
empty_cache()
|
588 |
+
|
589 |
+
# Return loss
|
590 |
+
return (loss, outputs_student) if return_outputs else loss
|
591 |
+
|
592 |
+
@staticmethod
|
593 |
+
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
594 |
+
# Generate output with respect to the prompt only
|
595 |
+
generated_outputs = model.generate(
|
596 |
+
input_ids=inputs["prompts"],
|
597 |
+
attention_mask=inputs.get("prompt_attention_mask", None),
|
598 |
+
generation_config=generation_config,
|
599 |
+
return_dict_in_generate=True,
|
600 |
+
)
|
601 |
+
|
602 |
+
# Get the generated token IDs
|
603 |
+
generated_tokens = generated_outputs.sequences
|
604 |
+
# Calculate new attention mask
|
605 |
+
new_attention_mask = torch.ones_like(generated_tokens)
|
606 |
+
new_labels = generated_tokens.clone()
|
607 |
+
|
608 |
+
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
609 |
+
if pad_token_id is not None:
|
610 |
+
new_labels[new_labels == pad_token_id] = -100
|
611 |
+
new_attention_mask[generated_tokens == pad_token_id] = 0
|
612 |
+
|
613 |
+
return generated_tokens, new_attention_mask, new_labels
|
614 |
+
|
615 |
+
def training_step(
|
616 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
617 |
+
) -> torch.Tensor:
|
618 |
+
"""
|
619 |
+
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
620 |
+
|
621 |
+
This method implements the on-policy learning approach described in the GKD paper.
|
622 |
+
With probability `self.lmbda`, it generates new responses using the student model,
|
623 |
+
which are then used for training instead of the original inputs.
|
624 |
+
"""
|
625 |
+
if self.seq_kd:
|
626 |
+
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
627 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
628 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
629 |
+
)
|
630 |
+
inputs["input_ids"] = new_input_ids
|
631 |
+
inputs["attention_mask"] = new_attention_mask
|
632 |
+
inputs["labels"] = new_labels
|
633 |
+
if random.random() <= self.lmbda:
|
634 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
635 |
+
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
636 |
+
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
637 |
+
)
|
638 |
+
inputs["input_ids"] = new_input_ids
|
639 |
+
inputs["attention_mask"] = new_attention_mask
|
640 |
+
inputs["labels"] = new_labels
|
641 |
+
|
642 |
+
loss = super().training_step(model, inputs, num_items_in_batch)
|
643 |
+
return loss
|
644 |
+
|
645 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
646 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
647 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
648 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
649 |
+
|
650 |
+
if model is not None:
|
651 |
+
if hasattr(model, "config"):
|
652 |
+
hidden_size = (
|
653 |
+
max(model.config.hidden_sizes)
|
654 |
+
if getattr(model.config, "hidden_sizes", None)
|
655 |
+
else getattr(model.config, "hidden_size", None)
|
656 |
+
)
|
657 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
658 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
659 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
660 |
+
config_kwargs.update(
|
661 |
+
{
|
662 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
663 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
664 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
665 |
+
}
|
666 |
+
)
|
667 |
+
|
668 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
669 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
670 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
671 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
672 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
673 |
+
model.eval()
|
674 |
+
return model
|
675 |
+
|
676 |
+
def create_model_card(
|
677 |
+
self,
|
678 |
+
model_name: Optional[str] = None,
|
679 |
+
dataset_name: Optional[str] = None,
|
680 |
+
tags: Union[str, list[str], None] = None,
|
681 |
+
):
|
682 |
+
"""
|
683 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
684 |
+
|
685 |
+
Args:
|
686 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
687 |
+
Name of the model.
|
688 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
689 |
+
Name of the dataset used for training.
|
690 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
691 |
+
Tags to be associated with the model card.
|
692 |
+
"""
|
693 |
+
if not self.is_world_process_zero():
|
694 |
+
return
|
695 |
+
|
696 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
697 |
+
base_model = self.model.config._name_or_path
|
698 |
+
else:
|
699 |
+
base_model = None
|
700 |
+
|
701 |
+
tags = tags or []
|
702 |
+
if isinstance(tags, str):
|
703 |
+
tags = [tags]
|
704 |
+
|
705 |
+
if hasattr(self.model.config, "unsloth_version"):
|
706 |
+
tags.append("unsloth")
|
707 |
+
|
708 |
+
citation = textwrap.dedent("""\
|
709 |
+
@inproceedings{agarwal2024on-policy,
|
710 |
+
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
711 |
+
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
712 |
+
year = 2024,
|
713 |
+
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
714 |
+
publisher = {OpenReview.net},
|
715 |
+
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
716 |
+
}""")
|
717 |
+
|
718 |
+
model_card = generate_model_card(
|
719 |
+
base_model=base_model,
|
720 |
+
model_name=model_name,
|
721 |
+
hub_model_id=self.hub_model_id,
|
722 |
+
dataset_name=dataset_name,
|
723 |
+
tags=tags,
|
724 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
725 |
+
comet_url=get_comet_experiment_url(),
|
726 |
+
trainer_name="GKD",
|
727 |
+
trainer_citation=citation,
|
728 |
+
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
729 |
+
paper_id="2306.13649",
|
730 |
+
)
|
731 |
+
|
732 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
733 |
+
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
734 |
+
"""
|
735 |
+
|
736 |
+
"""
|
737 |
+
def __init__(
|
738 |
+
self,
|
739 |
+
model = None,
|
740 |
+
teacher_model = None,
|
741 |
+
args = None,
|
742 |
+
data_collator = None,
|
743 |
+
train_dataset = None,
|
744 |
+
eval_dataset = None,
|
745 |
+
processing_class = None,
|
746 |
+
compute_metrics = None,
|
747 |
+
callbacks = None,
|
748 |
+
preprocess_logits_for_metrics = None,
|
749 |
+
peft_config = None,
|
750 |
+
formatting_func = None,
|
751 |
+
**kwargs
|
752 |
+
):
|
753 |
+
if args is None: args = UnslothGKDConfig()
|
754 |
+
use_bf16 = getattr(args, 'bf16', False)
|
755 |
+
use_fp16 = getattr(args, 'fp16', False)
|
756 |
+
force_float32 = False
|
757 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
758 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
759 |
+
force_float32 = True
|
760 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
761 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
762 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
763 |
+
from unsloth_zoo.utils import _get_dtype
|
764 |
+
dtype = _get_dtype(dtype)
|
765 |
+
float16 = dtype == torch.float16
|
766 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
767 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
768 |
+
if force_float32:
|
769 |
+
args.fp16 = False
|
770 |
+
args.bf16 = False
|
771 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
772 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
773 |
+
args.fp16 = float16
|
774 |
+
args.bf16 = not float16
|
775 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
776 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
777 |
+
args.eval_strategy = 'steps'
|
778 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
779 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
780 |
+
if ga_steps is not None and ga_steps > 1:
|
781 |
+
from transformers import __version__ as transformers_version
|
782 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
783 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
784 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
785 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
786 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
787 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
788 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
789 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
790 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
791 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
792 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
793 |
+
if force_float32:
|
794 |
+
args.bf16_full_eval = False
|
795 |
+
args.fp16_full_eval = False
|
796 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
797 |
+
args.bf16_full_eval = True
|
798 |
+
args.fp16_full_eval = False
|
799 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
800 |
+
args.bf16_full_eval = args.bf16
|
801 |
+
args.fp16_full_eval = args.fp16
|
802 |
+
_output_logits = False
|
803 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
804 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
805 |
+
if _output_logits:
|
806 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
807 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
808 |
+
pass
|
809 |
+
else:
|
810 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
811 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
812 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
813 |
+
max_seq_length = model.max_seq_length
|
814 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
815 |
+
if model is not None and hasattr(model, 'for_training'):
|
816 |
+
model.for_training()
|
817 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
818 |
+
if 'processing_class' in locals():
|
819 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
820 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
821 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
822 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
823 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
824 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
825 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
826 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
827 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
828 |
+
else:
|
829 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
830 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
831 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
832 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
833 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
834 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
835 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
836 |
+
else:
|
837 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
838 |
+
other_metrics = []
|
839 |
+
|
840 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
841 |
+
PatchRLStatistics('gkd_trainer', other_metrics)
|
842 |
+
|
843 |
+
super().__init__(
|
844 |
+
model = model,
|
845 |
+
teacher_model = teacher_model,
|
846 |
+
args = args,
|
847 |
+
data_collator = data_collator,
|
848 |
+
train_dataset = train_dataset,
|
849 |
+
eval_dataset = eval_dataset,
|
850 |
+
processing_class = processing_class,
|
851 |
+
compute_metrics = compute_metrics,
|
852 |
+
callbacks = callbacks,
|
853 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
854 |
+
peft_config = peft_config,
|
855 |
+
formatting_func = formatting_func,**kwargs)
|
856 |
+
if hasattr(self, 'neftune_hook_handle'):
|
857 |
+
self.neftune_hook_handle.remove()
|
858 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
859 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
860 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
861 |
+
pass
|
862 |
+
|
863 |
+
pass
|
unsloth_compiled_cache/UnslothGRPOTrainer.py
ADDED
@@ -0,0 +1,1438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.grpo_trainer import (Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, GRPOConfig, GRPOTrainer, GenerationConfig, IterableDataset, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, RepeatRandomSampler, RewardFunc, Sampler, SyncRefModelCallback, Trainer, TrainerCallback, Union, apply_chat_template, broadcast_object_list, create_reference_model, defaultdict, gather, gather_object, generate_model_card, get_comet_experiment_url, is_conversational, is_deepspeed_zero3_enabled, is_peft_model, is_wandb_available, maybe_apply_chat_template, nn, os, pad, patch, prepare_deepspeed, set_seed, textwrap, torch, transformers, unwrap_model_for_generation, version, warnings, os, torch, transformers, Any, Union, apply_chat_template, broadcast_object_list, gather, gather_object, is_conversational, maybe_apply_chat_template, nn, os, pad, torch, unwrap_model_for_generation, GRPOTrainer, Trainer, gather, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
|
43 |
+
def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages):
|
44 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
45 |
+
old_logits = old_logits.to(torch.float32)
|
46 |
+
new_logits = new_logits.to(torch.float32)
|
47 |
+
input_ids = input_ids.unsqueeze(-1)
|
48 |
+
|
49 |
+
# x_i - logsumexp(x_i)
|
50 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
51 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
52 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
53 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
54 |
+
|
55 |
+
# Reverse KL
|
56 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
57 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
58 |
+
# kl_i = torch.exp(new) * kl_i
|
59 |
+
|
60 |
+
# Below is forward KL (normal KL)
|
61 |
+
# kl_i = torch.exp(old) * (old - new)
|
62 |
+
|
63 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
64 |
+
# exp(x - x) == 1
|
65 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
66 |
+
loss_i = -(loss_i - beta * kl_i)
|
67 |
+
|
68 |
+
mask = mask.to(torch.float32)
|
69 |
+
n_mask_per_reward = mask.sum(1)
|
70 |
+
|
71 |
+
# See https://github.com/huggingface/trl/pull/2881
|
72 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
73 |
+
loss = loss_per_reward.mean()
|
74 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
75 |
+
|
76 |
+
# Get metrics as well which are folded
|
77 |
+
with torch.inference_mode():
|
78 |
+
completion_length = n_mask_per_reward.mean()
|
79 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
80 |
+
mean_kl = mean_kl_per_reward.mean()
|
81 |
+
pass
|
82 |
+
return loss, completion_length, mean_kl
|
83 |
+
|
84 |
+
class UnslothEfficientGRPO(torch.autograd.Function):
|
85 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1):
|
88 |
+
def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling):
|
89 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
90 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
91 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
92 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
93 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
94 |
+
old_logits, new_logits, input_ids, mask, beta, advantages,
|
95 |
+
)
|
96 |
+
# Scale loss if needed for mixed precision training
|
97 |
+
scaled_loss = loss * scaling
|
98 |
+
# Must add .loss.detach otherwise autograd uses 2x VRAM
|
99 |
+
return scaled_loss, (loss.detach(), completion_length, mean_kl,)
|
100 |
+
pass
|
101 |
+
|
102 |
+
device =_new_hidden_states.device
|
103 |
+
grad_inputs = torch.empty_like(_new_hidden_states)
|
104 |
+
accumulated_loss = torch.zeros(1, device = device)
|
105 |
+
accumulated_completion_length = torch.zeros(1, device = device)
|
106 |
+
accumulated_mean_kl = torch.zeros(1, device = device)
|
107 |
+
|
108 |
+
def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling):
|
109 |
+
(chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
|
110 |
+
compute_loss,
|
111 |
+
argnums = (0,),
|
112 |
+
has_aux = True,
|
113 |
+
)(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
114 |
+
accumulated_loss .add_(unscaled_loss)
|
115 |
+
accumulated_completion_length.add_(chunk_completion_length)
|
116 |
+
accumulated_mean_kl .add_(chunk_mean_kl)
|
117 |
+
return chunk_grad_input
|
118 |
+
pass
|
119 |
+
|
120 |
+
accumulate_chunk = torch.compile(
|
121 |
+
accumulate_chunk,
|
122 |
+
fullgraph = True,
|
123 |
+
options = torch_compile_options,
|
124 |
+
)
|
125 |
+
|
126 |
+
grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
|
127 |
+
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
|
128 |
+
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
|
129 |
+
input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0)
|
130 |
+
mask = torch.chunk(_mask, chunks = n_chunks, dim = 0)
|
131 |
+
advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0)
|
132 |
+
|
133 |
+
# Get mixed precision scaling if seen
|
134 |
+
scaling = scaler.get_scale() if scaler is not None else 1.0
|
135 |
+
|
136 |
+
# Force torch.compile to use dynamic shapes for seqlen dim
|
137 |
+
mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1)
|
138 |
+
|
139 |
+
for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \
|
140 |
+
zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages):
|
141 |
+
|
142 |
+
mark_dynamic(new_hidden_states_j)
|
143 |
+
mark_dynamic(old_hidden_states_j)
|
144 |
+
mark_dynamic(input_ids_j)
|
145 |
+
mark_dynamic(mask_j)
|
146 |
+
|
147 |
+
grad_inputs_j.copy_(
|
148 |
+
accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling)
|
149 |
+
)
|
150 |
+
pass
|
151 |
+
|
152 |
+
grad_inputs .div_(n_chunks)
|
153 |
+
accumulated_loss .div_(n_chunks)
|
154 |
+
accumulated_completion_length.div_(n_chunks)
|
155 |
+
accumulated_mean_kl .div_(n_chunks)
|
156 |
+
ctx.save_for_backward(grad_inputs)
|
157 |
+
|
158 |
+
return (
|
159 |
+
accumulated_loss,
|
160 |
+
accumulated_completion_length,
|
161 |
+
accumulated_mean_kl,
|
162 |
+
)
|
163 |
+
pass
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def backward(ctx, grad_output, dcompletion_length, dmean_kl):
|
167 |
+
(grad_input,) = ctx.saved_tensors
|
168 |
+
return (grad_input, None, None, None, None, None, None, None, None,)
|
169 |
+
pass
|
170 |
+
|
171 |
+
def grpo_accumulated_loss(
|
172 |
+
trainer,
|
173 |
+
input_ids,
|
174 |
+
logits_to_keep,
|
175 |
+
completion_mask,
|
176 |
+
advantages,
|
177 |
+
n_chunks = -1,
|
178 |
+
):
|
179 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
180 |
+
bsz, qlen = input_ids.shape
|
181 |
+
# Find closest multiple
|
182 |
+
factors = [i for i in range(1, bsz + 1) if bsz % i == 0]
|
183 |
+
if n_chunks == -1: n_chunks = bsz
|
184 |
+
n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)]
|
185 |
+
|
186 |
+
mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
187 |
+
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
|
188 |
+
|
189 |
+
completion_input_ids = input_ids[:, -logits_to_keep:]
|
190 |
+
lm_head = trainer.model.get_output_embeddings().weight
|
191 |
+
|
192 |
+
with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype):
|
193 |
+
with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter():
|
194 |
+
old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
195 |
+
pass
|
196 |
+
|
197 |
+
new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits
|
198 |
+
|
199 |
+
loss, completion_length, mean_kl = UnslothEfficientGRPO.apply(
|
200 |
+
new_hidden_states, old_hidden_states, lm_head,
|
201 |
+
completion_input_ids, completion_mask, advantages, trainer.beta,
|
202 |
+
trainer.accelerator.scaler,
|
203 |
+
n_chunks,
|
204 |
+
)
|
205 |
+
return loss, completion_length, mean_kl
|
206 |
+
|
207 |
+
# Old non efficient code path
|
208 |
+
new_logits = torch.matmul(new_hidden_states, lm_head.t())
|
209 |
+
new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
210 |
+
old_logits = torch.matmul(old_hidden_states, lm_head.t())
|
211 |
+
old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred
|
212 |
+
loss, completion_length, mean_kl = grpo_compute_loss(
|
213 |
+
old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages,
|
214 |
+
)
|
215 |
+
return loss, completion_length, mean_kl
|
216 |
+
pass
|
217 |
+
|
218 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)
|
219 |
+
def grpo_compute_loss_slow(old_logits, new_logits, input_ids, mask, beta, advantages):
|
220 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
221 |
+
old_logits = old_logits.to(torch.float32)
|
222 |
+
new_logits = new_logits.to(torch.float32)
|
223 |
+
input_ids = input_ids.unsqueeze(-1)
|
224 |
+
|
225 |
+
# x_i - logsumexp(x_i)
|
226 |
+
old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1)
|
227 |
+
new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1)
|
228 |
+
old = old_x - torch.logsumexp(old_logits, dim = -1)
|
229 |
+
new = new_x - torch.logsumexp(new_logits, dim = -1)
|
230 |
+
|
231 |
+
# Reverse KL
|
232 |
+
kl_i = torch.exp(old - new) - (old - new) - 1.0
|
233 |
+
# Full correct reverse KL divergence?? Missing term maybe?
|
234 |
+
# kl_i = torch.exp(new) * kl_i
|
235 |
+
|
236 |
+
# Below is forward KL (normal KL)
|
237 |
+
# kl_i = torch.exp(old) * (old - new)
|
238 |
+
|
239 |
+
# Must detach - otherwise gradients are not propagated correctly!
|
240 |
+
# exp(x - x) == 1
|
241 |
+
loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1)
|
242 |
+
loss_i = -(loss_i - beta * kl_i)
|
243 |
+
|
244 |
+
mask = mask.to(torch.float32)
|
245 |
+
n_mask_per_reward = mask.sum(1)
|
246 |
+
|
247 |
+
# See https://github.com/huggingface/trl/pull/2881
|
248 |
+
loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward
|
249 |
+
loss = loss_per_reward.mean()
|
250 |
+
# loss = (loss_i * mask).sum() / mask.sum()
|
251 |
+
|
252 |
+
# Get metrics as well which are folded
|
253 |
+
with torch.inference_mode():
|
254 |
+
completion_length = n_mask_per_reward.mean()
|
255 |
+
mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward
|
256 |
+
mean_kl = mean_kl_per_reward.mean()
|
257 |
+
pass
|
258 |
+
return loss, completion_length, mean_kl
|
259 |
+
|
260 |
+
def vLLMSamplingParams(**kwargs):
|
261 |
+
from vllm import SamplingParams
|
262 |
+
sampling_params = SamplingParams(**kwargs)
|
263 |
+
sampling_params._set_kwargs = kwargs
|
264 |
+
return sampling_params
|
265 |
+
@dataclass
|
266 |
+
class UnslothGRPOConfig(GRPOConfig):
|
267 |
+
"""
|
268 |
+
|
269 |
+
Configuration class for the [`GRPOTrainer`].
|
270 |
+
|
271 |
+
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
272 |
+
[`~transformers.TrainingArguments`] documentation.
|
273 |
+
|
274 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
275 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
276 |
+
command line.
|
277 |
+
|
278 |
+
Parameters:
|
279 |
+
> Parameters that control the model and reference model
|
280 |
+
|
281 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
282 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
283 |
+
argument of the [`GRPOTrainer`] is provided as a string.
|
284 |
+
|
285 |
+
> Parameters that control the data preprocessing
|
286 |
+
|
287 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
288 |
+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
289 |
+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
290 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
291 |
+
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
292 |
+
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
293 |
+
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
294 |
+
must be divisible by this value.
|
295 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
296 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
297 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
298 |
+
Maximum length of the generated completion.
|
299 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
300 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
301 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
302 |
+
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
303 |
+
with vLLM generation.
|
304 |
+
|
305 |
+
> Parameters that control generation acceleration powered by vLLM
|
306 |
+
|
307 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
308 |
+
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
309 |
+
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
310 |
+
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
311 |
+
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
312 |
+
automatically select the next available GPU after the last one used for training. This assumes that
|
313 |
+
training has not already occupied all available GPUs. If only one device is available, the device will be
|
314 |
+
shared between both training and vLLM.
|
315 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
316 |
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
317 |
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
318 |
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
319 |
+
during initialization.
|
320 |
+
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
321 |
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
322 |
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
323 |
+
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
324 |
+
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
325 |
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
326 |
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
327 |
+
|
328 |
+
> Parameters that control the training
|
329 |
+
|
330 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
331 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
332 |
+
[`~transformers.TrainingArguments`].
|
333 |
+
beta (`float`, *optional*, defaults to `0.04`):
|
334 |
+
KL coefficient.
|
335 |
+
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
336 |
+
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
337 |
+
weighted equally with weight `1.0`.
|
338 |
+
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
339 |
+
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
340 |
+
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
341 |
+
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
342 |
+
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
|
343 |
+
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
344 |
+
between the current policy and the previous reference policy during updates. The reference policy is
|
345 |
+
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
346 |
+
must set `sync_ref_model=True`.
|
347 |
+
ref_model_sync_steps (`int`, *optional*, defaults to `64`):
|
348 |
+
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
349 |
+
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
350 |
+
set `sync_ref_model=True`.
|
351 |
+
|
352 |
+
> Parameters that control the logging
|
353 |
+
|
354 |
+
log_completions (`bool`, *optional*, defaults to `False`):
|
355 |
+
Whether to log the completions during training.
|
356 |
+
|
357 |
+
"""
|
358 |
+
vllm_sampling_params: Optional[Any] = field(
|
359 |
+
default = None,
|
360 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
361 |
+
)
|
362 |
+
unsloth_num_chunks : Optional[int] = field(
|
363 |
+
default = -1,
|
364 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
365 |
+
)
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
output_dir = None,
|
369 |
+
overwrite_output_dir = None,
|
370 |
+
do_train = False,
|
371 |
+
do_eval = False,
|
372 |
+
do_predict = False,
|
373 |
+
eval_strategy = 'no',
|
374 |
+
prediction_loss_only = False,
|
375 |
+
per_device_train_batch_size = 4,
|
376 |
+
per_device_eval_batch_size = 4,
|
377 |
+
per_gpu_train_batch_size = None,
|
378 |
+
per_gpu_eval_batch_size = None,
|
379 |
+
gradient_accumulation_steps = 2,
|
380 |
+
eval_accumulation_steps = 2,
|
381 |
+
eval_delay = 0,
|
382 |
+
torch_empty_cache_steps = 250,
|
383 |
+
learning_rate = 5e-05,
|
384 |
+
weight_decay = 0.01,
|
385 |
+
adam_beta1 = 0.9,
|
386 |
+
adam_beta2 = 0.999,
|
387 |
+
adam_epsilon = 1e-08,
|
388 |
+
max_grad_norm = 1.0,
|
389 |
+
num_train_epochs = 3.0,
|
390 |
+
max_steps = -1,
|
391 |
+
lr_scheduler_type = 'linear',
|
392 |
+
warmup_ratio = 0.1,
|
393 |
+
warmup_steps = 0,
|
394 |
+
log_level = 'passive',
|
395 |
+
log_level_replica = 'warning',
|
396 |
+
log_on_each_node = True,
|
397 |
+
logging_dir = None,
|
398 |
+
logging_strategy = 'steps',
|
399 |
+
logging_first_step = False,
|
400 |
+
logging_steps = 1,
|
401 |
+
logging_nan_inf_filter = False,
|
402 |
+
save_strategy = 'steps',
|
403 |
+
save_steps = 500,
|
404 |
+
save_total_limit = None,
|
405 |
+
save_safetensors = True,
|
406 |
+
save_on_each_node = False,
|
407 |
+
save_only_model = False,
|
408 |
+
restore_callback_states_from_checkpoint = False,
|
409 |
+
no_cuda = False,
|
410 |
+
use_cpu = False,
|
411 |
+
use_mps_device = False,
|
412 |
+
seed = 3407,
|
413 |
+
data_seed = 3407,
|
414 |
+
jit_mode_eval = False,
|
415 |
+
use_ipex = False,
|
416 |
+
bf16 = False,
|
417 |
+
fp16 = False,
|
418 |
+
fp16_opt_level = 'O1',
|
419 |
+
half_precision_backend = 'auto',
|
420 |
+
bf16_full_eval = False,
|
421 |
+
fp16_full_eval = False,
|
422 |
+
tf32 = None,
|
423 |
+
local_rank = -1,
|
424 |
+
ddp_backend = None,
|
425 |
+
tpu_num_cores = None,
|
426 |
+
tpu_metrics_debug = False,
|
427 |
+
debug = '',
|
428 |
+
dataloader_drop_last = False,
|
429 |
+
eval_steps = None,
|
430 |
+
dataloader_num_workers = 0,
|
431 |
+
dataloader_prefetch_factor = None,
|
432 |
+
past_index = -1,
|
433 |
+
run_name = None,
|
434 |
+
disable_tqdm = None,
|
435 |
+
remove_unused_columns = False,
|
436 |
+
label_names = None,
|
437 |
+
load_best_model_at_end = False,
|
438 |
+
metric_for_best_model = None,
|
439 |
+
greater_is_better = None,
|
440 |
+
ignore_data_skip = False,
|
441 |
+
fsdp = '',
|
442 |
+
fsdp_min_num_params = 0,
|
443 |
+
fsdp_config = None,
|
444 |
+
tp_size = 0,
|
445 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
446 |
+
accelerator_config = None,
|
447 |
+
deepspeed = None,
|
448 |
+
label_smoothing_factor = 0.0,
|
449 |
+
optim = 'adamw_8bit',
|
450 |
+
optim_args = None,
|
451 |
+
adafactor = False,
|
452 |
+
group_by_length = False,
|
453 |
+
length_column_name = 'length',
|
454 |
+
report_to = None,
|
455 |
+
ddp_find_unused_parameters = None,
|
456 |
+
ddp_bucket_cap_mb = None,
|
457 |
+
ddp_broadcast_buffers = None,
|
458 |
+
dataloader_pin_memory = True,
|
459 |
+
dataloader_persistent_workers = False,
|
460 |
+
skip_memory_metrics = True,
|
461 |
+
use_legacy_prediction_loop = False,
|
462 |
+
push_to_hub = False,
|
463 |
+
resume_from_checkpoint = None,
|
464 |
+
hub_model_id = None,
|
465 |
+
hub_strategy = 'every_save',
|
466 |
+
hub_token = None,
|
467 |
+
hub_private_repo = None,
|
468 |
+
hub_always_push = False,
|
469 |
+
gradient_checkpointing = False,
|
470 |
+
gradient_checkpointing_kwargs = None,
|
471 |
+
include_inputs_for_metrics = False,
|
472 |
+
eval_do_concat_batches = True,
|
473 |
+
fp16_backend = 'auto',
|
474 |
+
evaluation_strategy = None,
|
475 |
+
push_to_hub_model_id = None,
|
476 |
+
push_to_hub_organization = None,
|
477 |
+
push_to_hub_token = None,
|
478 |
+
mp_parameters = '',
|
479 |
+
auto_find_batch_size = False,
|
480 |
+
full_determinism = False,
|
481 |
+
torchdynamo = None,
|
482 |
+
ray_scope = 'last',
|
483 |
+
ddp_timeout = 1800,
|
484 |
+
torch_compile = False,
|
485 |
+
torch_compile_backend = None,
|
486 |
+
torch_compile_mode = None,
|
487 |
+
dispatch_batches = None,
|
488 |
+
split_batches = None,
|
489 |
+
include_tokens_per_second = False,
|
490 |
+
include_num_input_tokens_seen = False,
|
491 |
+
neftune_noise_alpha = None,
|
492 |
+
optim_target_modules = None,
|
493 |
+
batch_eval_metrics = False,
|
494 |
+
eval_on_start = False,
|
495 |
+
use_liger_kernel = False,
|
496 |
+
eval_use_gather_object = False,
|
497 |
+
average_tokens_across_devices = False,
|
498 |
+
model_init_kwargs = None,
|
499 |
+
max_prompt_length = 512,
|
500 |
+
num_generations = 8,
|
501 |
+
temperature = 0.9,
|
502 |
+
max_completion_length = 256,
|
503 |
+
ds3_gather_for_generation = True,
|
504 |
+
use_vllm = False,
|
505 |
+
vllm_device = 'auto',
|
506 |
+
vllm_gpu_memory_utilization = 0.9,
|
507 |
+
vllm_dtype = 'auto',
|
508 |
+
vllm_max_model_len = None,
|
509 |
+
beta = 0.04,
|
510 |
+
reward_weights = None,
|
511 |
+
sync_ref_model = False,
|
512 |
+
ref_model_mixup_alpha = 0.9,
|
513 |
+
ref_model_sync_steps = 64,
|
514 |
+
log_completions = False,
|
515 |
+
vllm_sampling_params = None,
|
516 |
+
unsloth_num_chunks = -1,
|
517 |
+
**kwargs,
|
518 |
+
):
|
519 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
520 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
521 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
522 |
+
output_dir = 'unsloth_training_checkpoints'
|
523 |
+
save_strategy = 'no'
|
524 |
+
div = per_device_train_batch_size // num_generations
|
525 |
+
if div * num_generations != per_device_train_batch_size:
|
526 |
+
print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\nWe will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))
|
527 |
+
per_device_train_batch_size = num_generations
|
528 |
+
|
529 |
+
super().__init__(
|
530 |
+
output_dir = output_dir,
|
531 |
+
overwrite_output_dir = overwrite_output_dir,
|
532 |
+
do_train = do_train,
|
533 |
+
do_eval = do_eval,
|
534 |
+
do_predict = do_predict,
|
535 |
+
eval_strategy = eval_strategy,
|
536 |
+
prediction_loss_only = prediction_loss_only,
|
537 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
538 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
539 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
540 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
541 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
542 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
543 |
+
eval_delay = eval_delay,
|
544 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
545 |
+
learning_rate = learning_rate,
|
546 |
+
weight_decay = weight_decay,
|
547 |
+
adam_beta1 = adam_beta1,
|
548 |
+
adam_beta2 = adam_beta2,
|
549 |
+
adam_epsilon = adam_epsilon,
|
550 |
+
max_grad_norm = max_grad_norm,
|
551 |
+
num_train_epochs = num_train_epochs,
|
552 |
+
max_steps = max_steps,
|
553 |
+
lr_scheduler_type = lr_scheduler_type,
|
554 |
+
warmup_ratio = warmup_ratio,
|
555 |
+
warmup_steps = warmup_steps,
|
556 |
+
log_level = log_level,
|
557 |
+
log_level_replica = log_level_replica,
|
558 |
+
log_on_each_node = log_on_each_node,
|
559 |
+
logging_dir = logging_dir,
|
560 |
+
logging_strategy = logging_strategy,
|
561 |
+
logging_first_step = logging_first_step,
|
562 |
+
logging_steps = logging_steps,
|
563 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
564 |
+
save_strategy = save_strategy,
|
565 |
+
save_steps = save_steps,
|
566 |
+
save_total_limit = save_total_limit,
|
567 |
+
save_safetensors = save_safetensors,
|
568 |
+
save_on_each_node = save_on_each_node,
|
569 |
+
save_only_model = save_only_model,
|
570 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
571 |
+
no_cuda = no_cuda,
|
572 |
+
use_cpu = use_cpu,
|
573 |
+
use_mps_device = use_mps_device,
|
574 |
+
seed = seed,
|
575 |
+
data_seed = data_seed,
|
576 |
+
jit_mode_eval = jit_mode_eval,
|
577 |
+
use_ipex = use_ipex,
|
578 |
+
bf16 = bf16,
|
579 |
+
fp16 = fp16,
|
580 |
+
fp16_opt_level = fp16_opt_level,
|
581 |
+
half_precision_backend = half_precision_backend,
|
582 |
+
bf16_full_eval = bf16_full_eval,
|
583 |
+
fp16_full_eval = fp16_full_eval,
|
584 |
+
tf32 = tf32,
|
585 |
+
local_rank = local_rank,
|
586 |
+
ddp_backend = ddp_backend,
|
587 |
+
tpu_num_cores = tpu_num_cores,
|
588 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
589 |
+
debug = debug,
|
590 |
+
dataloader_drop_last = dataloader_drop_last,
|
591 |
+
eval_steps = eval_steps,
|
592 |
+
dataloader_num_workers = dataloader_num_workers,
|
593 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
594 |
+
past_index = past_index,
|
595 |
+
run_name = run_name,
|
596 |
+
disable_tqdm = disable_tqdm,
|
597 |
+
remove_unused_columns = remove_unused_columns,
|
598 |
+
label_names = label_names,
|
599 |
+
load_best_model_at_end = load_best_model_at_end,
|
600 |
+
metric_for_best_model = metric_for_best_model,
|
601 |
+
greater_is_better = greater_is_better,
|
602 |
+
ignore_data_skip = ignore_data_skip,
|
603 |
+
fsdp = fsdp,
|
604 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
605 |
+
fsdp_config = fsdp_config,
|
606 |
+
tp_size = tp_size,
|
607 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
608 |
+
accelerator_config = accelerator_config,
|
609 |
+
deepspeed = deepspeed,
|
610 |
+
label_smoothing_factor = label_smoothing_factor,
|
611 |
+
optim = optim,
|
612 |
+
optim_args = optim_args,
|
613 |
+
adafactor = adafactor,
|
614 |
+
group_by_length = group_by_length,
|
615 |
+
length_column_name = length_column_name,
|
616 |
+
report_to = report_to,
|
617 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
618 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
619 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
620 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
621 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
622 |
+
skip_memory_metrics = skip_memory_metrics,
|
623 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
624 |
+
push_to_hub = push_to_hub,
|
625 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
626 |
+
hub_model_id = hub_model_id,
|
627 |
+
hub_strategy = hub_strategy,
|
628 |
+
hub_token = hub_token,
|
629 |
+
hub_private_repo = hub_private_repo,
|
630 |
+
hub_always_push = hub_always_push,
|
631 |
+
gradient_checkpointing = gradient_checkpointing,
|
632 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
633 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
634 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
635 |
+
fp16_backend = fp16_backend,
|
636 |
+
evaluation_strategy = evaluation_strategy,
|
637 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
638 |
+
push_to_hub_organization = push_to_hub_organization,
|
639 |
+
push_to_hub_token = push_to_hub_token,
|
640 |
+
mp_parameters = mp_parameters,
|
641 |
+
auto_find_batch_size = auto_find_batch_size,
|
642 |
+
full_determinism = full_determinism,
|
643 |
+
torchdynamo = torchdynamo,
|
644 |
+
ray_scope = ray_scope,
|
645 |
+
ddp_timeout = ddp_timeout,
|
646 |
+
torch_compile = torch_compile,
|
647 |
+
torch_compile_backend = torch_compile_backend,
|
648 |
+
torch_compile_mode = torch_compile_mode,
|
649 |
+
dispatch_batches = dispatch_batches,
|
650 |
+
split_batches = split_batches,
|
651 |
+
include_tokens_per_second = include_tokens_per_second,
|
652 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
653 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
654 |
+
optim_target_modules = optim_target_modules,
|
655 |
+
batch_eval_metrics = batch_eval_metrics,
|
656 |
+
eval_on_start = eval_on_start,
|
657 |
+
use_liger_kernel = use_liger_kernel,
|
658 |
+
eval_use_gather_object = eval_use_gather_object,
|
659 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
660 |
+
model_init_kwargs = model_init_kwargs,
|
661 |
+
max_prompt_length = max_prompt_length,
|
662 |
+
num_generations = num_generations,
|
663 |
+
temperature = temperature,
|
664 |
+
max_completion_length = max_completion_length,
|
665 |
+
ds3_gather_for_generation = ds3_gather_for_generation,
|
666 |
+
use_vllm = use_vllm,
|
667 |
+
vllm_device = vllm_device,
|
668 |
+
vllm_gpu_memory_utilization = vllm_gpu_memory_utilization,
|
669 |
+
vllm_dtype = vllm_dtype,
|
670 |
+
vllm_max_model_len = vllm_max_model_len,
|
671 |
+
beta = beta,
|
672 |
+
reward_weights = reward_weights,
|
673 |
+
sync_ref_model = sync_ref_model,
|
674 |
+
ref_model_mixup_alpha = ref_model_mixup_alpha,
|
675 |
+
ref_model_sync_steps = ref_model_sync_steps,
|
676 |
+
log_completions = log_completions,**kwargs)
|
677 |
+
self.vllm_sampling_params = vllm_sampling_params
|
678 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
679 |
+
pass
|
680 |
+
|
681 |
+
class _UnslothGRPOTrainer(Trainer):
|
682 |
+
""""""
|
683 |
+
|
684 |
+
_tag_names = ["trl", "grpo"]
|
685 |
+
|
686 |
+
def __init__(
|
687 |
+
self,
|
688 |
+
model: Union[str, PreTrainedModel],
|
689 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
690 |
+
args: GRPOConfig = None,
|
691 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
692 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
693 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
694 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
695 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
696 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
697 |
+
peft_config: Optional["PeftConfig"] = None,
|
698 |
+
):
|
699 |
+
|
700 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
701 |
+
# Args
|
702 |
+
if args is None:
|
703 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
704 |
+
model_name = model_name.split("/")[-1]
|
705 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
706 |
+
|
707 |
+
# Models
|
708 |
+
# Trained model
|
709 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
710 |
+
if isinstance(model, str):
|
711 |
+
model_id = model
|
712 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
713 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
714 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
715 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
716 |
+
torch_dtype = getattr(torch, torch_dtype)
|
717 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
718 |
+
else:
|
719 |
+
raise ValueError(
|
720 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
721 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
722 |
+
)
|
723 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
724 |
+
model_init_kwargs["use_cache"] = (
|
725 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
726 |
+
)
|
727 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
728 |
+
else:
|
729 |
+
model_id = model.config._name_or_path
|
730 |
+
if args.model_init_kwargs is not None:
|
731 |
+
raise ValueError(
|
732 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
733 |
+
"This argument can only be used when the `model` argument is a string."
|
734 |
+
)
|
735 |
+
|
736 |
+
if False:
|
737 |
+
model = model
|
738 |
+
|
739 |
+
# Reference model
|
740 |
+
if is_deepspeed_zero3_enabled():
|
741 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
742 |
+
elif not is_peft_model(model):
|
743 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
744 |
+
self.ref_model = create_reference_model(model)
|
745 |
+
else:
|
746 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
747 |
+
# to revert to the initial model.
|
748 |
+
self.ref_model = None
|
749 |
+
|
750 |
+
# Processing class
|
751 |
+
if processing_class is None:
|
752 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
753 |
+
|
754 |
+
# Reward functions
|
755 |
+
if not isinstance(reward_funcs, list):
|
756 |
+
reward_funcs = [reward_funcs]
|
757 |
+
for i, reward_func in enumerate(reward_funcs):
|
758 |
+
if isinstance(reward_func, str):
|
759 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
760 |
+
reward_func, num_labels=1, **model_init_kwargs
|
761 |
+
)
|
762 |
+
self.reward_funcs = reward_funcs
|
763 |
+
|
764 |
+
# Reward weights
|
765 |
+
if args.reward_weights is not None:
|
766 |
+
if len(args.reward_weights) != len(reward_funcs):
|
767 |
+
raise ValueError(
|
768 |
+
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
769 |
+
f"functions ({len(reward_funcs)})"
|
770 |
+
)
|
771 |
+
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
772 |
+
else:
|
773 |
+
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
774 |
+
|
775 |
+
# Reward processing class
|
776 |
+
if reward_processing_classes is None:
|
777 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
778 |
+
elif not isinstance(reward_processing_classes, list):
|
779 |
+
reward_processing_classes = [reward_processing_classes]
|
780 |
+
else:
|
781 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
782 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
783 |
+
|
784 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
785 |
+
if isinstance(reward_func, PreTrainedModel):
|
786 |
+
if reward_processing_class is None:
|
787 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
788 |
+
if reward_processing_class.pad_token_id is None:
|
789 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
790 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
791 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
792 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
793 |
+
reward_processing_classes[i] = reward_processing_class
|
794 |
+
self.reward_processing_classes = reward_processing_classes
|
795 |
+
|
796 |
+
# Data collator
|
797 |
+
def data_collator(features): # No data collation is needed in GRPO
|
798 |
+
return features
|
799 |
+
|
800 |
+
# Training arguments
|
801 |
+
self.max_prompt_length = args.max_prompt_length
|
802 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
803 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
804 |
+
self.use_vllm = args.use_vllm
|
805 |
+
|
806 |
+
self.beta = args.beta
|
807 |
+
|
808 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
809 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
810 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
811 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
812 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
813 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
814 |
+
model.warnings_issued["estimate_tokens"] = True
|
815 |
+
|
816 |
+
# Initialize the metrics
|
817 |
+
self._metrics = defaultdict(list)
|
818 |
+
self.log_completions = args.log_completions
|
819 |
+
|
820 |
+
super().__init__(
|
821 |
+
model=model,
|
822 |
+
args=args,
|
823 |
+
data_collator=data_collator,
|
824 |
+
train_dataset=train_dataset,
|
825 |
+
eval_dataset=eval_dataset,
|
826 |
+
processing_class=processing_class,
|
827 |
+
callbacks=callbacks,
|
828 |
+
optimizers=optimizers,
|
829 |
+
)
|
830 |
+
|
831 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
832 |
+
num_processes = self.accelerator.num_processes
|
833 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
834 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
835 |
+
if self.num_generations not in possible_values:
|
836 |
+
raise ValueError(
|
837 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
838 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
839 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
840 |
+
)
|
841 |
+
if self.args.eval_strategy != "no":
|
842 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
843 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
844 |
+
if self.num_generations not in possible_values:
|
845 |
+
raise ValueError(
|
846 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
847 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
848 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
849 |
+
)
|
850 |
+
|
851 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
852 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
853 |
+
# it's safer to set it in all cases.
|
854 |
+
set_seed(args.seed, device_specific=True)
|
855 |
+
|
856 |
+
if self.use_vllm:
|
857 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.sampling_params = SamplingParams(
|
858 |
+
temperature=args.temperature,
|
859 |
+
max_tokens=self.max_completion_length,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
860 |
+
else:
|
861 |
+
self.generation_config = GenerationConfig(
|
862 |
+
max_new_tokens=self.max_completion_length,
|
863 |
+
do_sample=True,
|
864 |
+
temperature=args.temperature,
|
865 |
+
pad_token_id=processing_class.pad_token_id,
|
866 |
+
)
|
867 |
+
|
868 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
869 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
870 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
871 |
+
self.model_accepts_loss_kwargs = False
|
872 |
+
|
873 |
+
# Add tags to the model
|
874 |
+
self.model.add_model_tags(self._tag_names)
|
875 |
+
|
876 |
+
if self.ref_model is not None:
|
877 |
+
if self.is_deepspeed_enabled:
|
878 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
879 |
+
else:
|
880 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
881 |
+
|
882 |
+
if args.sync_ref_model:
|
883 |
+
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
884 |
+
|
885 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
886 |
+
if isinstance(reward_func, PreTrainedModel):
|
887 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
888 |
+
|
889 |
+
def _set_signature_columns_if_needed(self):
|
890 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
891 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
892 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
893 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
894 |
+
if self._signature_columns is None:
|
895 |
+
self._signature_columns = ["prompt"]
|
896 |
+
|
897 |
+
def _get_train_sampler(self) -> Sampler:
|
898 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
899 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
900 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
901 |
+
# preventing discrepancies in group formation.
|
902 |
+
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
903 |
+
|
904 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
905 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
906 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
907 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
908 |
+
# preventing discrepancies in group formation.
|
909 |
+
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
910 |
+
|
911 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
912 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
913 |
+
if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0':
|
914 |
+
return None # Unsloth efficient GRPO
|
915 |
+
# Otherwise, calculate normally:
|
916 |
+
if not hasattr(self, '_autocast_dtype'):
|
917 |
+
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
|
918 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
|
919 |
+
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
|
920 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
921 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
922 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
923 |
+
|
924 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
925 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
926 |
+
# See https://github.com/huggingface/trl/issues/2770
|
927 |
+
logits = logits[:, -logits_to_keep:]
|
928 |
+
return logits
|
929 |
+
# return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
930 |
+
pass
|
931 |
+
|
932 |
+
def _move_model_to_vllm(self, *args, **kwargs): return None
|
933 |
+
|
934 |
+
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
935 |
+
device = self.accelerator.device
|
936 |
+
prompts = [x["prompt"] for x in inputs]
|
937 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
938 |
+
prompt_inputs = self.processing_class(
|
939 |
+
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
940 |
+
)
|
941 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
942 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
943 |
+
|
944 |
+
if self.max_prompt_length is not None:
|
945 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
946 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
947 |
+
|
948 |
+
# Generate completions using either vLLM or regular generation
|
949 |
+
if self.args.use_vllm:
|
950 |
+
# First, have main process load weights if needed
|
951 |
+
if self.state.global_step != self._last_loaded_step:
|
952 |
+
self._move_model_to_vllm()
|
953 |
+
self._last_loaded_step = self.state.global_step
|
954 |
+
|
955 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
956 |
+
all_prompts_text = gather_object(prompts_text)
|
957 |
+
if self.accelerator.is_main_process:
|
958 |
+
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False, lora_request = self.model.load_lora('grpo_trainer_lora_model', load_tensors = True))
|
959 |
+
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
960 |
+
else:
|
961 |
+
completion_ids = [None] * len(all_prompts_text)
|
962 |
+
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
963 |
+
# corresponding slice.
|
964 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
965 |
+
process_slice = slice(
|
966 |
+
self.accelerator.process_index * len(prompts),
|
967 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
968 |
+
)
|
969 |
+
completion_ids = completion_ids[process_slice]
|
970 |
+
|
971 |
+
# Pad the completions, and concatenate them with the prompts
|
972 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
973 |
+
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
974 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
975 |
+
else:
|
976 |
+
# Regular generation path
|
977 |
+
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
978 |
+
prompt_completion_ids = unwrapped_model.generate(
|
979 |
+
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
980 |
+
)
|
981 |
+
|
982 |
+
# Compute prompt length and extract completion ids
|
983 |
+
prompt_length = prompt_ids.size(1)
|
984 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
985 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
986 |
+
|
987 |
+
# Mask everything after the first EOS token
|
988 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
989 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
990 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
991 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
992 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
993 |
+
|
994 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
995 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
996 |
+
|
997 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
998 |
+
|
999 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
1000 |
+
if self.ref_model is not None:
|
1001 |
+
ref_per_token_logps = self._get_per_token_logps(
|
1002 |
+
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
1003 |
+
)
|
1004 |
+
else:
|
1005 |
+
with self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False).disable_adapter():
|
1006 |
+
ref_per_token_logps = self._get_per_token_logps(
|
1007 |
+
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
# Decode the generated completions
|
1011 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
1012 |
+
if is_conversational(inputs[0]):
|
1013 |
+
completions = []
|
1014 |
+
for prompt, completion in zip(prompts, completions_text):
|
1015 |
+
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
1016 |
+
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
1017 |
+
else:
|
1018 |
+
completions = completions_text
|
1019 |
+
|
1020 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
1021 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
1022 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
1023 |
+
):
|
1024 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1025 |
+
if is_conversational(inputs[0]):
|
1026 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
1027 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
1028 |
+
else:
|
1029 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
1030 |
+
reward_inputs = reward_processing_class(
|
1031 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
1032 |
+
)
|
1033 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
1034 |
+
with torch.inference_mode(), torch.amp.autocast(device_type = 'cuda', dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) if not torch.is_autocast_enabled('cuda') else nullcontext())if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):
|
1035 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
1036 |
+
else:
|
1037 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
1038 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
1039 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
1040 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
1041 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
1042 |
+
|
1043 |
+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
1044 |
+
# completions may be distributed across processes
|
1045 |
+
rewards_per_func = gather(rewards_per_func)
|
1046 |
+
|
1047 |
+
# Apply weights to each reward function's output and sum
|
1048 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
1049 |
+
|
1050 |
+
# Compute grouped-wise rewards
|
1051 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
1052 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
1053 |
+
|
1054 |
+
# Normalize the rewards to compute the advantages
|
1055 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1056 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
1057 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
1058 |
+
|
1059 |
+
# Slice to keep only the local part of the data
|
1060 |
+
process_slice = slice(
|
1061 |
+
self.accelerator.process_index * len(prompts),
|
1062 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
1063 |
+
)
|
1064 |
+
advantages = advantages[process_slice]
|
1065 |
+
|
1066 |
+
# Log the metrics
|
1067 |
+
reward_per_func = rewards_per_func.mean(0)
|
1068 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
1069 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
1070 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
1071 |
+
else:
|
1072 |
+
reward_func_name = reward_func.__name__
|
1073 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
1074 |
+
|
1075 |
+
self._metrics["reward"].append(rewards.mean().item())
|
1076 |
+
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
1077 |
+
|
1078 |
+
if (
|
1079 |
+
self.log_completions
|
1080 |
+
and self.state.global_step % self.args.logging_steps == 0
|
1081 |
+
and "wandb" in self.args.report_to
|
1082 |
+
):
|
1083 |
+
import pandas as pd
|
1084 |
+
|
1085 |
+
# For logging
|
1086 |
+
table = {
|
1087 |
+
"step": [str(self.state.global_step)] * len(rewards),
|
1088 |
+
"prompt": gather_object(prompts_text),
|
1089 |
+
"completion": gather_object(completions_text),
|
1090 |
+
"reward": rewards.tolist(),
|
1091 |
+
}
|
1092 |
+
df = pd.DataFrame(table)
|
1093 |
+
|
1094 |
+
if wandb.run is not None and self.accelerator.is_main_process:
|
1095 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1096 |
+
|
1097 |
+
return {
|
1098 |
+
"prompt_ids": prompt_ids,
|
1099 |
+
"prompt_mask": prompt_mask,
|
1100 |
+
"completion_ids": completion_ids,
|
1101 |
+
"completion_mask": completion_mask,
|
1102 |
+
"ref_per_token_logps": ref_per_token_logps,
|
1103 |
+
"advantages": advantages,
|
1104 |
+
}
|
1105 |
+
|
1106 |
+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
1107 |
+
if return_outputs:
|
1108 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
1109 |
+
# Compute the per-token log probabilities for the model
|
1110 |
+
|
1111 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
1112 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
1113 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
1114 |
+
bsz, qlen = input_ids.shape
|
1115 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
1116 |
+
# attention_mask = None
|
1117 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
1118 |
+
_input_ids = input_ids
|
1119 |
+
_logits_to_keep = logits_to_keep
|
1120 |
+
|
1121 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
1122 |
+
|
1123 |
+
# Compute the KL divergence between the model and the reference model
|
1124 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
1125 |
+
# per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
1126 |
+
|
1127 |
+
# x - x.detach() allows for preserving gradients from x
|
1128 |
+
advantages = inputs["advantages"]
|
1129 |
+
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
1130 |
+
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
1131 |
+
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1132 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
1133 |
+
if per_token_logps is not None:
|
1134 |
+
loss, completion_length, mean_kl = grpo_compute_loss_slow(
|
1135 |
+
ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages,
|
1136 |
+
)
|
1137 |
+
else:
|
1138 |
+
loss, completion_length, mean_kl = grpo_accumulated_loss(
|
1139 |
+
self, _input_ids, logits_to_keep, completion_mask, advantages,
|
1140 |
+
n_chunks = self.args.unsloth_num_chunks,
|
1141 |
+
)
|
1142 |
+
|
1143 |
+
# Log the metrics
|
1144 |
+
# completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
1145 |
+
|
1146 |
+
# mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
1147 |
+
# self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
1148 |
+
|
1149 |
+
if "train" in self._metrics:
|
1150 |
+
mode = "eval" if self.control.should_evaluate else "train"
|
1151 |
+
self._metrics[mode]["completion_length"].append(completion_length.item())
|
1152 |
+
self._metrics[mode]["kl"].append(mean_kl.item())
|
1153 |
+
else:
|
1154 |
+
self._metrics["completion_length"].append(completion_length.item())
|
1155 |
+
self._metrics["kl"].append(mean_kl.item())
|
1156 |
+
return loss
|
1157 |
+
|
1158 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
1159 |
+
inputs = self._prepare_inputs(inputs)
|
1160 |
+
with torch.no_grad():
|
1161 |
+
with self.compute_loss_context_manager():
|
1162 |
+
loss = self.compute_loss(model, inputs)
|
1163 |
+
loss = loss.mean().detach()
|
1164 |
+
return loss, None, None
|
1165 |
+
|
1166 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1167 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
1168 |
+
|
1169 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
1170 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
1171 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
1172 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
1173 |
+
|
1174 |
+
logs = {**logs, **metrics}
|
1175 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1176 |
+
super().log(logs, start_time)
|
1177 |
+
else: # transformers<=4.46
|
1178 |
+
super().log(logs)
|
1179 |
+
self._metrics.clear()
|
1180 |
+
|
1181 |
+
def create_model_card(
|
1182 |
+
self,
|
1183 |
+
model_name: Optional[str] = None,
|
1184 |
+
dataset_name: Optional[str] = None,
|
1185 |
+
tags: Union[str, list[str], None] = None,
|
1186 |
+
):
|
1187 |
+
"""
|
1188 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1189 |
+
|
1190 |
+
Args:
|
1191 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1192 |
+
Name of the model.
|
1193 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1194 |
+
Name of the dataset used for training.
|
1195 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1196 |
+
Tags to be associated with the model card.
|
1197 |
+
"""
|
1198 |
+
if not self.is_world_process_zero():
|
1199 |
+
return
|
1200 |
+
|
1201 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1202 |
+
base_model = self.model.config._name_or_path
|
1203 |
+
else:
|
1204 |
+
base_model = None
|
1205 |
+
|
1206 |
+
tags = tags or []
|
1207 |
+
if isinstance(tags, str):
|
1208 |
+
tags = [tags]
|
1209 |
+
|
1210 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1211 |
+
tags.append("unsloth")
|
1212 |
+
|
1213 |
+
citation = textwrap.dedent(
|
1214 |
+
"""\
|
1215 |
+
@article{zhihong2024deepseekmath,
|
1216 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
1217 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
1218 |
+
year = 2024,
|
1219 |
+
eprint = {arXiv:2402.03300},
|
1220 |
+
}
|
1221 |
+
"""
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
model_card = generate_model_card(
|
1225 |
+
base_model=base_model,
|
1226 |
+
model_name=model_name,
|
1227 |
+
hub_model_id=self.hub_model_id,
|
1228 |
+
dataset_name=dataset_name,
|
1229 |
+
tags=tags,
|
1230 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1231 |
+
comet_url=get_comet_experiment_url(),
|
1232 |
+
trainer_name="GRPO",
|
1233 |
+
trainer_citation=citation,
|
1234 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
1235 |
+
paper_id="2402.03300",
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1239 |
+
class UnslothGRPOTrainer(_UnslothGRPOTrainer):
|
1240 |
+
"""
|
1241 |
+
|
1242 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
1243 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
1244 |
+
|
1245 |
+
Example:
|
1246 |
+
|
1247 |
+
```python
|
1248 |
+
from datasets import load_dataset
|
1249 |
+
from trl import GRPOTrainer
|
1250 |
+
|
1251 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
1252 |
+
|
1253 |
+
def reward_func(completions, **kwargs):
|
1254 |
+
# Dummy reward function that rewards completions with more unique letters.
|
1255 |
+
return [float(len(set(completion))) for completion in completions]
|
1256 |
+
|
1257 |
+
trainer = GRPOTrainer(
|
1258 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
1259 |
+
reward_funcs=reward_func,
|
1260 |
+
train_dataset=dataset,
|
1261 |
+
)
|
1262 |
+
|
1263 |
+
trainer.train()
|
1264 |
+
```
|
1265 |
+
|
1266 |
+
Args:
|
1267 |
+
model (`Union[str, PreTrainedModel]`):
|
1268 |
+
Model to be trained. Can be either:
|
1269 |
+
|
1270 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
1271 |
+
a path to a *directory* containing model weights saved using
|
1272 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
1273 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
1274 |
+
in `args.model_init_kwargs`.
|
1275 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
1276 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
1277 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
1278 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
1279 |
+
|
1280 |
+
- A single reward function, such as:
|
1281 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
1282 |
+
path to a *directory* containing model weights saved using
|
1283 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
1284 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
1285 |
+
keyword arguments in `args.model_init_kwargs`.
|
1286 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
1287 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
1288 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
1289 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
1290 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
1291 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
1292 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
1293 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
1294 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
1295 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
1296 |
+
ignored. The format of the samples can be either:
|
1297 |
+
|
1298 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
1299 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
1300 |
+
and content).
|
1301 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
1302 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
1303 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
1304 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
1305 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
1306 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
1307 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
1308 |
+
|
1309 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
1310 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
1311 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
1312 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
1313 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
1314 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
1315 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
1316 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
1317 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
1318 |
+
|
1319 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
1320 |
+
method.
|
1321 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
1322 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
1323 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
1324 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
1325 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
1326 |
+
|
1327 |
+
"""
|
1328 |
+
def __init__(
|
1329 |
+
self,
|
1330 |
+
model,
|
1331 |
+
reward_funcs,
|
1332 |
+
args = None,
|
1333 |
+
train_dataset = None,
|
1334 |
+
eval_dataset = None,
|
1335 |
+
processing_class = None,
|
1336 |
+
reward_processing_classes = None,
|
1337 |
+
callbacks = None,
|
1338 |
+
peft_config = None,
|
1339 |
+
**kwargs
|
1340 |
+
):
|
1341 |
+
if args is None: args = UnslothGRPOConfig()
|
1342 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1343 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1344 |
+
force_float32 = False
|
1345 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1346 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1347 |
+
force_float32 = True
|
1348 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1349 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1350 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1351 |
+
from unsloth_zoo.utils import _get_dtype
|
1352 |
+
dtype = _get_dtype(dtype)
|
1353 |
+
float16 = dtype == torch.float16
|
1354 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1355 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1356 |
+
if force_float32:
|
1357 |
+
args.fp16 = False
|
1358 |
+
args.bf16 = False
|
1359 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1360 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1361 |
+
args.fp16 = float16
|
1362 |
+
args.bf16 = not float16
|
1363 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1364 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1365 |
+
args.eval_strategy = 'steps'
|
1366 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1367 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1368 |
+
if ga_steps is not None and ga_steps > 1:
|
1369 |
+
from transformers import __version__ as transformers_version
|
1370 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1371 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1372 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1373 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1374 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1375 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1376 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1377 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1378 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1379 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1380 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1381 |
+
if force_float32:
|
1382 |
+
args.bf16_full_eval = False
|
1383 |
+
args.fp16_full_eval = False
|
1384 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1385 |
+
args.bf16_full_eval = True
|
1386 |
+
args.fp16_full_eval = False
|
1387 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1388 |
+
args.bf16_full_eval = args.bf16
|
1389 |
+
args.fp16_full_eval = args.fp16
|
1390 |
+
_output_logits = False
|
1391 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1392 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1393 |
+
if _output_logits:
|
1394 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1395 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1396 |
+
pass
|
1397 |
+
else:
|
1398 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1399 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1400 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1401 |
+
max_seq_length = model.max_seq_length
|
1402 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1403 |
+
if model is not None and hasattr(model, 'for_training'):
|
1404 |
+
model.for_training()
|
1405 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1406 |
+
if 'processing_class' in locals():
|
1407 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1408 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1409 |
+
other_metrics = []
|
1410 |
+
if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]
|
1411 |
+
else: _reward_funcs = reward_funcs
|
1412 |
+
for reward_func in _reward_funcs:
|
1413 |
+
try:
|
1414 |
+
reward_func_name = reward_func.__name__
|
1415 |
+
other_metrics.append(f'rewards/{reward_func_name}')
|
1416 |
+
except: pass
|
1417 |
+
|
1418 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1419 |
+
PatchRLStatistics('grpo_trainer', other_metrics)
|
1420 |
+
|
1421 |
+
super().__init__(
|
1422 |
+
model = model,
|
1423 |
+
reward_funcs = reward_funcs,
|
1424 |
+
args = args,
|
1425 |
+
train_dataset = train_dataset,
|
1426 |
+
eval_dataset = eval_dataset,
|
1427 |
+
processing_class = processing_class,
|
1428 |
+
reward_processing_classes = reward_processing_classes,
|
1429 |
+
callbacks = callbacks,
|
1430 |
+
peft_config = peft_config,**kwargs)
|
1431 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1432 |
+
self.neftune_hook_handle.remove()
|
1433 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1434 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1435 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1436 |
+
pass
|
1437 |
+
|
1438 |
+
pass
|
unsloth_compiled_cache/UnslothKTOTrainer.py
ADDED
@@ -0,0 +1,1840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothKTOConfig(KTOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`KTOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
66 |
+
reference model.
|
67 |
+
loss_type (`str`, *optional*, defaults to `"kto"`):
|
68 |
+
Type of loss to use. Possible values are:
|
69 |
+
|
70 |
+
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
71 |
+
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
|
72 |
+
|
73 |
+
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
74 |
+
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
75 |
+
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
76 |
+
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
77 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
78 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
79 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
80 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
81 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
82 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
83 |
+
This argument is required if you want to use the default data collator.
|
84 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
85 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
86 |
+
evaluation.
|
87 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
88 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
89 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
90 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
92 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
93 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
94 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
95 |
+
string.
|
96 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
97 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
98 |
+
from a string.
|
99 |
+
dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
|
100 |
+
Number of processes to use for processing the dataset.
|
101 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
102 |
+
Whether to disable dropout in the model and reference model.
|
103 |
+
|
104 |
+
"""
|
105 |
+
vllm_sampling_params: Optional[Any] = field(
|
106 |
+
default = None,
|
107 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
108 |
+
)
|
109 |
+
unsloth_num_chunks : Optional[int] = field(
|
110 |
+
default = -1,
|
111 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
112 |
+
)
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
output_dir = None,
|
116 |
+
overwrite_output_dir = None,
|
117 |
+
do_train = False,
|
118 |
+
do_eval = False,
|
119 |
+
do_predict = False,
|
120 |
+
eval_strategy = 'no',
|
121 |
+
prediction_loss_only = False,
|
122 |
+
per_device_train_batch_size = 4,
|
123 |
+
per_device_eval_batch_size = 4,
|
124 |
+
per_gpu_train_batch_size = None,
|
125 |
+
per_gpu_eval_batch_size = None,
|
126 |
+
gradient_accumulation_steps = 2,
|
127 |
+
eval_accumulation_steps = 2,
|
128 |
+
eval_delay = 0,
|
129 |
+
torch_empty_cache_steps = 250,
|
130 |
+
learning_rate = 5e-05,
|
131 |
+
weight_decay = 0.01,
|
132 |
+
adam_beta1 = 0.9,
|
133 |
+
adam_beta2 = 0.999,
|
134 |
+
adam_epsilon = 1e-08,
|
135 |
+
max_grad_norm = 1.0,
|
136 |
+
num_train_epochs = 3.0,
|
137 |
+
max_steps = -1,
|
138 |
+
lr_scheduler_type = 'linear',
|
139 |
+
warmup_ratio = 0.1,
|
140 |
+
warmup_steps = 0,
|
141 |
+
log_level = 'passive',
|
142 |
+
log_level_replica = 'warning',
|
143 |
+
log_on_each_node = True,
|
144 |
+
logging_dir = None,
|
145 |
+
logging_strategy = 'steps',
|
146 |
+
logging_first_step = False,
|
147 |
+
logging_steps = 1,
|
148 |
+
logging_nan_inf_filter = False,
|
149 |
+
save_strategy = 'steps',
|
150 |
+
save_steps = 500,
|
151 |
+
save_total_limit = None,
|
152 |
+
save_safetensors = True,
|
153 |
+
save_on_each_node = False,
|
154 |
+
save_only_model = False,
|
155 |
+
restore_callback_states_from_checkpoint = False,
|
156 |
+
no_cuda = False,
|
157 |
+
use_cpu = False,
|
158 |
+
use_mps_device = False,
|
159 |
+
seed = 3407,
|
160 |
+
data_seed = 3407,
|
161 |
+
jit_mode_eval = False,
|
162 |
+
use_ipex = False,
|
163 |
+
bf16 = False,
|
164 |
+
fp16 = False,
|
165 |
+
fp16_opt_level = 'O1',
|
166 |
+
half_precision_backend = 'auto',
|
167 |
+
bf16_full_eval = False,
|
168 |
+
fp16_full_eval = False,
|
169 |
+
tf32 = None,
|
170 |
+
local_rank = -1,
|
171 |
+
ddp_backend = None,
|
172 |
+
tpu_num_cores = None,
|
173 |
+
tpu_metrics_debug = False,
|
174 |
+
debug = '',
|
175 |
+
dataloader_drop_last = False,
|
176 |
+
eval_steps = None,
|
177 |
+
dataloader_num_workers = 0,
|
178 |
+
dataloader_prefetch_factor = None,
|
179 |
+
past_index = -1,
|
180 |
+
run_name = None,
|
181 |
+
disable_tqdm = None,
|
182 |
+
remove_unused_columns = True,
|
183 |
+
label_names = None,
|
184 |
+
load_best_model_at_end = False,
|
185 |
+
metric_for_best_model = None,
|
186 |
+
greater_is_better = None,
|
187 |
+
ignore_data_skip = False,
|
188 |
+
fsdp = '',
|
189 |
+
fsdp_min_num_params = 0,
|
190 |
+
fsdp_config = None,
|
191 |
+
tp_size = 0,
|
192 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
193 |
+
accelerator_config = None,
|
194 |
+
deepspeed = None,
|
195 |
+
label_smoothing_factor = 0.0,
|
196 |
+
optim = 'adamw_8bit',
|
197 |
+
optim_args = None,
|
198 |
+
adafactor = False,
|
199 |
+
group_by_length = False,
|
200 |
+
length_column_name = 'length',
|
201 |
+
report_to = None,
|
202 |
+
ddp_find_unused_parameters = None,
|
203 |
+
ddp_bucket_cap_mb = None,
|
204 |
+
ddp_broadcast_buffers = None,
|
205 |
+
dataloader_pin_memory = True,
|
206 |
+
dataloader_persistent_workers = False,
|
207 |
+
skip_memory_metrics = True,
|
208 |
+
use_legacy_prediction_loop = False,
|
209 |
+
push_to_hub = False,
|
210 |
+
resume_from_checkpoint = None,
|
211 |
+
hub_model_id = None,
|
212 |
+
hub_strategy = 'every_save',
|
213 |
+
hub_token = None,
|
214 |
+
hub_private_repo = None,
|
215 |
+
hub_always_push = False,
|
216 |
+
gradient_checkpointing = False,
|
217 |
+
gradient_checkpointing_kwargs = None,
|
218 |
+
include_inputs_for_metrics = False,
|
219 |
+
eval_do_concat_batches = True,
|
220 |
+
fp16_backend = 'auto',
|
221 |
+
evaluation_strategy = None,
|
222 |
+
push_to_hub_model_id = None,
|
223 |
+
push_to_hub_organization = None,
|
224 |
+
push_to_hub_token = None,
|
225 |
+
mp_parameters = '',
|
226 |
+
auto_find_batch_size = False,
|
227 |
+
full_determinism = False,
|
228 |
+
torchdynamo = None,
|
229 |
+
ray_scope = 'last',
|
230 |
+
ddp_timeout = 1800,
|
231 |
+
torch_compile = False,
|
232 |
+
torch_compile_backend = None,
|
233 |
+
torch_compile_mode = None,
|
234 |
+
dispatch_batches = None,
|
235 |
+
split_batches = None,
|
236 |
+
include_tokens_per_second = False,
|
237 |
+
include_num_input_tokens_seen = False,
|
238 |
+
neftune_noise_alpha = None,
|
239 |
+
optim_target_modules = None,
|
240 |
+
batch_eval_metrics = False,
|
241 |
+
eval_on_start = False,
|
242 |
+
use_liger_kernel = False,
|
243 |
+
eval_use_gather_object = False,
|
244 |
+
average_tokens_across_devices = False,
|
245 |
+
max_length = 1024,
|
246 |
+
max_prompt_length = 512,
|
247 |
+
max_completion_length = None,
|
248 |
+
beta = 0.1,
|
249 |
+
loss_type = 'kto',
|
250 |
+
desirable_weight = 1.0,
|
251 |
+
undesirable_weight = 1.0,
|
252 |
+
label_pad_token_id = -100,
|
253 |
+
padding_value = None,
|
254 |
+
truncation_mode = 'keep_end',
|
255 |
+
generate_during_eval = False,
|
256 |
+
is_encoder_decoder = None,
|
257 |
+
disable_dropout = True,
|
258 |
+
precompute_ref_log_probs = False,
|
259 |
+
model_init_kwargs = None,
|
260 |
+
ref_model_init_kwargs = None,
|
261 |
+
dataset_num_proc = None,
|
262 |
+
vllm_sampling_params = None,
|
263 |
+
unsloth_num_chunks = -1,
|
264 |
+
**kwargs,
|
265 |
+
):
|
266 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
267 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
268 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
269 |
+
output_dir = 'unsloth_training_checkpoints'
|
270 |
+
save_strategy = 'no'
|
271 |
+
if dataset_num_proc is None:
|
272 |
+
from multiprocessing import cpu_count
|
273 |
+
dataset_num_proc = cpu_count()
|
274 |
+
|
275 |
+
super().__init__(
|
276 |
+
output_dir = output_dir,
|
277 |
+
overwrite_output_dir = overwrite_output_dir,
|
278 |
+
do_train = do_train,
|
279 |
+
do_eval = do_eval,
|
280 |
+
do_predict = do_predict,
|
281 |
+
eval_strategy = eval_strategy,
|
282 |
+
prediction_loss_only = prediction_loss_only,
|
283 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
284 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
285 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
286 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
287 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
288 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
289 |
+
eval_delay = eval_delay,
|
290 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
291 |
+
learning_rate = learning_rate,
|
292 |
+
weight_decay = weight_decay,
|
293 |
+
adam_beta1 = adam_beta1,
|
294 |
+
adam_beta2 = adam_beta2,
|
295 |
+
adam_epsilon = adam_epsilon,
|
296 |
+
max_grad_norm = max_grad_norm,
|
297 |
+
num_train_epochs = num_train_epochs,
|
298 |
+
max_steps = max_steps,
|
299 |
+
lr_scheduler_type = lr_scheduler_type,
|
300 |
+
warmup_ratio = warmup_ratio,
|
301 |
+
warmup_steps = warmup_steps,
|
302 |
+
log_level = log_level,
|
303 |
+
log_level_replica = log_level_replica,
|
304 |
+
log_on_each_node = log_on_each_node,
|
305 |
+
logging_dir = logging_dir,
|
306 |
+
logging_strategy = logging_strategy,
|
307 |
+
logging_first_step = logging_first_step,
|
308 |
+
logging_steps = logging_steps,
|
309 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
310 |
+
save_strategy = save_strategy,
|
311 |
+
save_steps = save_steps,
|
312 |
+
save_total_limit = save_total_limit,
|
313 |
+
save_safetensors = save_safetensors,
|
314 |
+
save_on_each_node = save_on_each_node,
|
315 |
+
save_only_model = save_only_model,
|
316 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
317 |
+
no_cuda = no_cuda,
|
318 |
+
use_cpu = use_cpu,
|
319 |
+
use_mps_device = use_mps_device,
|
320 |
+
seed = seed,
|
321 |
+
data_seed = data_seed,
|
322 |
+
jit_mode_eval = jit_mode_eval,
|
323 |
+
use_ipex = use_ipex,
|
324 |
+
bf16 = bf16,
|
325 |
+
fp16 = fp16,
|
326 |
+
fp16_opt_level = fp16_opt_level,
|
327 |
+
half_precision_backend = half_precision_backend,
|
328 |
+
bf16_full_eval = bf16_full_eval,
|
329 |
+
fp16_full_eval = fp16_full_eval,
|
330 |
+
tf32 = tf32,
|
331 |
+
local_rank = local_rank,
|
332 |
+
ddp_backend = ddp_backend,
|
333 |
+
tpu_num_cores = tpu_num_cores,
|
334 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
335 |
+
debug = debug,
|
336 |
+
dataloader_drop_last = dataloader_drop_last,
|
337 |
+
eval_steps = eval_steps,
|
338 |
+
dataloader_num_workers = dataloader_num_workers,
|
339 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
340 |
+
past_index = past_index,
|
341 |
+
run_name = run_name,
|
342 |
+
disable_tqdm = disable_tqdm,
|
343 |
+
remove_unused_columns = remove_unused_columns,
|
344 |
+
label_names = label_names,
|
345 |
+
load_best_model_at_end = load_best_model_at_end,
|
346 |
+
metric_for_best_model = metric_for_best_model,
|
347 |
+
greater_is_better = greater_is_better,
|
348 |
+
ignore_data_skip = ignore_data_skip,
|
349 |
+
fsdp = fsdp,
|
350 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
351 |
+
fsdp_config = fsdp_config,
|
352 |
+
tp_size = tp_size,
|
353 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
354 |
+
accelerator_config = accelerator_config,
|
355 |
+
deepspeed = deepspeed,
|
356 |
+
label_smoothing_factor = label_smoothing_factor,
|
357 |
+
optim = optim,
|
358 |
+
optim_args = optim_args,
|
359 |
+
adafactor = adafactor,
|
360 |
+
group_by_length = group_by_length,
|
361 |
+
length_column_name = length_column_name,
|
362 |
+
report_to = report_to,
|
363 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
364 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
365 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
366 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
367 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
368 |
+
skip_memory_metrics = skip_memory_metrics,
|
369 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
370 |
+
push_to_hub = push_to_hub,
|
371 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
372 |
+
hub_model_id = hub_model_id,
|
373 |
+
hub_strategy = hub_strategy,
|
374 |
+
hub_token = hub_token,
|
375 |
+
hub_private_repo = hub_private_repo,
|
376 |
+
hub_always_push = hub_always_push,
|
377 |
+
gradient_checkpointing = gradient_checkpointing,
|
378 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
379 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
380 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
381 |
+
fp16_backend = fp16_backend,
|
382 |
+
evaluation_strategy = evaluation_strategy,
|
383 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
384 |
+
push_to_hub_organization = push_to_hub_organization,
|
385 |
+
push_to_hub_token = push_to_hub_token,
|
386 |
+
mp_parameters = mp_parameters,
|
387 |
+
auto_find_batch_size = auto_find_batch_size,
|
388 |
+
full_determinism = full_determinism,
|
389 |
+
torchdynamo = torchdynamo,
|
390 |
+
ray_scope = ray_scope,
|
391 |
+
ddp_timeout = ddp_timeout,
|
392 |
+
torch_compile = torch_compile,
|
393 |
+
torch_compile_backend = torch_compile_backend,
|
394 |
+
torch_compile_mode = torch_compile_mode,
|
395 |
+
dispatch_batches = dispatch_batches,
|
396 |
+
split_batches = split_batches,
|
397 |
+
include_tokens_per_second = include_tokens_per_second,
|
398 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
399 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
400 |
+
optim_target_modules = optim_target_modules,
|
401 |
+
batch_eval_metrics = batch_eval_metrics,
|
402 |
+
eval_on_start = eval_on_start,
|
403 |
+
use_liger_kernel = use_liger_kernel,
|
404 |
+
eval_use_gather_object = eval_use_gather_object,
|
405 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
406 |
+
max_length = max_length,
|
407 |
+
max_prompt_length = max_prompt_length,
|
408 |
+
max_completion_length = max_completion_length,
|
409 |
+
beta = beta,
|
410 |
+
loss_type = loss_type,
|
411 |
+
desirable_weight = desirable_weight,
|
412 |
+
undesirable_weight = undesirable_weight,
|
413 |
+
label_pad_token_id = label_pad_token_id,
|
414 |
+
padding_value = padding_value,
|
415 |
+
truncation_mode = truncation_mode,
|
416 |
+
generate_during_eval = generate_during_eval,
|
417 |
+
is_encoder_decoder = is_encoder_decoder,
|
418 |
+
disable_dropout = disable_dropout,
|
419 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
420 |
+
model_init_kwargs = model_init_kwargs,
|
421 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
422 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
423 |
+
self.vllm_sampling_params = vllm_sampling_params
|
424 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
425 |
+
pass
|
426 |
+
|
427 |
+
class _UnslothKTOTrainer(Trainer):
|
428 |
+
r""""""
|
429 |
+
|
430 |
+
_tag_names = ["trl", "kto"]
|
431 |
+
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
435 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
436 |
+
args: KTOConfig = None,
|
437 |
+
train_dataset: Optional[Dataset] = None,
|
438 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
439 |
+
processing_class: Optional[
|
440 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
441 |
+
] = None,
|
442 |
+
data_collator: Optional[DataCollator] = None,
|
443 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
444 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
445 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
446 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
447 |
+
peft_config: Optional[dict] = None,
|
448 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
449 |
+
model_adapter_name: Optional[str] = None,
|
450 |
+
ref_adapter_name: Optional[str] = None,
|
451 |
+
):
|
452 |
+
if type(args) is TrainingArguments:
|
453 |
+
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
454 |
+
|
455 |
+
if not isinstance(model, str) and ref_model is model:
|
456 |
+
raise ValueError(
|
457 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
458 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
459 |
+
)
|
460 |
+
|
461 |
+
if args.model_init_kwargs is None:
|
462 |
+
model_init_kwargs = {}
|
463 |
+
elif not isinstance(model, str):
|
464 |
+
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
465 |
+
else:
|
466 |
+
model_init_kwargs = args.model_init_kwargs
|
467 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
468 |
+
if torch_dtype is not None:
|
469 |
+
# Convert to `torch.dtype` if an str is passed
|
470 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
471 |
+
torch_dtype = getattr(torch, torch_dtype)
|
472 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
473 |
+
raise ValueError(
|
474 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
475 |
+
)
|
476 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
477 |
+
|
478 |
+
if args.ref_model_init_kwargs is None:
|
479 |
+
ref_model_init_kwargs = {}
|
480 |
+
elif not isinstance(ref_model, str):
|
481 |
+
raise ValueError(
|
482 |
+
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
486 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
487 |
+
if torch_dtype is not None:
|
488 |
+
# Convert to `torch.dtype` if an str is passed
|
489 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
490 |
+
torch_dtype = getattr(torch, torch_dtype)
|
491 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
492 |
+
raise ValueError(
|
493 |
+
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
494 |
+
)
|
495 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
496 |
+
|
497 |
+
if isinstance(model, str):
|
498 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
499 |
+
|
500 |
+
if isinstance(ref_model, str):
|
501 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
502 |
+
|
503 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
504 |
+
# has been called in order to properly call autocast if needed.
|
505 |
+
self._peft_has_been_casted_to_bf16 = False
|
506 |
+
|
507 |
+
if not is_peft_available() and peft_config is not None:
|
508 |
+
raise ValueError(
|
509 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
510 |
+
)
|
511 |
+
elif is_peft_available() and peft_config is not None:
|
512 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
513 |
+
if isinstance(model, PeftModel):
|
514 |
+
model = model.merge_and_unload()
|
515 |
+
|
516 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
517 |
+
_support_gc_kwargs = hasattr(
|
518 |
+
args, "gradient_checkpointing_kwargs"
|
519 |
+
) and "gradient_checkpointing_kwargs" in list(
|
520 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
521 |
+
)
|
522 |
+
|
523 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
524 |
+
|
525 |
+
if _support_gc_kwargs:
|
526 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
527 |
+
|
528 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
529 |
+
elif getattr(args, "gradient_checkpointing", False):
|
530 |
+
# For backward compatibility with older versions of transformers
|
531 |
+
if hasattr(model, "enable_input_require_grads"):
|
532 |
+
model.enable_input_require_grads()
|
533 |
+
else:
|
534 |
+
|
535 |
+
def make_inputs_require_grad(module, input, output):
|
536 |
+
output.requires_grad_(True)
|
537 |
+
|
538 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
539 |
+
|
540 |
+
# get peft model with the given config
|
541 |
+
model = model
|
542 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
543 |
+
peft_module_casting_to_bf16(model)
|
544 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
545 |
+
self._peft_has_been_casted_to_bf16 = True
|
546 |
+
|
547 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
548 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
549 |
+
# fail or completely fail.
|
550 |
+
elif getattr(args, "gradient_checkpointing", False):
|
551 |
+
# For backward compatibility with older versions of transformers
|
552 |
+
if hasattr(model, "enable_input_require_grads"):
|
553 |
+
model.enable_input_require_grads()
|
554 |
+
else:
|
555 |
+
|
556 |
+
def make_inputs_require_grad(module, input, output):
|
557 |
+
output.requires_grad_(True)
|
558 |
+
|
559 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
560 |
+
|
561 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
562 |
+
raise ValueError(
|
563 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
564 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
565 |
+
)
|
566 |
+
|
567 |
+
if model is not None:
|
568 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
569 |
+
elif args.is_encoder_decoder is None:
|
570 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
571 |
+
else:
|
572 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
573 |
+
|
574 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
575 |
+
self.model_adapter_name = model_adapter_name
|
576 |
+
self.ref_adapter_name = ref_adapter_name
|
577 |
+
|
578 |
+
if ref_model:
|
579 |
+
self.ref_model = ref_model
|
580 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
581 |
+
# The `model` with adapters turned off will be used as the reference model
|
582 |
+
self.ref_model = None
|
583 |
+
else:
|
584 |
+
self.ref_model = create_reference_model(model)
|
585 |
+
|
586 |
+
if processing_class is None:
|
587 |
+
raise ValueError(
|
588 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
589 |
+
)
|
590 |
+
if args.max_length is None:
|
591 |
+
warnings.warn(
|
592 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
593 |
+
" it will be set to `512` by default, but you should do it yourself in the future.",
|
594 |
+
UserWarning,
|
595 |
+
)
|
596 |
+
max_length = 512
|
597 |
+
if args.max_length is not None:
|
598 |
+
max_length = args.max_length
|
599 |
+
|
600 |
+
if args.max_prompt_length is None:
|
601 |
+
warnings.warn(
|
602 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
603 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
604 |
+
UserWarning,
|
605 |
+
)
|
606 |
+
max_prompt_length = 128
|
607 |
+
if args.max_prompt_length is not None:
|
608 |
+
max_prompt_length = args.max_prompt_length
|
609 |
+
|
610 |
+
max_completion_length = None
|
611 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
612 |
+
warnings.warn(
|
613 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
614 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
615 |
+
UserWarning,
|
616 |
+
)
|
617 |
+
max_completion_length = 128
|
618 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
619 |
+
max_completion_length = args.max_completion_length
|
620 |
+
|
621 |
+
if data_collator is None:
|
622 |
+
data_collator = DPODataCollatorWithPadding(
|
623 |
+
pad_token_id=processing_class.pad_token_id,
|
624 |
+
label_pad_token_id=args.label_pad_token_id,
|
625 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
626 |
+
)
|
627 |
+
|
628 |
+
if args.remove_unused_columns:
|
629 |
+
args.remove_unused_columns = False
|
630 |
+
# warn users
|
631 |
+
warnings.warn(
|
632 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
633 |
+
" we have set it for you, but you should do it yourself in the future.",
|
634 |
+
UserWarning,
|
635 |
+
)
|
636 |
+
|
637 |
+
self.use_dpo_data_collator = True
|
638 |
+
else:
|
639 |
+
self.use_dpo_data_collator = False
|
640 |
+
|
641 |
+
# Disable dropout in the model and reference model
|
642 |
+
if args.disable_dropout:
|
643 |
+
disable_dropout_in_model(model)
|
644 |
+
if self.ref_model is not None:
|
645 |
+
disable_dropout_in_model(self.ref_model)
|
646 |
+
|
647 |
+
self.loss_type = args.loss_type
|
648 |
+
self.max_length = max_length
|
649 |
+
self.generate_during_eval = args.generate_during_eval
|
650 |
+
self.label_pad_token_id = args.label_pad_token_id
|
651 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
652 |
+
self.max_prompt_length = max_prompt_length
|
653 |
+
self.truncation_mode = args.truncation_mode
|
654 |
+
self.max_completion_length = max_completion_length
|
655 |
+
self.processing_class = processing_class
|
656 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
657 |
+
|
658 |
+
# Not all losses require a KL calculation
|
659 |
+
self.calculate_KL = True
|
660 |
+
if self.loss_type in ["apo_zero_unpaired"]:
|
661 |
+
self.calculate_KL = False
|
662 |
+
|
663 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
664 |
+
# keep track of first called to avoid computation of future calls
|
665 |
+
self._precomputed_train_ref_log_probs = False
|
666 |
+
self._precomputed_eval_ref_log_probs = False
|
667 |
+
|
668 |
+
# metric
|
669 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
670 |
+
|
671 |
+
# KTO parameter
|
672 |
+
self.beta = args.beta
|
673 |
+
self.desirable_weight = args.desirable_weight
|
674 |
+
self.undesirable_weight = args.undesirable_weight
|
675 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
676 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
677 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
678 |
+
warnings.warn(
|
679 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
680 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
681 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
682 |
+
"loss.",
|
683 |
+
UserWarning,
|
684 |
+
)
|
685 |
+
|
686 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
687 |
+
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
688 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
689 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
690 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
691 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
692 |
+
# issued.
|
693 |
+
model.warnings_issued["estimate_tokens"] = True
|
694 |
+
|
695 |
+
# Compute that only on the main process for faster data processing.
|
696 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
697 |
+
with PartialState().local_main_process_first():
|
698 |
+
# Extract the prompt if needed
|
699 |
+
train_dataset = train_dataset.map(
|
700 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
701 |
+
)
|
702 |
+
# Unpair the dataset if needed
|
703 |
+
train_dataset = maybe_unpair_preference_dataset(
|
704 |
+
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
705 |
+
)
|
706 |
+
# Apply the chat template if needed
|
707 |
+
train_dataset = train_dataset.map(
|
708 |
+
maybe_apply_chat_template,
|
709 |
+
fn_kwargs={"tokenizer": processing_class},
|
710 |
+
num_proc=args.dataset_num_proc,
|
711 |
+
desc="Applying chat template to train dataset",
|
712 |
+
)
|
713 |
+
if eval_dataset is not None:
|
714 |
+
eval_dataset = eval_dataset.map(
|
715 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
716 |
+
)
|
717 |
+
eval_dataset = maybe_unpair_preference_dataset(
|
718 |
+
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
719 |
+
)
|
720 |
+
eval_dataset = eval_dataset.map(
|
721 |
+
maybe_apply_chat_template,
|
722 |
+
fn_kwargs={"tokenizer": processing_class},
|
723 |
+
num_proc=args.dataset_num_proc,
|
724 |
+
desc="Applying chat template to eval dataset",
|
725 |
+
)
|
726 |
+
|
727 |
+
# Tokenize and prepare the training datasets
|
728 |
+
train_dataset = train_dataset.map(
|
729 |
+
_tokenize,
|
730 |
+
batched=True,
|
731 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
732 |
+
num_proc=args.dataset_num_proc,
|
733 |
+
desc="Tokenizing train dataset",
|
734 |
+
)
|
735 |
+
|
736 |
+
fn_kwargs = {
|
737 |
+
"prefix": "",
|
738 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
739 |
+
"tokenizer": self.processing_class,
|
740 |
+
"max_length": self.max_length,
|
741 |
+
"truncation_mode": self.truncation_mode,
|
742 |
+
"label_pad_token_id": self.label_pad_token_id,
|
743 |
+
"max_prompt_length": self.max_prompt_length,
|
744 |
+
"max_completion_length": self.max_completion_length,
|
745 |
+
}
|
746 |
+
|
747 |
+
train_dataset = train_dataset.map(
|
748 |
+
_process_tokens,
|
749 |
+
fn_kwargs=fn_kwargs,
|
750 |
+
num_proc=args.dataset_num_proc,
|
751 |
+
desc="Processing tokenized train dataset",
|
752 |
+
)
|
753 |
+
|
754 |
+
# Tokenize and prepare the eval datasets
|
755 |
+
if eval_dataset is not None:
|
756 |
+
eval_dataset = eval_dataset.map(
|
757 |
+
_tokenize,
|
758 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
759 |
+
batched=True,
|
760 |
+
num_proc=args.dataset_num_proc,
|
761 |
+
desc="Tokenizing eval dataset",
|
762 |
+
)
|
763 |
+
|
764 |
+
eval_dataset = eval_dataset.map(
|
765 |
+
_process_tokens,
|
766 |
+
fn_kwargs=fn_kwargs,
|
767 |
+
num_proc=args.dataset_num_proc,
|
768 |
+
desc="Processing tokenized eval dataset",
|
769 |
+
)
|
770 |
+
|
771 |
+
# Get KL datasets if needed
|
772 |
+
if self.calculate_KL:
|
773 |
+
if args.per_device_train_batch_size <= 1:
|
774 |
+
raise ValueError(
|
775 |
+
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
776 |
+
)
|
777 |
+
|
778 |
+
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
779 |
+
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
780 |
+
train_kl_dataset = train_dataset.map(
|
781 |
+
_get_kl_dataset,
|
782 |
+
batched=True,
|
783 |
+
batch_size=args.per_device_train_batch_size,
|
784 |
+
num_proc=args.dataset_num_proc,
|
785 |
+
desc="Extracting KL train dataset",
|
786 |
+
)
|
787 |
+
|
788 |
+
fn_kwargs["prefix"] = "KL_"
|
789 |
+
train_kl_dataset = train_kl_dataset.map(
|
790 |
+
_process_tokens,
|
791 |
+
fn_kwargs=fn_kwargs,
|
792 |
+
num_proc=args.dataset_num_proc,
|
793 |
+
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
794 |
+
desc="Processing tokenized train KL dataset",
|
795 |
+
)
|
796 |
+
|
797 |
+
# merge the datasets
|
798 |
+
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
799 |
+
|
800 |
+
if eval_dataset is not None:
|
801 |
+
# Get KL dataset
|
802 |
+
eval_kl_dataset = eval_dataset.map(
|
803 |
+
_get_kl_dataset,
|
804 |
+
batched=True,
|
805 |
+
batch_size=args.per_device_train_batch_size,
|
806 |
+
num_proc=args.dataset_num_proc,
|
807 |
+
desc="Extracting eval KL dataset",
|
808 |
+
)
|
809 |
+
|
810 |
+
eval_kl_dataset = eval_kl_dataset.map(
|
811 |
+
_process_tokens,
|
812 |
+
fn_kwargs=fn_kwargs,
|
813 |
+
num_proc=args.dataset_num_proc,
|
814 |
+
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
815 |
+
desc="Processing tokenized eval KL dataset",
|
816 |
+
)
|
817 |
+
|
818 |
+
# merge the datasets
|
819 |
+
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
820 |
+
|
821 |
+
# calculate dataset desirability balance
|
822 |
+
num_desirable = max(sum(train_dataset["label"]), 1)
|
823 |
+
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
824 |
+
|
825 |
+
if num_desirable != num_undesirable:
|
826 |
+
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
827 |
+
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
828 |
+
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
829 |
+
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
830 |
+
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
831 |
+
|
832 |
+
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
833 |
+
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
834 |
+
|
835 |
+
if not (des_weight_in_range or und_weight_in_range):
|
836 |
+
warnings.warn(
|
837 |
+
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
838 |
+
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
839 |
+
f"on your data, we recommend EITHER "
|
840 |
+
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
841 |
+
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
842 |
+
"See the documentation on how to optimally set these weights.",
|
843 |
+
UserWarning,
|
844 |
+
)
|
845 |
+
|
846 |
+
super().__init__(
|
847 |
+
model=model,
|
848 |
+
args=args,
|
849 |
+
data_collator=data_collator,
|
850 |
+
train_dataset=train_dataset,
|
851 |
+
eval_dataset=eval_dataset,
|
852 |
+
processing_class=processing_class,
|
853 |
+
model_init=model_init,
|
854 |
+
compute_metrics=compute_metrics,
|
855 |
+
callbacks=callbacks,
|
856 |
+
optimizers=optimizers,
|
857 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
858 |
+
)
|
859 |
+
|
860 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
861 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
862 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
863 |
+
self.model_accepts_loss_kwargs = False
|
864 |
+
|
865 |
+
# Add tags for models that have been loaded with the correct transformers version
|
866 |
+
if hasattr(self.model, "add_model_tags"):
|
867 |
+
self.model.add_model_tags(self._tag_names)
|
868 |
+
|
869 |
+
if not hasattr(self, "accelerator"):
|
870 |
+
raise AttributeError(
|
871 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
872 |
+
)
|
873 |
+
|
874 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
875 |
+
if self.is_deepspeed_enabled:
|
876 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
877 |
+
raise ValueError(
|
878 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
879 |
+
)
|
880 |
+
|
881 |
+
if self.ref_model is None:
|
882 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
883 |
+
raise ValueError(
|
884 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
885 |
+
)
|
886 |
+
else:
|
887 |
+
if self.is_deepspeed_enabled:
|
888 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
889 |
+
else:
|
890 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
891 |
+
|
892 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
893 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
894 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
895 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
896 |
+
|
897 |
+
if model is not None:
|
898 |
+
if hasattr(model, "config"):
|
899 |
+
hidden_size = (
|
900 |
+
max(model.config.hidden_sizes)
|
901 |
+
if getattr(model.config, "hidden_sizes", None)
|
902 |
+
else getattr(model.config, "hidden_size", None)
|
903 |
+
)
|
904 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
905 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
906 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
907 |
+
config_kwargs.update(
|
908 |
+
{
|
909 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
910 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
911 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
912 |
+
}
|
913 |
+
)
|
914 |
+
|
915 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
916 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
917 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
918 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
919 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
920 |
+
model.eval()
|
921 |
+
return model
|
922 |
+
|
923 |
+
@contextmanager
|
924 |
+
def null_ref_context(self):
|
925 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
926 |
+
with (
|
927 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
928 |
+
if self.is_peft_model and not self.ref_adapter_name
|
929 |
+
else nullcontext()
|
930 |
+
):
|
931 |
+
if self.ref_adapter_name:
|
932 |
+
self.model.set_adapter(self.ref_adapter_name)
|
933 |
+
yield
|
934 |
+
if self.ref_adapter_name:
|
935 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
936 |
+
|
937 |
+
def get_train_dataloader(self) -> DataLoader:
|
938 |
+
"""
|
939 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
940 |
+
|
941 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
942 |
+
"""
|
943 |
+
|
944 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
945 |
+
dataloader_params = {
|
946 |
+
"batch_size": self.args.per_device_train_batch_size,
|
947 |
+
"collate_fn": self.data_collator,
|
948 |
+
"num_workers": self.args.dataloader_num_workers,
|
949 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
950 |
+
"shuffle": False,
|
951 |
+
}
|
952 |
+
|
953 |
+
# prepare dataloader
|
954 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
955 |
+
reference_completion_logps = []
|
956 |
+
reference_KL_logps = []
|
957 |
+
|
958 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
959 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
960 |
+
|
961 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
962 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
963 |
+
|
964 |
+
if self.calculate_KL:
|
965 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
966 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
967 |
+
|
968 |
+
self.train_dataset = self.train_dataset.add_column(
|
969 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
970 |
+
)
|
971 |
+
|
972 |
+
if self.calculate_KL:
|
973 |
+
self.train_dataset = self.train_dataset.add_column(
|
974 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
975 |
+
)
|
976 |
+
|
977 |
+
self._precomputed_train_ref_log_probs = True
|
978 |
+
|
979 |
+
return super().get_train_dataloader()
|
980 |
+
|
981 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
982 |
+
"""
|
983 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
984 |
+
|
985 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
986 |
+
|
987 |
+
Args:
|
988 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
989 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
990 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
991 |
+
"""
|
992 |
+
if eval_dataset is None and self.eval_dataset is None:
|
993 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
994 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
995 |
+
|
996 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
997 |
+
dataloader_params = {
|
998 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
999 |
+
"collate_fn": self.data_collator,
|
1000 |
+
"num_workers": self.args.dataloader_num_workers,
|
1001 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
1002 |
+
"shuffle": False,
|
1003 |
+
}
|
1004 |
+
|
1005 |
+
# prepare dataloader
|
1006 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
1007 |
+
|
1008 |
+
reference_completion_logps = []
|
1009 |
+
reference_KL_logps = []
|
1010 |
+
|
1011 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
1012 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
1013 |
+
|
1014 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
1015 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
1016 |
+
|
1017 |
+
if self.calculate_KL:
|
1018 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
1019 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
1020 |
+
|
1021 |
+
eval_dataset = eval_dataset.add_column(
|
1022 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
1023 |
+
)
|
1024 |
+
if self.calculate_KL:
|
1025 |
+
eval_dataset = eval_dataset.add_column(
|
1026 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
1027 |
+
)
|
1028 |
+
|
1029 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
1030 |
+
if self.eval_dataset is not None:
|
1031 |
+
self.eval_dataset = eval_dataset
|
1032 |
+
self._precomputed_eval_ref_log_probs = True
|
1033 |
+
|
1034 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
1035 |
+
|
1036 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
1037 |
+
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
1038 |
+
with torch.no_grad():
|
1039 |
+
if self.ref_model is None:
|
1040 |
+
with self.null_ref_context():
|
1041 |
+
if self.is_encoder_decoder:
|
1042 |
+
completion_logits = self.model(
|
1043 |
+
padded_batch["prompt_input_ids"],
|
1044 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1045 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1046 |
+
labels=padded_batch["completion_labels"],
|
1047 |
+
).logits
|
1048 |
+
|
1049 |
+
if self.calculate_KL:
|
1050 |
+
KL_logits = self.model(
|
1051 |
+
padded_batch["KL_prompt_input_ids"],
|
1052 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1053 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1054 |
+
labels=padded_batch["KL_completion_labels"],
|
1055 |
+
).logits
|
1056 |
+
else:
|
1057 |
+
completion_logits = self.model(
|
1058 |
+
padded_batch["completion_input_ids"],
|
1059 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
1060 |
+
).logits
|
1061 |
+
|
1062 |
+
if self.calculate_KL:
|
1063 |
+
KL_logits = self.model(
|
1064 |
+
padded_batch["KL_completion_input_ids"],
|
1065 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1066 |
+
).logits
|
1067 |
+
else:
|
1068 |
+
if self.is_encoder_decoder:
|
1069 |
+
completion_logits = self.ref_model(
|
1070 |
+
padded_batch["prompt_input_ids"],
|
1071 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
1072 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
1073 |
+
labels=padded_batch["completion_labels"],
|
1074 |
+
).logits
|
1075 |
+
|
1076 |
+
if self.calculate_KL:
|
1077 |
+
KL_logits = self.ref_model(
|
1078 |
+
padded_batch["KL_prompt_input_ids"],
|
1079 |
+
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
1080 |
+
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
1081 |
+
labels=padded_batch["KL_completion_labels"],
|
1082 |
+
).logits
|
1083 |
+
else:
|
1084 |
+
completion_logits = self.ref_model(
|
1085 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
1086 |
+
).logits
|
1087 |
+
|
1088 |
+
if self.calculate_KL:
|
1089 |
+
KL_logits = self.ref_model(
|
1090 |
+
padded_batch["KL_completion_input_ids"],
|
1091 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
1092 |
+
).logits
|
1093 |
+
|
1094 |
+
completion_logps = self.get_batch_logps(
|
1095 |
+
completion_logits,
|
1096 |
+
padded_batch["completion_labels"],
|
1097 |
+
average_log_prob=False,
|
1098 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1099 |
+
label_pad_token_id=self.label_pad_token_id,
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
if self.calculate_KL:
|
1103 |
+
KL_logps = self.get_batch_logps(
|
1104 |
+
KL_logits,
|
1105 |
+
padded_batch["KL_completion_labels"],
|
1106 |
+
average_log_prob=False,
|
1107 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1108 |
+
label_pad_token_id=self.label_pad_token_id,
|
1109 |
+
)
|
1110 |
+
else:
|
1111 |
+
KL_logps = None
|
1112 |
+
|
1113 |
+
return completion_logps, KL_logps
|
1114 |
+
|
1115 |
+
@staticmethod
|
1116 |
+
def get_batch_logps(
|
1117 |
+
logits: torch.FloatTensor,
|
1118 |
+
labels: torch.LongTensor,
|
1119 |
+
average_log_prob: bool = False,
|
1120 |
+
label_pad_token_id: int = -100,
|
1121 |
+
is_encoder_decoder: bool = False,
|
1122 |
+
) -> torch.FloatTensor:
|
1123 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
1124 |
+
|
1125 |
+
Args:
|
1126 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
1127 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
1128 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
1129 |
+
|
1130 |
+
Returns:
|
1131 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
1132 |
+
"""
|
1133 |
+
if logits.shape[:-1] != labels.shape:
|
1134 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
1135 |
+
|
1136 |
+
if not is_encoder_decoder:
|
1137 |
+
labels = labels[:, 1:].clone()
|
1138 |
+
logits = logits[:, :-1, :]
|
1139 |
+
else:
|
1140 |
+
# Fixes end-dec RuntimeError
|
1141 |
+
labels = labels.clone()
|
1142 |
+
|
1143 |
+
loss_mask = labels != label_pad_token_id
|
1144 |
+
|
1145 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1146 |
+
labels[labels == label_pad_token_id] = 0
|
1147 |
+
|
1148 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1149 |
+
|
1150 |
+
if average_log_prob:
|
1151 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1152 |
+
else:
|
1153 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1154 |
+
|
1155 |
+
def forward(
|
1156 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1157 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1158 |
+
if self.calculate_KL:
|
1159 |
+
KL_logps = None
|
1160 |
+
KL_model_kwargs = (
|
1161 |
+
{
|
1162 |
+
"input_ids": batch["KL_prompt_input_ids"],
|
1163 |
+
"attention_mask": batch["KL_prompt_attention_mask"],
|
1164 |
+
"labels": batch["KL_completion_labels"],
|
1165 |
+
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
|
1166 |
+
}
|
1167 |
+
if self.is_encoder_decoder
|
1168 |
+
else {
|
1169 |
+
"input_ids": batch["KL_completion_input_ids"],
|
1170 |
+
"attention_mask": batch["KL_completion_attention_mask"],
|
1171 |
+
}
|
1172 |
+
)
|
1173 |
+
with torch.no_grad():
|
1174 |
+
KL_logits = model(
|
1175 |
+
**KL_model_kwargs,
|
1176 |
+
).logits
|
1177 |
+
|
1178 |
+
KL_logps = self.get_batch_logps(
|
1179 |
+
KL_logits,
|
1180 |
+
batch["KL_completion_labels"],
|
1181 |
+
average_log_prob=False,
|
1182 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1183 |
+
label_pad_token_id=self.label_pad_token_id,
|
1184 |
+
)
|
1185 |
+
else:
|
1186 |
+
KL_logps = None
|
1187 |
+
|
1188 |
+
model_kwargs = (
|
1189 |
+
{
|
1190 |
+
"labels": batch["completion_labels"],
|
1191 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
1192 |
+
}
|
1193 |
+
if self.is_encoder_decoder
|
1194 |
+
else {}
|
1195 |
+
)
|
1196 |
+
if self.aux_loss_enabled:
|
1197 |
+
model_kwargs["output_router_logits"] = True
|
1198 |
+
|
1199 |
+
outputs = model(
|
1200 |
+
batch["completion_input_ids"],
|
1201 |
+
attention_mask=batch["completion_attention_mask"],
|
1202 |
+
**model_kwargs,
|
1203 |
+
)
|
1204 |
+
completion_logits = outputs.logits
|
1205 |
+
|
1206 |
+
completion_logps = self.get_batch_logps(
|
1207 |
+
completion_logits,
|
1208 |
+
batch["completion_labels"],
|
1209 |
+
average_log_prob=False,
|
1210 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1211 |
+
label_pad_token_id=self.label_pad_token_id,
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
1215 |
+
raise ValueError(
|
1216 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
1217 |
+
"examples for which an output sequence was predicted."
|
1218 |
+
)
|
1219 |
+
|
1220 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
1221 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
1222 |
+
|
1223 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
1224 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
1225 |
+
|
1226 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
1227 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
1228 |
+
|
1229 |
+
if self.aux_loss_enabled:
|
1230 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
1231 |
+
else:
|
1232 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
1233 |
+
|
1234 |
+
def kto_loss(
|
1235 |
+
self,
|
1236 |
+
policy_chosen_logps: torch.FloatTensor,
|
1237 |
+
policy_rejected_logps: torch.FloatTensor,
|
1238 |
+
policy_KL_logps: torch.FloatTensor,
|
1239 |
+
reference_chosen_logps: torch.FloatTensor,
|
1240 |
+
reference_rejected_logps: torch.FloatTensor,
|
1241 |
+
reference_KL_logps: torch.FloatTensor,
|
1242 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1243 |
+
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
1244 |
+
|
1245 |
+
Args:
|
1246 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1247 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1248 |
+
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
1249 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
1250 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
1251 |
+
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
1252 |
+
|
1253 |
+
Returns:
|
1254 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
1255 |
+
The losses tensor contains the KTO loss for each example in the batch.
|
1256 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
1257 |
+
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
1258 |
+
"""
|
1259 |
+
if self.calculate_KL:
|
1260 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
1261 |
+
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
1262 |
+
else:
|
1263 |
+
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
1264 |
+
|
1265 |
+
# Chosen losses
|
1266 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
1267 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
1268 |
+
|
1269 |
+
if self.loss_type == "kto":
|
1270 |
+
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
1271 |
+
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
1272 |
+
elif self.loss_type == "apo_zero_unpaired":
|
1273 |
+
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
1274 |
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
1275 |
+
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
1276 |
+
|
1277 |
+
chosen_rewards = self.beta * chosen_logratios.detach()
|
1278 |
+
|
1279 |
+
else:
|
1280 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1281 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
1282 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1283 |
+
|
1284 |
+
# Rejected losses
|
1285 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
1286 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
1287 |
+
|
1288 |
+
if self.loss_type == "kto":
|
1289 |
+
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
1290 |
+
elif self.loss_type == "apo_zero_unpaired":
|
1291 |
+
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
1292 |
+
|
1293 |
+
rejected_rewards = self.beta * rejected_logratios.detach()
|
1294 |
+
else:
|
1295 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
1296 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
1297 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
1298 |
+
|
1299 |
+
losses = torch.cat(
|
1300 |
+
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
1301 |
+
0,
|
1302 |
+
)
|
1303 |
+
|
1304 |
+
return losses, chosen_rewards, rejected_rewards, kl
|
1305 |
+
|
1306 |
+
def get_batch_loss_metrics(
|
1307 |
+
self,
|
1308 |
+
model,
|
1309 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1310 |
+
):
|
1311 |
+
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
1312 |
+
metrics = {}
|
1313 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
1314 |
+
|
1315 |
+
forward_output = self.forward(model, batch)
|
1316 |
+
(
|
1317 |
+
policy_chosen_logps,
|
1318 |
+
policy_rejected_logps,
|
1319 |
+
policy_chosen_logits,
|
1320 |
+
policy_rejected_logits,
|
1321 |
+
policy_KL_logps,
|
1322 |
+
) = forward_output[:5]
|
1323 |
+
if self.aux_loss_enabled:
|
1324 |
+
aux_loss = forward_output[5]
|
1325 |
+
|
1326 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
1327 |
+
if "reference_logps" in batch:
|
1328 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
1329 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
1330 |
+
|
1331 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
1332 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
1333 |
+
if self.calculate_KL:
|
1334 |
+
reference_KL_logps = batch["reference_KL_logps"]
|
1335 |
+
else:
|
1336 |
+
reference_KL_logps = None
|
1337 |
+
else:
|
1338 |
+
with torch.no_grad():
|
1339 |
+
if self.ref_model is None:
|
1340 |
+
with self.null_ref_context():
|
1341 |
+
(
|
1342 |
+
reference_chosen_logps,
|
1343 |
+
reference_rejected_logps,
|
1344 |
+
_,
|
1345 |
+
_,
|
1346 |
+
reference_KL_logps,
|
1347 |
+
) = self.forward(self.model, batch)[:5]
|
1348 |
+
else:
|
1349 |
+
(
|
1350 |
+
reference_chosen_logps,
|
1351 |
+
reference_rejected_logps,
|
1352 |
+
_,
|
1353 |
+
_,
|
1354 |
+
reference_KL_logps,
|
1355 |
+
) = self.forward(self.ref_model, batch)[:5]
|
1356 |
+
|
1357 |
+
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
1358 |
+
policy_chosen_logps,
|
1359 |
+
policy_rejected_logps,
|
1360 |
+
policy_KL_logps,
|
1361 |
+
reference_chosen_logps,
|
1362 |
+
reference_rejected_logps,
|
1363 |
+
reference_KL_logps,
|
1364 |
+
)
|
1365 |
+
metrics["kl"] = kl.item()
|
1366 |
+
|
1367 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
1368 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
1369 |
+
|
1370 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
1371 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
1372 |
+
|
1373 |
+
if all_num_chosen > 0:
|
1374 |
+
metrics["rewards/chosen_sum"] = (
|
1375 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
1376 |
+
)
|
1377 |
+
metrics["logps/chosen_sum"] = (
|
1378 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
1379 |
+
)
|
1380 |
+
metrics["logits/chosen_sum"] = (
|
1381 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
1382 |
+
)
|
1383 |
+
metrics["count/chosen"] = all_num_chosen
|
1384 |
+
|
1385 |
+
if all_num_rejected > 0:
|
1386 |
+
metrics["rewards/rejected_sum"] = (
|
1387 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
1388 |
+
)
|
1389 |
+
metrics["logps/rejected_sum"] = (
|
1390 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
1391 |
+
)
|
1392 |
+
metrics["logits/rejected_sum"] = (
|
1393 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
1394 |
+
)
|
1395 |
+
metrics["count/rejected"] = all_num_rejected
|
1396 |
+
|
1397 |
+
loss = losses.nanmean()
|
1398 |
+
if self.aux_loss_enabled:
|
1399 |
+
loss += self.aux_loss_coef * aux_loss
|
1400 |
+
|
1401 |
+
return loss, metrics
|
1402 |
+
|
1403 |
+
def compute_loss(
|
1404 |
+
self,
|
1405 |
+
model: Union[PreTrainedModel, nn.Module],
|
1406 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1407 |
+
return_outputs=False,
|
1408 |
+
num_items_in_batch=None,
|
1409 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1410 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1411 |
+
|
1412 |
+
with compute_loss_context_manager:
|
1413 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1414 |
+
|
1415 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1416 |
+
loss = loss.to(self.args.device)
|
1417 |
+
# force log the metrics
|
1418 |
+
if self.accelerator.is_main_process:
|
1419 |
+
self.store_metrics(metrics, train_eval="train")
|
1420 |
+
|
1421 |
+
if return_outputs:
|
1422 |
+
return (loss, metrics)
|
1423 |
+
return loss
|
1424 |
+
|
1425 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1426 |
+
for key, value in metrics.items():
|
1427 |
+
self._stored_metrics[train_eval][key].append(value)
|
1428 |
+
|
1429 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
1430 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
1431 |
+
return None
|
1432 |
+
return SequentialSampler(self.train_dataset)
|
1433 |
+
|
1434 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
1435 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1436 |
+
|
1437 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1438 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1439 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1440 |
+
|
1441 |
+
with generate_context_manager:
|
1442 |
+
policy_output = model.generate(
|
1443 |
+
input_ids=batch["prompt_input_ids"],
|
1444 |
+
attention_mask=batch["prompt_attention_mask"],
|
1445 |
+
max_length=self.max_length,
|
1446 |
+
do_sample=True,
|
1447 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1448 |
+
)
|
1449 |
+
|
1450 |
+
# if reference_output in batch use that otherwise use the reference model
|
1451 |
+
if "reference_output" in batch:
|
1452 |
+
reference_output = batch["reference_output"]
|
1453 |
+
else:
|
1454 |
+
if self.ref_model is None:
|
1455 |
+
with self.null_ref_context():
|
1456 |
+
reference_output = self.model.generate(
|
1457 |
+
input_ids=batch["prompt_input_ids"],
|
1458 |
+
attention_mask=batch["prompt_attention_mask"],
|
1459 |
+
max_length=self.max_length,
|
1460 |
+
do_sample=True,
|
1461 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1462 |
+
)
|
1463 |
+
else:
|
1464 |
+
reference_output = self.ref_model.generate(
|
1465 |
+
input_ids=batch["prompt_input_ids"],
|
1466 |
+
attention_mask=batch["prompt_attention_mask"],
|
1467 |
+
max_length=self.max_length,
|
1468 |
+
do_sample=True,
|
1469 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1470 |
+
)
|
1471 |
+
|
1472 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1473 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1474 |
+
|
1475 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
1476 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
1477 |
+
|
1478 |
+
return policy_output_decoded, reference_output_decoded
|
1479 |
+
|
1480 |
+
def prediction_step(
|
1481 |
+
self,
|
1482 |
+
model: Union[PreTrainedModel, nn.Module],
|
1483 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1484 |
+
prediction_loss_only: bool,
|
1485 |
+
ignore_keys: Optional[list[str]] = None,
|
1486 |
+
):
|
1487 |
+
if ignore_keys is None:
|
1488 |
+
if hasattr(model, "config"):
|
1489 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1490 |
+
else:
|
1491 |
+
ignore_keys = []
|
1492 |
+
|
1493 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1494 |
+
with torch.no_grad(), prediction_context_manager:
|
1495 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
1496 |
+
|
1497 |
+
# force log the metrics
|
1498 |
+
if self.accelerator.is_main_process:
|
1499 |
+
self.store_metrics(metrics, train_eval="eval")
|
1500 |
+
|
1501 |
+
if prediction_loss_only:
|
1502 |
+
return (loss.detach(), None, None)
|
1503 |
+
|
1504 |
+
# logits for the chosen and rejected samples from model
|
1505 |
+
logits_dict = {
|
1506 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
1507 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
1508 |
+
}
|
1509 |
+
logits = torch.tensor(
|
1510 |
+
[v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
|
1511 |
+
)
|
1512 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1513 |
+
|
1514 |
+
return (loss.detach(), logits, labels)
|
1515 |
+
|
1516 |
+
def evaluation_loop(
|
1517 |
+
self,
|
1518 |
+
dataloader: DataLoader,
|
1519 |
+
description: str,
|
1520 |
+
prediction_loss_only: Optional[bool] = None,
|
1521 |
+
ignore_keys: Optional[list[str]] = None,
|
1522 |
+
metric_key_prefix: str = "eval",
|
1523 |
+
) -> EvalLoopOutput:
|
1524 |
+
"""
|
1525 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1526 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1527 |
+
|
1528 |
+
Works both with or without labels.
|
1529 |
+
"""
|
1530 |
+
|
1531 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1532 |
+
if self.generate_during_eval:
|
1533 |
+
# Generate random indices within the range of the total number of samples
|
1534 |
+
num_samples = len(dataloader.dataset)
|
1535 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1536 |
+
|
1537 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1538 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1539 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1540 |
+
random_batch = self._prepare_inputs(random_batch)
|
1541 |
+
|
1542 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
1543 |
+
target_batch = {
|
1544 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
1545 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
1546 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
1547 |
+
}
|
1548 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
1549 |
+
|
1550 |
+
table = pd.DataFrame(
|
1551 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
1552 |
+
data=[
|
1553 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
1554 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
1555 |
+
],
|
1556 |
+
)
|
1557 |
+
if "wandb" in self.args.report_to:
|
1558 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1559 |
+
|
1560 |
+
if "comet_ml" in self.args.report_to:
|
1561 |
+
log_table_to_comet_experiment(
|
1562 |
+
name="game_log.csv",
|
1563 |
+
table=table,
|
1564 |
+
)
|
1565 |
+
|
1566 |
+
# Base evaluation
|
1567 |
+
initial_output = super().evaluation_loop(
|
1568 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1569 |
+
)
|
1570 |
+
|
1571 |
+
return initial_output
|
1572 |
+
|
1573 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1574 |
+
"""
|
1575 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1576 |
+
|
1577 |
+
Args:
|
1578 |
+
logs (`dict[str, float]`):
|
1579 |
+
The values to log.
|
1580 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1581 |
+
Start time of the training.
|
1582 |
+
"""
|
1583 |
+
# logs either has 'loss' or 'eval_loss'
|
1584 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1585 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
1586 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1587 |
+
# accumulate average metrics from sums and lengths
|
1588 |
+
for split in ["chosen", "rejected"]:
|
1589 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
1590 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
1591 |
+
for metric in ["rewards", "logps", "logits"]:
|
1592 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
1593 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
1594 |
+
/ count_sum
|
1595 |
+
)
|
1596 |
+
# delete obsolete metric
|
1597 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
1598 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
1599 |
+
# calculate reward margin
|
1600 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
1601 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
1602 |
+
# Add averaged stored metrics to logs
|
1603 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1604 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
1605 |
+
del self._stored_metrics[train_eval]
|
1606 |
+
|
1607 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1608 |
+
return super().log(logs, start_time)
|
1609 |
+
else: # transformers<=4.46
|
1610 |
+
return super().log(logs)
|
1611 |
+
|
1612 |
+
def create_model_card(
|
1613 |
+
self,
|
1614 |
+
model_name: Optional[str] = None,
|
1615 |
+
dataset_name: Optional[str] = None,
|
1616 |
+
tags: Union[str, list[str], None] = None,
|
1617 |
+
):
|
1618 |
+
"""
|
1619 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1620 |
+
|
1621 |
+
Args:
|
1622 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1623 |
+
Name of the model.
|
1624 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1625 |
+
Name of the dataset used for training.
|
1626 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1627 |
+
Tags to be associated with the model card.
|
1628 |
+
"""
|
1629 |
+
if not self.is_world_process_zero():
|
1630 |
+
return
|
1631 |
+
|
1632 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1633 |
+
base_model = self.model.config._name_or_path
|
1634 |
+
else:
|
1635 |
+
base_model = None
|
1636 |
+
|
1637 |
+
tags = tags or []
|
1638 |
+
if isinstance(tags, str):
|
1639 |
+
tags = [tags]
|
1640 |
+
|
1641 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1642 |
+
tags.append("unsloth")
|
1643 |
+
|
1644 |
+
citation = textwrap.dedent("""\
|
1645 |
+
@article{ethayarajh2024kto,
|
1646 |
+
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
1647 |
+
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
1648 |
+
year = 2024,
|
1649 |
+
eprint = {arXiv:2402.01306},
|
1650 |
+
}""")
|
1651 |
+
|
1652 |
+
model_card = generate_model_card(
|
1653 |
+
base_model=base_model,
|
1654 |
+
model_name=model_name,
|
1655 |
+
hub_model_id=self.hub_model_id,
|
1656 |
+
dataset_name=dataset_name,
|
1657 |
+
tags=tags,
|
1658 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1659 |
+
comet_url=get_comet_experiment_url(),
|
1660 |
+
trainer_name="KTO",
|
1661 |
+
trainer_citation=citation,
|
1662 |
+
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
|
1663 |
+
paper_id="2402.01306",
|
1664 |
+
)
|
1665 |
+
|
1666 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1667 |
+
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
1668 |
+
"""
|
1669 |
+
|
1670 |
+
Initialize KTOTrainer.
|
1671 |
+
|
1672 |
+
Args:
|
1673 |
+
model (`transformers.PreTrainedModel`):
|
1674 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1675 |
+
ref_model (`PreTrainedModelWrapper`):
|
1676 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
1677 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
1678 |
+
args (`KTOConfig`):
|
1679 |
+
The arguments to use for training.
|
1680 |
+
train_dataset (`datasets.Dataset`):
|
1681 |
+
The dataset to use for training.
|
1682 |
+
eval_dataset (`datasets.Dataset`):
|
1683 |
+
The dataset to use for evaluation.
|
1684 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1685 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1686 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1687 |
+
reuse the fine-tuned model.
|
1688 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
1689 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1690 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1691 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1692 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1693 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1694 |
+
The callbacks to use for training.
|
1695 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1696 |
+
The optimizer and scheduler to use for training.
|
1697 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1698 |
+
The function to use to preprocess the logits before computing the metrics.
|
1699 |
+
peft_config (`dict`, defaults to `None`):
|
1700 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1701 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1702 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1703 |
+
a dictionary string to metric values.
|
1704 |
+
model_adapter_name (`str`, defaults to `None`):
|
1705 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
1706 |
+
ref_adapter_name (`str`, defaults to `None`):
|
1707 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
1708 |
+
|
1709 |
+
"""
|
1710 |
+
def __init__(
|
1711 |
+
self,
|
1712 |
+
model = None,
|
1713 |
+
ref_model = None,
|
1714 |
+
args = None,
|
1715 |
+
train_dataset = None,
|
1716 |
+
eval_dataset = None,
|
1717 |
+
processing_class = None,
|
1718 |
+
data_collator = None,
|
1719 |
+
model_init = None,
|
1720 |
+
callbacks = None,
|
1721 |
+
preprocess_logits_for_metrics = None,
|
1722 |
+
peft_config = None,
|
1723 |
+
compute_metrics = None,
|
1724 |
+
model_adapter_name = None,
|
1725 |
+
ref_adapter_name = None,
|
1726 |
+
**kwargs
|
1727 |
+
):
|
1728 |
+
if args is None: args = UnslothKTOConfig()
|
1729 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1730 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1731 |
+
force_float32 = False
|
1732 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1733 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1734 |
+
force_float32 = True
|
1735 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1736 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1737 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1738 |
+
from unsloth_zoo.utils import _get_dtype
|
1739 |
+
dtype = _get_dtype(dtype)
|
1740 |
+
float16 = dtype == torch.float16
|
1741 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1742 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1743 |
+
if force_float32:
|
1744 |
+
args.fp16 = False
|
1745 |
+
args.bf16 = False
|
1746 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1747 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1748 |
+
args.fp16 = float16
|
1749 |
+
args.bf16 = not float16
|
1750 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1751 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1752 |
+
args.eval_strategy = 'steps'
|
1753 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1754 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1755 |
+
if ga_steps is not None and ga_steps > 1:
|
1756 |
+
from transformers import __version__ as transformers_version
|
1757 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1758 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1759 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1760 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1761 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1762 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1763 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1764 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1765 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1766 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1767 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1768 |
+
if force_float32:
|
1769 |
+
args.bf16_full_eval = False
|
1770 |
+
args.fp16_full_eval = False
|
1771 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1772 |
+
args.bf16_full_eval = True
|
1773 |
+
args.fp16_full_eval = False
|
1774 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1775 |
+
args.bf16_full_eval = args.bf16
|
1776 |
+
args.fp16_full_eval = args.fp16
|
1777 |
+
_output_logits = False
|
1778 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1779 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1780 |
+
if _output_logits:
|
1781 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1782 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1783 |
+
pass
|
1784 |
+
else:
|
1785 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1786 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1787 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1788 |
+
max_seq_length = model.max_seq_length
|
1789 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1790 |
+
if model is not None and hasattr(model, 'for_training'):
|
1791 |
+
model.for_training()
|
1792 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1793 |
+
if 'processing_class' in locals():
|
1794 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1795 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1796 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1797 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1798 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1799 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1800 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1801 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1802 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1803 |
+
else:
|
1804 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1805 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1806 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1807 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1808 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1809 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1810 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1811 |
+
else:
|
1812 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1813 |
+
other_metrics = []
|
1814 |
+
|
1815 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1816 |
+
PatchRLStatistics('kto_trainer', other_metrics)
|
1817 |
+
|
1818 |
+
super().__init__(
|
1819 |
+
model = model,
|
1820 |
+
ref_model = ref_model,
|
1821 |
+
args = args,
|
1822 |
+
train_dataset = train_dataset,
|
1823 |
+
eval_dataset = eval_dataset,
|
1824 |
+
processing_class = processing_class,
|
1825 |
+
data_collator = data_collator,
|
1826 |
+
model_init = model_init,
|
1827 |
+
callbacks = callbacks,
|
1828 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1829 |
+
peft_config = peft_config,
|
1830 |
+
compute_metrics = compute_metrics,
|
1831 |
+
model_adapter_name = model_adapter_name,
|
1832 |
+
ref_adapter_name = ref_adapter_name,**kwargs)
|
1833 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1834 |
+
self.neftune_hook_handle.remove()
|
1835 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1836 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1837 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1838 |
+
pass
|
1839 |
+
|
1840 |
+
pass
|
unsloth_compiled_cache/UnslothNashMDTrainer.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothNashMDConfig(NashMDConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`NashMDTrainer`].
|
47 |
+
|
48 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
52 |
+
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
53 |
+
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
54 |
+
epochs.
|
55 |
+
|
56 |
+
"""
|
57 |
+
vllm_sampling_params: Optional[Any] = field(
|
58 |
+
default = None,
|
59 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
60 |
+
)
|
61 |
+
unsloth_num_chunks : Optional[int] = field(
|
62 |
+
default = -1,
|
63 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
64 |
+
)
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
output_dir = None,
|
68 |
+
overwrite_output_dir = None,
|
69 |
+
do_train = False,
|
70 |
+
do_eval = False,
|
71 |
+
do_predict = False,
|
72 |
+
eval_strategy = 'no',
|
73 |
+
prediction_loss_only = False,
|
74 |
+
per_device_train_batch_size = 4,
|
75 |
+
per_device_eval_batch_size = 4,
|
76 |
+
per_gpu_train_batch_size = None,
|
77 |
+
per_gpu_eval_batch_size = None,
|
78 |
+
gradient_accumulation_steps = 2,
|
79 |
+
eval_accumulation_steps = 2,
|
80 |
+
eval_delay = 0,
|
81 |
+
torch_empty_cache_steps = 250,
|
82 |
+
learning_rate = 5e-05,
|
83 |
+
weight_decay = 0.01,
|
84 |
+
adam_beta1 = 0.9,
|
85 |
+
adam_beta2 = 0.999,
|
86 |
+
adam_epsilon = 1e-08,
|
87 |
+
max_grad_norm = 1.0,
|
88 |
+
num_train_epochs = 3.0,
|
89 |
+
max_steps = -1,
|
90 |
+
lr_scheduler_type = 'linear',
|
91 |
+
warmup_ratio = 0.1,
|
92 |
+
warmup_steps = 0,
|
93 |
+
log_level = 'passive',
|
94 |
+
log_level_replica = 'warning',
|
95 |
+
log_on_each_node = True,
|
96 |
+
logging_dir = None,
|
97 |
+
logging_strategy = 'steps',
|
98 |
+
logging_first_step = False,
|
99 |
+
logging_steps = 1,
|
100 |
+
logging_nan_inf_filter = False,
|
101 |
+
save_strategy = 'steps',
|
102 |
+
save_steps = 500,
|
103 |
+
save_total_limit = None,
|
104 |
+
save_safetensors = True,
|
105 |
+
save_on_each_node = False,
|
106 |
+
save_only_model = False,
|
107 |
+
restore_callback_states_from_checkpoint = False,
|
108 |
+
no_cuda = False,
|
109 |
+
use_cpu = False,
|
110 |
+
use_mps_device = False,
|
111 |
+
seed = 3407,
|
112 |
+
data_seed = 3407,
|
113 |
+
jit_mode_eval = False,
|
114 |
+
use_ipex = False,
|
115 |
+
bf16 = False,
|
116 |
+
fp16 = False,
|
117 |
+
fp16_opt_level = 'O1',
|
118 |
+
half_precision_backend = 'auto',
|
119 |
+
bf16_full_eval = False,
|
120 |
+
fp16_full_eval = False,
|
121 |
+
tf32 = None,
|
122 |
+
local_rank = -1,
|
123 |
+
ddp_backend = None,
|
124 |
+
tpu_num_cores = None,
|
125 |
+
tpu_metrics_debug = False,
|
126 |
+
debug = '',
|
127 |
+
dataloader_drop_last = False,
|
128 |
+
eval_steps = None,
|
129 |
+
dataloader_num_workers = 0,
|
130 |
+
dataloader_prefetch_factor = None,
|
131 |
+
past_index = -1,
|
132 |
+
run_name = None,
|
133 |
+
disable_tqdm = None,
|
134 |
+
remove_unused_columns = True,
|
135 |
+
label_names = None,
|
136 |
+
load_best_model_at_end = False,
|
137 |
+
metric_for_best_model = None,
|
138 |
+
greater_is_better = None,
|
139 |
+
ignore_data_skip = False,
|
140 |
+
fsdp = '',
|
141 |
+
fsdp_min_num_params = 0,
|
142 |
+
fsdp_config = None,
|
143 |
+
tp_size = 0,
|
144 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
145 |
+
accelerator_config = None,
|
146 |
+
deepspeed = None,
|
147 |
+
label_smoothing_factor = 0.0,
|
148 |
+
optim = 'adamw_8bit',
|
149 |
+
optim_args = None,
|
150 |
+
adafactor = False,
|
151 |
+
group_by_length = False,
|
152 |
+
length_column_name = 'length',
|
153 |
+
report_to = None,
|
154 |
+
ddp_find_unused_parameters = None,
|
155 |
+
ddp_bucket_cap_mb = None,
|
156 |
+
ddp_broadcast_buffers = None,
|
157 |
+
dataloader_pin_memory = True,
|
158 |
+
dataloader_persistent_workers = False,
|
159 |
+
skip_memory_metrics = True,
|
160 |
+
use_legacy_prediction_loop = False,
|
161 |
+
push_to_hub = False,
|
162 |
+
resume_from_checkpoint = None,
|
163 |
+
hub_model_id = None,
|
164 |
+
hub_strategy = 'every_save',
|
165 |
+
hub_token = None,
|
166 |
+
hub_private_repo = None,
|
167 |
+
hub_always_push = False,
|
168 |
+
gradient_checkpointing = False,
|
169 |
+
gradient_checkpointing_kwargs = None,
|
170 |
+
include_inputs_for_metrics = False,
|
171 |
+
eval_do_concat_batches = True,
|
172 |
+
fp16_backend = 'auto',
|
173 |
+
evaluation_strategy = None,
|
174 |
+
push_to_hub_model_id = None,
|
175 |
+
push_to_hub_organization = None,
|
176 |
+
push_to_hub_token = None,
|
177 |
+
mp_parameters = '',
|
178 |
+
auto_find_batch_size = False,
|
179 |
+
full_determinism = False,
|
180 |
+
torchdynamo = None,
|
181 |
+
ray_scope = 'last',
|
182 |
+
ddp_timeout = 1800,
|
183 |
+
torch_compile = False,
|
184 |
+
torch_compile_backend = None,
|
185 |
+
torch_compile_mode = None,
|
186 |
+
dispatch_batches = None,
|
187 |
+
split_batches = None,
|
188 |
+
include_tokens_per_second = False,
|
189 |
+
include_num_input_tokens_seen = False,
|
190 |
+
neftune_noise_alpha = None,
|
191 |
+
optim_target_modules = None,
|
192 |
+
batch_eval_metrics = False,
|
193 |
+
eval_on_start = False,
|
194 |
+
use_liger_kernel = False,
|
195 |
+
eval_use_gather_object = False,
|
196 |
+
average_tokens_across_devices = False,
|
197 |
+
reward_model_path = None,
|
198 |
+
judge = None,
|
199 |
+
max_new_tokens = 64,
|
200 |
+
max_length = 512,
|
201 |
+
temperature = 0.9,
|
202 |
+
missing_eos_penalty = None,
|
203 |
+
loss_type = 'sigmoid',
|
204 |
+
dataset_num_proc = None,
|
205 |
+
disable_dropout = True,
|
206 |
+
use_vllm = False,
|
207 |
+
ds3_gather_for_generation = True,
|
208 |
+
vllm_sampling_params = None,
|
209 |
+
unsloth_num_chunks = -1,
|
210 |
+
**kwargs,
|
211 |
+
):
|
212 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
213 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
214 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
215 |
+
output_dir = 'unsloth_training_checkpoints'
|
216 |
+
save_strategy = 'no'
|
217 |
+
if dataset_num_proc is None:
|
218 |
+
from multiprocessing import cpu_count
|
219 |
+
dataset_num_proc = cpu_count()
|
220 |
+
|
221 |
+
super().__init__(
|
222 |
+
output_dir = output_dir,
|
223 |
+
overwrite_output_dir = overwrite_output_dir,
|
224 |
+
do_train = do_train,
|
225 |
+
do_eval = do_eval,
|
226 |
+
do_predict = do_predict,
|
227 |
+
eval_strategy = eval_strategy,
|
228 |
+
prediction_loss_only = prediction_loss_only,
|
229 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
230 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
231 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
232 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
233 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
234 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
235 |
+
eval_delay = eval_delay,
|
236 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
237 |
+
learning_rate = learning_rate,
|
238 |
+
weight_decay = weight_decay,
|
239 |
+
adam_beta1 = adam_beta1,
|
240 |
+
adam_beta2 = adam_beta2,
|
241 |
+
adam_epsilon = adam_epsilon,
|
242 |
+
max_grad_norm = max_grad_norm,
|
243 |
+
num_train_epochs = num_train_epochs,
|
244 |
+
max_steps = max_steps,
|
245 |
+
lr_scheduler_type = lr_scheduler_type,
|
246 |
+
warmup_ratio = warmup_ratio,
|
247 |
+
warmup_steps = warmup_steps,
|
248 |
+
log_level = log_level,
|
249 |
+
log_level_replica = log_level_replica,
|
250 |
+
log_on_each_node = log_on_each_node,
|
251 |
+
logging_dir = logging_dir,
|
252 |
+
logging_strategy = logging_strategy,
|
253 |
+
logging_first_step = logging_first_step,
|
254 |
+
logging_steps = logging_steps,
|
255 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
256 |
+
save_strategy = save_strategy,
|
257 |
+
save_steps = save_steps,
|
258 |
+
save_total_limit = save_total_limit,
|
259 |
+
save_safetensors = save_safetensors,
|
260 |
+
save_on_each_node = save_on_each_node,
|
261 |
+
save_only_model = save_only_model,
|
262 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
263 |
+
no_cuda = no_cuda,
|
264 |
+
use_cpu = use_cpu,
|
265 |
+
use_mps_device = use_mps_device,
|
266 |
+
seed = seed,
|
267 |
+
data_seed = data_seed,
|
268 |
+
jit_mode_eval = jit_mode_eval,
|
269 |
+
use_ipex = use_ipex,
|
270 |
+
bf16 = bf16,
|
271 |
+
fp16 = fp16,
|
272 |
+
fp16_opt_level = fp16_opt_level,
|
273 |
+
half_precision_backend = half_precision_backend,
|
274 |
+
bf16_full_eval = bf16_full_eval,
|
275 |
+
fp16_full_eval = fp16_full_eval,
|
276 |
+
tf32 = tf32,
|
277 |
+
local_rank = local_rank,
|
278 |
+
ddp_backend = ddp_backend,
|
279 |
+
tpu_num_cores = tpu_num_cores,
|
280 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
281 |
+
debug = debug,
|
282 |
+
dataloader_drop_last = dataloader_drop_last,
|
283 |
+
eval_steps = eval_steps,
|
284 |
+
dataloader_num_workers = dataloader_num_workers,
|
285 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
286 |
+
past_index = past_index,
|
287 |
+
run_name = run_name,
|
288 |
+
disable_tqdm = disable_tqdm,
|
289 |
+
remove_unused_columns = remove_unused_columns,
|
290 |
+
label_names = label_names,
|
291 |
+
load_best_model_at_end = load_best_model_at_end,
|
292 |
+
metric_for_best_model = metric_for_best_model,
|
293 |
+
greater_is_better = greater_is_better,
|
294 |
+
ignore_data_skip = ignore_data_skip,
|
295 |
+
fsdp = fsdp,
|
296 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
297 |
+
fsdp_config = fsdp_config,
|
298 |
+
tp_size = tp_size,
|
299 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
300 |
+
accelerator_config = accelerator_config,
|
301 |
+
deepspeed = deepspeed,
|
302 |
+
label_smoothing_factor = label_smoothing_factor,
|
303 |
+
optim = optim,
|
304 |
+
optim_args = optim_args,
|
305 |
+
adafactor = adafactor,
|
306 |
+
group_by_length = group_by_length,
|
307 |
+
length_column_name = length_column_name,
|
308 |
+
report_to = report_to,
|
309 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
310 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
311 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
312 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
313 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
314 |
+
skip_memory_metrics = skip_memory_metrics,
|
315 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
316 |
+
push_to_hub = push_to_hub,
|
317 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
318 |
+
hub_model_id = hub_model_id,
|
319 |
+
hub_strategy = hub_strategy,
|
320 |
+
hub_token = hub_token,
|
321 |
+
hub_private_repo = hub_private_repo,
|
322 |
+
hub_always_push = hub_always_push,
|
323 |
+
gradient_checkpointing = gradient_checkpointing,
|
324 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
325 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
326 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
327 |
+
fp16_backend = fp16_backend,
|
328 |
+
evaluation_strategy = evaluation_strategy,
|
329 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
330 |
+
push_to_hub_organization = push_to_hub_organization,
|
331 |
+
push_to_hub_token = push_to_hub_token,
|
332 |
+
mp_parameters = mp_parameters,
|
333 |
+
auto_find_batch_size = auto_find_batch_size,
|
334 |
+
full_determinism = full_determinism,
|
335 |
+
torchdynamo = torchdynamo,
|
336 |
+
ray_scope = ray_scope,
|
337 |
+
ddp_timeout = ddp_timeout,
|
338 |
+
torch_compile = torch_compile,
|
339 |
+
torch_compile_backend = torch_compile_backend,
|
340 |
+
torch_compile_mode = torch_compile_mode,
|
341 |
+
dispatch_batches = dispatch_batches,
|
342 |
+
split_batches = split_batches,
|
343 |
+
include_tokens_per_second = include_tokens_per_second,
|
344 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
345 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
346 |
+
optim_target_modules = optim_target_modules,
|
347 |
+
batch_eval_metrics = batch_eval_metrics,
|
348 |
+
eval_on_start = eval_on_start,
|
349 |
+
use_liger_kernel = use_liger_kernel,
|
350 |
+
eval_use_gather_object = eval_use_gather_object,
|
351 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
352 |
+
reward_model_path = reward_model_path,
|
353 |
+
judge = judge,
|
354 |
+
max_new_tokens = max_new_tokens,
|
355 |
+
max_length = max_length,
|
356 |
+
temperature = temperature,
|
357 |
+
missing_eos_penalty = missing_eos_penalty,
|
358 |
+
loss_type = loss_type,
|
359 |
+
dataset_num_proc = dataset_num_proc,
|
360 |
+
disable_dropout = disable_dropout,
|
361 |
+
use_vllm = use_vllm,
|
362 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
363 |
+
self.vllm_sampling_params = vllm_sampling_params
|
364 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
365 |
+
pass
|
366 |
+
|
367 |
+
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
368 |
+
r""""""
|
369 |
+
|
370 |
+
_tag_names = ["trl", "nash-md"]
|
371 |
+
|
372 |
+
def __init__(
|
373 |
+
self,
|
374 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
375 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
376 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
377 |
+
judge: Optional[BasePairwiseJudge] = None,
|
378 |
+
args: Optional[NashMDConfig] = None,
|
379 |
+
data_collator: Optional[Callable] = None,
|
380 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
381 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
382 |
+
processing_class: Optional[
|
383 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
384 |
+
] = None,
|
385 |
+
peft_config: Optional[dict] = None,
|
386 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
387 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
388 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
389 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
390 |
+
) -> None:
|
391 |
+
super().__init__(
|
392 |
+
model=model,
|
393 |
+
ref_model=ref_model,
|
394 |
+
reward_model=reward_model,
|
395 |
+
judge=judge,
|
396 |
+
args=args,
|
397 |
+
data_collator=data_collator,
|
398 |
+
train_dataset=train_dataset,
|
399 |
+
eval_dataset=eval_dataset,
|
400 |
+
processing_class=processing_class,
|
401 |
+
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
402 |
+
peft_config=peft_config,
|
403 |
+
compute_metrics=compute_metrics,
|
404 |
+
callbacks=callbacks,
|
405 |
+
optimizers=optimizers,
|
406 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
407 |
+
)
|
408 |
+
|
409 |
+
self._mixture_coef = self.args.mixture_coef
|
410 |
+
|
411 |
+
# Overwrite the stats dictionary to include NashMD specific statistics
|
412 |
+
self.stats = {
|
413 |
+
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
414 |
+
# Add "mixture_coef"
|
415 |
+
"loss/kl": [],
|
416 |
+
"objective/entropy": [],
|
417 |
+
"loss/score": [],
|
418 |
+
"rewards/probabilities": [],
|
419 |
+
"rewards/accuracies": [],
|
420 |
+
"rewards/margins": [],
|
421 |
+
"logps/chosen": [],
|
422 |
+
"logps/rejected": [],
|
423 |
+
"val/model_contain_eos_token": [],
|
424 |
+
"val/ref_contain_eos_token": [],
|
425 |
+
"beta": [],
|
426 |
+
"mixture_coef": [],
|
427 |
+
}
|
428 |
+
if self.reward_model is not None:
|
429 |
+
self.stats["rewards/chosen"] = []
|
430 |
+
self.stats["rewards/rejected"] = []
|
431 |
+
|
432 |
+
@property
|
433 |
+
def mixture_coef(self):
|
434 |
+
if isinstance(self._mixture_coef, list):
|
435 |
+
epoch = self.state.epoch
|
436 |
+
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
437 |
+
else:
|
438 |
+
return self._mixture_coef
|
439 |
+
|
440 |
+
def _generate_completions(self, model, prompts):
|
441 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
442 |
+
model_output = unwrapped_model.generate(
|
443 |
+
input_ids=prompts["input_ids"],
|
444 |
+
attention_mask=prompts["attention_mask"],
|
445 |
+
generation_config=self.generation_config,
|
446 |
+
)
|
447 |
+
|
448 |
+
ref_model = model if self.ref_model is None else self.ref_model
|
449 |
+
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
450 |
+
mixture_model = GeometricMixtureWrapper(
|
451 |
+
model=unwrapped_model,
|
452 |
+
ref_model=unwrapped_ref_model,
|
453 |
+
generation_config=self.generation_config,
|
454 |
+
mixture_coef=self.mixture_coef,
|
455 |
+
device=self.accelerator.device,
|
456 |
+
)
|
457 |
+
|
458 |
+
mixture_output = mixture_model.generate(
|
459 |
+
input_ids=prompts["input_ids"],
|
460 |
+
attention_mask=prompts["attention_mask"],
|
461 |
+
generation_config=self.generation_config,
|
462 |
+
)
|
463 |
+
|
464 |
+
return model_output, mixture_output
|
465 |
+
|
466 |
+
def _process_completions(self, model_output, mixture_output, prompts):
|
467 |
+
context_length = prompts["input_ids"].shape[1]
|
468 |
+
|
469 |
+
# Process model completions
|
470 |
+
model_completion_ids = model_output[:, context_length:]
|
471 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
472 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
473 |
+
)
|
474 |
+
model_data = {
|
475 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
476 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
477 |
+
"raw": prompts["raw"],
|
478 |
+
}
|
479 |
+
|
480 |
+
# Process reference model completions
|
481 |
+
mixture_completion_ids = mixture_output[:, context_length:]
|
482 |
+
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
483 |
+
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
484 |
+
)
|
485 |
+
mixture_data = {
|
486 |
+
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
487 |
+
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
488 |
+
"raw": prompts["raw"],
|
489 |
+
}
|
490 |
+
|
491 |
+
return model_data, mixture_data
|
492 |
+
|
493 |
+
def _compute_rewards(self, model_data, mixture_data, context_length):
|
494 |
+
with torch.no_grad():
|
495 |
+
_, model_scores, _ = get_reward(
|
496 |
+
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
497 |
+
)
|
498 |
+
_, mixture_scores, _ = get_reward(
|
499 |
+
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
500 |
+
)
|
501 |
+
|
502 |
+
# Apply EOS penalty if needed
|
503 |
+
if self.args.missing_eos_penalty is not None:
|
504 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
505 |
+
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
506 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
507 |
+
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
508 |
+
|
509 |
+
return model_scores, mixture_scores
|
510 |
+
|
511 |
+
def _compute_judge(self, model_data, mixture_data, context_length):
|
512 |
+
prompts = model_data["raw"]
|
513 |
+
model_data_completions = self.processing_class.batch_decode(
|
514 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
515 |
+
)
|
516 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
517 |
+
|
518 |
+
mixture_data_completions = self.processing_class.batch_decode(
|
519 |
+
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
520 |
+
)
|
521 |
+
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
522 |
+
if is_conversational({"prompt": prompts[0]}):
|
523 |
+
model_data_completions = [
|
524 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
525 |
+
]
|
526 |
+
environment = jinja2.Environment()
|
527 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
528 |
+
prompts = [template.render(messages=message) for message in prompts]
|
529 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
530 |
+
|
531 |
+
mixture_data_completions = [
|
532 |
+
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
533 |
+
]
|
534 |
+
mixture_data_completions = [
|
535 |
+
template.render(messages=completion) for completion in mixture_data_completions
|
536 |
+
]
|
537 |
+
|
538 |
+
probability = self.judge.judge(
|
539 |
+
prompts,
|
540 |
+
list(zip(model_data_completions, mixture_data_completions)),
|
541 |
+
return_scores=True,
|
542 |
+
)
|
543 |
+
return torch.tensor(probability, device=model_data["input_ids"].device)
|
544 |
+
|
545 |
+
def _compute_logprobs(self, model, model_data, context_length):
|
546 |
+
def compute_logprobs_for_data(m, data):
|
547 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
548 |
+
logits = output.logits[:, context_length - 1 : -1]
|
549 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
550 |
+
return token_logprobs
|
551 |
+
|
552 |
+
# Compute logprobs for model completions under the model
|
553 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
554 |
+
|
555 |
+
# Compute logprobs of model completions under the reference model
|
556 |
+
with torch.no_grad():
|
557 |
+
if self.ref_model is None:
|
558 |
+
with model.disable_adapter():
|
559 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
560 |
+
else:
|
561 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
562 |
+
|
563 |
+
# Mask padding tokens
|
564 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
565 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
566 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
567 |
+
|
568 |
+
return (model_logprobs_model_data, ref_logprobs_model_data)
|
569 |
+
|
570 |
+
def _compute_losses(
|
571 |
+
self,
|
572 |
+
model_logprobs_model_data,
|
573 |
+
ref_logprobs_model_data,
|
574 |
+
probability,
|
575 |
+
):
|
576 |
+
# reinforce score where 0.5 is a control variate
|
577 |
+
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
578 |
+
|
579 |
+
# kl divergence via reinforce
|
580 |
+
with torch.no_grad():
|
581 |
+
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
582 |
+
kl_div_log = log_ratio.sum(1)
|
583 |
+
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
584 |
+
|
585 |
+
# final loss
|
586 |
+
loss = self.beta * kl_div_loss - score
|
587 |
+
|
588 |
+
return loss.mean(), score, kl_div_log
|
589 |
+
|
590 |
+
def _log_statistics(
|
591 |
+
self,
|
592 |
+
model_data,
|
593 |
+
mixture_data,
|
594 |
+
model_logprobs_model_data,
|
595 |
+
ref_logprobs_model_data,
|
596 |
+
probability,
|
597 |
+
score,
|
598 |
+
kl_div,
|
599 |
+
context_length,
|
600 |
+
model_scores=None,
|
601 |
+
mixture_scores=None,
|
602 |
+
):
|
603 |
+
# Helper function to gather and compute mean
|
604 |
+
def gather_mean(tensor):
|
605 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
606 |
+
|
607 |
+
# Log score
|
608 |
+
self.stats["loss/score"].append(gather_mean(score))
|
609 |
+
# Log KL divergence
|
610 |
+
self.stats["loss/kl"].append(gather_mean(kl_div))
|
611 |
+
|
612 |
+
# Log logprobs
|
613 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
614 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
615 |
+
|
616 |
+
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
617 |
+
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
618 |
+
|
619 |
+
# Log rewards
|
620 |
+
if self.reward_model is not None:
|
621 |
+
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
622 |
+
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
623 |
+
|
624 |
+
# Log probabilities
|
625 |
+
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
626 |
+
|
627 |
+
# Calculate entropy for model data
|
628 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
629 |
+
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
630 |
+
|
631 |
+
# Calculate margins
|
632 |
+
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
633 |
+
self.stats["rewards/margins"].append(gather_mean(margin))
|
634 |
+
|
635 |
+
# Calculate accuracy
|
636 |
+
accuracy = (margin > 0).float()
|
637 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
638 |
+
|
639 |
+
# Log EOS token statistics
|
640 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
641 |
+
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
642 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
643 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
644 |
+
|
645 |
+
# Log beta and mixture coef
|
646 |
+
self.stats["beta"].append(self.beta)
|
647 |
+
self.stats["mixture_coef"].append(self.mixture_coef)
|
648 |
+
|
649 |
+
def training_step(
|
650 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
651 |
+
) -> torch.Tensor:
|
652 |
+
model.train()
|
653 |
+
|
654 |
+
# Apply chat template and tokenize the input
|
655 |
+
batch_size = len(next(iter(inputs.values())))
|
656 |
+
prompts = inputs["prompt"]
|
657 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
658 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
659 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
660 |
+
inputs = self.data_collator(inputs)
|
661 |
+
|
662 |
+
# need the prompt_ only
|
663 |
+
inputs = self._prepare_inputs(inputs)
|
664 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
665 |
+
prompts = {
|
666 |
+
"input_ids": inputs["prompt_input_ids"],
|
667 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
668 |
+
"raw": prompts,
|
669 |
+
}
|
670 |
+
del inputs
|
671 |
+
|
672 |
+
# Sample completions from both the model and the reference model
|
673 |
+
model_output, mixture_output = self._generate_completions(model, prompts)
|
674 |
+
|
675 |
+
# Process model completions
|
676 |
+
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
677 |
+
|
678 |
+
# Compute rewards
|
679 |
+
if self.reward_model is not None:
|
680 |
+
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
681 |
+
# probability of the model data vs the mixture data
|
682 |
+
probability = F.sigmoid(model_scores - mixture_scores)
|
683 |
+
else:
|
684 |
+
model_scores, mixture_scores = None, None
|
685 |
+
probability = self._compute_judge(model_data, mixture_data, context_length)
|
686 |
+
|
687 |
+
# Compute logprobs
|
688 |
+
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
689 |
+
|
690 |
+
# Compute loss
|
691 |
+
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
692 |
+
|
693 |
+
# Log everything
|
694 |
+
self._log_statistics(
|
695 |
+
model_data,
|
696 |
+
mixture_data,
|
697 |
+
model_logprobs_model_data.detach(),
|
698 |
+
ref_logprobs_model_data,
|
699 |
+
probability,
|
700 |
+
score.detach(),
|
701 |
+
kl_div.detach(),
|
702 |
+
context_length,
|
703 |
+
model_scores,
|
704 |
+
mixture_scores,
|
705 |
+
)
|
706 |
+
|
707 |
+
if (
|
708 |
+
self.args.torch_empty_cache_steps is not None
|
709 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
710 |
+
):
|
711 |
+
empty_cache()
|
712 |
+
|
713 |
+
kwargs = {}
|
714 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
715 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
716 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
717 |
+
|
718 |
+
if self.args.n_gpu > 1:
|
719 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
720 |
+
|
721 |
+
if self.use_apex:
|
722 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
723 |
+
scaled_loss.backward()
|
724 |
+
else:
|
725 |
+
self.accelerator.backward(loss, **kwargs)
|
726 |
+
|
727 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
728 |
+
|
729 |
+
def create_model_card(
|
730 |
+
self,
|
731 |
+
model_name: Optional[str] = None,
|
732 |
+
dataset_name: Optional[str] = None,
|
733 |
+
tags: Union[str, list[str], None] = None,
|
734 |
+
):
|
735 |
+
"""
|
736 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
737 |
+
|
738 |
+
Args:
|
739 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
740 |
+
Name of the model.
|
741 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
742 |
+
Name of the dataset used for training.
|
743 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
744 |
+
Tags to be associated with the model card.
|
745 |
+
"""
|
746 |
+
if not self.is_world_process_zero():
|
747 |
+
return
|
748 |
+
|
749 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
750 |
+
base_model = self.model.config._name_or_path
|
751 |
+
else:
|
752 |
+
base_model = None
|
753 |
+
|
754 |
+
tags = tags or []
|
755 |
+
if isinstance(tags, str):
|
756 |
+
tags = [tags]
|
757 |
+
|
758 |
+
if hasattr(self.model.config, "unsloth_version"):
|
759 |
+
tags.append("unsloth")
|
760 |
+
|
761 |
+
citation = textwrap.dedent("""\
|
762 |
+
@inproceedings{munos2024nash,
|
763 |
+
title = {{Nash Learning from Human Feedback}},
|
764 |
+
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
765 |
+
year = 2024,
|
766 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
767 |
+
publisher = {OpenReview.net},
|
768 |
+
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
769 |
+
}""")
|
770 |
+
|
771 |
+
model_card = generate_model_card(
|
772 |
+
base_model=base_model,
|
773 |
+
model_name=model_name,
|
774 |
+
hub_model_id=self.hub_model_id,
|
775 |
+
dataset_name=dataset_name,
|
776 |
+
tags=tags,
|
777 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
778 |
+
comet_url=get_comet_experiment_url(),
|
779 |
+
trainer_name="Nash-MD",
|
780 |
+
trainer_citation=citation,
|
781 |
+
paper_title="Nash Learning from Human Feedback",
|
782 |
+
paper_id="2312.00886",
|
783 |
+
)
|
784 |
+
|
785 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
786 |
+
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
787 |
+
"""
|
788 |
+
|
789 |
+
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
790 |
+
|
791 |
+
Args:
|
792 |
+
model (`transformers.PreTrainedModel`):
|
793 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
794 |
+
ref_model (`PreTrainedModelWrapper`):
|
795 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
796 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
797 |
+
reward_model (`transformers.PreTrainedModel`):
|
798 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
799 |
+
judge (`BasePairwiseJudge`):
|
800 |
+
The judge to use for pairwise comparison of model completions.
|
801 |
+
args (`NashMDConfig`):
|
802 |
+
The NashMD config arguments to use for training.
|
803 |
+
data_collator (`transformers.DataCollator`):
|
804 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
805 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
806 |
+
train_dataset (`datasets.Dataset`):
|
807 |
+
The dataset to use for training.
|
808 |
+
eval_dataset (`datasets.Dataset`):
|
809 |
+
The dataset to use for evaluation.
|
810 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
811 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
812 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
813 |
+
reuse the fine-tuned model.
|
814 |
+
peft_config (`dict`):
|
815 |
+
The peft config to use for training.
|
816 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
817 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
818 |
+
a dictionary string to metric values.
|
819 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
820 |
+
The callbacks to use for training.
|
821 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
822 |
+
The optimizer and scheduler to use for training.
|
823 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
824 |
+
The function to use to preprocess the logits before computing the metrics.
|
825 |
+
|
826 |
+
"""
|
827 |
+
def __init__(
|
828 |
+
self,
|
829 |
+
model = None,
|
830 |
+
ref_model = None,
|
831 |
+
reward_model = None,
|
832 |
+
judge = None,
|
833 |
+
args = None,
|
834 |
+
data_collator = None,
|
835 |
+
train_dataset = None,
|
836 |
+
eval_dataset = None,
|
837 |
+
processing_class = None,
|
838 |
+
peft_config = None,
|
839 |
+
compute_metrics = None,
|
840 |
+
callbacks = None,
|
841 |
+
preprocess_logits_for_metrics = None,
|
842 |
+
**kwargs
|
843 |
+
):
|
844 |
+
if args is None: args = UnslothNashMDConfig()
|
845 |
+
use_bf16 = getattr(args, 'bf16', False)
|
846 |
+
use_fp16 = getattr(args, 'fp16', False)
|
847 |
+
force_float32 = False
|
848 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
849 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
850 |
+
force_float32 = True
|
851 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
852 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
853 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
854 |
+
from unsloth_zoo.utils import _get_dtype
|
855 |
+
dtype = _get_dtype(dtype)
|
856 |
+
float16 = dtype == torch.float16
|
857 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
858 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
859 |
+
if force_float32:
|
860 |
+
args.fp16 = False
|
861 |
+
args.bf16 = False
|
862 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
863 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
864 |
+
args.fp16 = float16
|
865 |
+
args.bf16 = not float16
|
866 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
867 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
868 |
+
args.eval_strategy = 'steps'
|
869 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
870 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
871 |
+
if ga_steps is not None and ga_steps > 1:
|
872 |
+
from transformers import __version__ as transformers_version
|
873 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
874 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
875 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
876 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
877 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
878 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
879 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
880 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
881 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
882 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
883 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
884 |
+
if force_float32:
|
885 |
+
args.bf16_full_eval = False
|
886 |
+
args.fp16_full_eval = False
|
887 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
888 |
+
args.bf16_full_eval = True
|
889 |
+
args.fp16_full_eval = False
|
890 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
891 |
+
args.bf16_full_eval = args.bf16
|
892 |
+
args.fp16_full_eval = args.fp16
|
893 |
+
_output_logits = False
|
894 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
895 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
896 |
+
if _output_logits:
|
897 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
898 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
899 |
+
pass
|
900 |
+
else:
|
901 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
902 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
903 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
904 |
+
max_seq_length = model.max_seq_length
|
905 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
906 |
+
if model is not None and hasattr(model, 'for_training'):
|
907 |
+
model.for_training()
|
908 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
909 |
+
if 'processing_class' in locals():
|
910 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
911 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
912 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
913 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
914 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
915 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
916 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
917 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
918 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
919 |
+
else:
|
920 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
921 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
922 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
923 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
924 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
925 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
926 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
927 |
+
else:
|
928 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
929 |
+
other_metrics = []
|
930 |
+
|
931 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
932 |
+
PatchRLStatistics('nash_md_trainer', other_metrics)
|
933 |
+
|
934 |
+
super().__init__(
|
935 |
+
model = model,
|
936 |
+
ref_model = ref_model,
|
937 |
+
reward_model = reward_model,
|
938 |
+
judge = judge,
|
939 |
+
args = args,
|
940 |
+
data_collator = data_collator,
|
941 |
+
train_dataset = train_dataset,
|
942 |
+
eval_dataset = eval_dataset,
|
943 |
+
processing_class = processing_class,
|
944 |
+
peft_config = peft_config,
|
945 |
+
compute_metrics = compute_metrics,
|
946 |
+
callbacks = callbacks,
|
947 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
948 |
+
if hasattr(self, 'neftune_hook_handle'):
|
949 |
+
self.neftune_hook_handle.remove()
|
950 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
951 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
952 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
953 |
+
pass
|
954 |
+
|
955 |
+
pass
|
unsloth_compiled_cache/UnslothORPOTrainer.py
ADDED
@@ -0,0 +1,1543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, transformers, version, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothORPOConfig(ORPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`ORPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
58 |
+
to use the default data collator.
|
59 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
60 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
61 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
62 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
63 |
+
and your model is an encoder-decoder.
|
64 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
65 |
+
Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
|
66 |
+
it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
67 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
68 |
+
Whether to disable dropout in the model.
|
69 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
70 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
71 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
72 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
73 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
74 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
75 |
+
This argument is required if you want to use the default data collator.
|
76 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
77 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
78 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
79 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
80 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
81 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
82 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
83 |
+
string.
|
84 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
85 |
+
Number of processes to use for processing the dataset.
|
86 |
+
|
87 |
+
"""
|
88 |
+
vllm_sampling_params: Optional[Any] = field(
|
89 |
+
default = None,
|
90 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
91 |
+
)
|
92 |
+
unsloth_num_chunks : Optional[int] = field(
|
93 |
+
default = -1,
|
94 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
95 |
+
)
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
output_dir = None,
|
99 |
+
overwrite_output_dir = None,
|
100 |
+
do_train = False,
|
101 |
+
do_eval = False,
|
102 |
+
do_predict = False,
|
103 |
+
eval_strategy = 'no',
|
104 |
+
prediction_loss_only = False,
|
105 |
+
per_device_train_batch_size = 4,
|
106 |
+
per_device_eval_batch_size = 4,
|
107 |
+
per_gpu_train_batch_size = None,
|
108 |
+
per_gpu_eval_batch_size = None,
|
109 |
+
gradient_accumulation_steps = 2,
|
110 |
+
eval_accumulation_steps = 2,
|
111 |
+
eval_delay = 0,
|
112 |
+
torch_empty_cache_steps = 250,
|
113 |
+
learning_rate = 5e-05,
|
114 |
+
weight_decay = 0.01,
|
115 |
+
adam_beta1 = 0.9,
|
116 |
+
adam_beta2 = 0.999,
|
117 |
+
adam_epsilon = 1e-08,
|
118 |
+
max_grad_norm = 1.0,
|
119 |
+
num_train_epochs = 3.0,
|
120 |
+
max_steps = -1,
|
121 |
+
lr_scheduler_type = 'linear',
|
122 |
+
warmup_ratio = 0.1,
|
123 |
+
warmup_steps = 0,
|
124 |
+
log_level = 'passive',
|
125 |
+
log_level_replica = 'warning',
|
126 |
+
log_on_each_node = True,
|
127 |
+
logging_dir = None,
|
128 |
+
logging_strategy = 'steps',
|
129 |
+
logging_first_step = False,
|
130 |
+
logging_steps = 1,
|
131 |
+
logging_nan_inf_filter = False,
|
132 |
+
save_strategy = 'steps',
|
133 |
+
save_steps = 500,
|
134 |
+
save_total_limit = None,
|
135 |
+
save_safetensors = True,
|
136 |
+
save_on_each_node = False,
|
137 |
+
save_only_model = False,
|
138 |
+
restore_callback_states_from_checkpoint = False,
|
139 |
+
no_cuda = False,
|
140 |
+
use_cpu = False,
|
141 |
+
use_mps_device = False,
|
142 |
+
seed = 3407,
|
143 |
+
data_seed = 3407,
|
144 |
+
jit_mode_eval = False,
|
145 |
+
use_ipex = False,
|
146 |
+
bf16 = False,
|
147 |
+
fp16 = False,
|
148 |
+
fp16_opt_level = 'O1',
|
149 |
+
half_precision_backend = 'auto',
|
150 |
+
bf16_full_eval = False,
|
151 |
+
fp16_full_eval = False,
|
152 |
+
tf32 = None,
|
153 |
+
local_rank = -1,
|
154 |
+
ddp_backend = None,
|
155 |
+
tpu_num_cores = None,
|
156 |
+
tpu_metrics_debug = False,
|
157 |
+
debug = '',
|
158 |
+
dataloader_drop_last = False,
|
159 |
+
eval_steps = None,
|
160 |
+
dataloader_num_workers = 0,
|
161 |
+
dataloader_prefetch_factor = None,
|
162 |
+
past_index = -1,
|
163 |
+
run_name = None,
|
164 |
+
disable_tqdm = None,
|
165 |
+
remove_unused_columns = True,
|
166 |
+
label_names = None,
|
167 |
+
load_best_model_at_end = False,
|
168 |
+
metric_for_best_model = None,
|
169 |
+
greater_is_better = None,
|
170 |
+
ignore_data_skip = False,
|
171 |
+
fsdp = '',
|
172 |
+
fsdp_min_num_params = 0,
|
173 |
+
fsdp_config = None,
|
174 |
+
tp_size = 0,
|
175 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
176 |
+
accelerator_config = None,
|
177 |
+
deepspeed = None,
|
178 |
+
label_smoothing_factor = 0.0,
|
179 |
+
optim = 'adamw_8bit',
|
180 |
+
optim_args = None,
|
181 |
+
adafactor = False,
|
182 |
+
group_by_length = False,
|
183 |
+
length_column_name = 'length',
|
184 |
+
report_to = None,
|
185 |
+
ddp_find_unused_parameters = None,
|
186 |
+
ddp_bucket_cap_mb = None,
|
187 |
+
ddp_broadcast_buffers = None,
|
188 |
+
dataloader_pin_memory = True,
|
189 |
+
dataloader_persistent_workers = False,
|
190 |
+
skip_memory_metrics = True,
|
191 |
+
use_legacy_prediction_loop = False,
|
192 |
+
push_to_hub = False,
|
193 |
+
resume_from_checkpoint = None,
|
194 |
+
hub_model_id = None,
|
195 |
+
hub_strategy = 'every_save',
|
196 |
+
hub_token = None,
|
197 |
+
hub_private_repo = None,
|
198 |
+
hub_always_push = False,
|
199 |
+
gradient_checkpointing = False,
|
200 |
+
gradient_checkpointing_kwargs = None,
|
201 |
+
include_inputs_for_metrics = False,
|
202 |
+
eval_do_concat_batches = True,
|
203 |
+
fp16_backend = 'auto',
|
204 |
+
evaluation_strategy = None,
|
205 |
+
push_to_hub_model_id = None,
|
206 |
+
push_to_hub_organization = None,
|
207 |
+
push_to_hub_token = None,
|
208 |
+
mp_parameters = '',
|
209 |
+
auto_find_batch_size = False,
|
210 |
+
full_determinism = False,
|
211 |
+
torchdynamo = None,
|
212 |
+
ray_scope = 'last',
|
213 |
+
ddp_timeout = 1800,
|
214 |
+
torch_compile = False,
|
215 |
+
torch_compile_backend = None,
|
216 |
+
torch_compile_mode = None,
|
217 |
+
dispatch_batches = None,
|
218 |
+
split_batches = None,
|
219 |
+
include_tokens_per_second = False,
|
220 |
+
include_num_input_tokens_seen = False,
|
221 |
+
neftune_noise_alpha = None,
|
222 |
+
optim_target_modules = None,
|
223 |
+
batch_eval_metrics = False,
|
224 |
+
eval_on_start = False,
|
225 |
+
use_liger_kernel = False,
|
226 |
+
eval_use_gather_object = False,
|
227 |
+
average_tokens_across_devices = False,
|
228 |
+
max_length = 1024,
|
229 |
+
max_prompt_length = 512,
|
230 |
+
max_completion_length = None,
|
231 |
+
beta = 0.1,
|
232 |
+
disable_dropout = True,
|
233 |
+
label_pad_token_id = -100,
|
234 |
+
padding_value = None,
|
235 |
+
truncation_mode = 'keep_end',
|
236 |
+
generate_during_eval = False,
|
237 |
+
is_encoder_decoder = None,
|
238 |
+
model_init_kwargs = None,
|
239 |
+
dataset_num_proc = None,
|
240 |
+
vllm_sampling_params = None,
|
241 |
+
unsloth_num_chunks = -1,
|
242 |
+
**kwargs,
|
243 |
+
):
|
244 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
245 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
246 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
247 |
+
output_dir = 'unsloth_training_checkpoints'
|
248 |
+
save_strategy = 'no'
|
249 |
+
if dataset_num_proc is None:
|
250 |
+
from multiprocessing import cpu_count
|
251 |
+
dataset_num_proc = cpu_count()
|
252 |
+
|
253 |
+
super().__init__(
|
254 |
+
output_dir = output_dir,
|
255 |
+
overwrite_output_dir = overwrite_output_dir,
|
256 |
+
do_train = do_train,
|
257 |
+
do_eval = do_eval,
|
258 |
+
do_predict = do_predict,
|
259 |
+
eval_strategy = eval_strategy,
|
260 |
+
prediction_loss_only = prediction_loss_only,
|
261 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
262 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
263 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
264 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
265 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
266 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
267 |
+
eval_delay = eval_delay,
|
268 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
269 |
+
learning_rate = learning_rate,
|
270 |
+
weight_decay = weight_decay,
|
271 |
+
adam_beta1 = adam_beta1,
|
272 |
+
adam_beta2 = adam_beta2,
|
273 |
+
adam_epsilon = adam_epsilon,
|
274 |
+
max_grad_norm = max_grad_norm,
|
275 |
+
num_train_epochs = num_train_epochs,
|
276 |
+
max_steps = max_steps,
|
277 |
+
lr_scheduler_type = lr_scheduler_type,
|
278 |
+
warmup_ratio = warmup_ratio,
|
279 |
+
warmup_steps = warmup_steps,
|
280 |
+
log_level = log_level,
|
281 |
+
log_level_replica = log_level_replica,
|
282 |
+
log_on_each_node = log_on_each_node,
|
283 |
+
logging_dir = logging_dir,
|
284 |
+
logging_strategy = logging_strategy,
|
285 |
+
logging_first_step = logging_first_step,
|
286 |
+
logging_steps = logging_steps,
|
287 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
288 |
+
save_strategy = save_strategy,
|
289 |
+
save_steps = save_steps,
|
290 |
+
save_total_limit = save_total_limit,
|
291 |
+
save_safetensors = save_safetensors,
|
292 |
+
save_on_each_node = save_on_each_node,
|
293 |
+
save_only_model = save_only_model,
|
294 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
295 |
+
no_cuda = no_cuda,
|
296 |
+
use_cpu = use_cpu,
|
297 |
+
use_mps_device = use_mps_device,
|
298 |
+
seed = seed,
|
299 |
+
data_seed = data_seed,
|
300 |
+
jit_mode_eval = jit_mode_eval,
|
301 |
+
use_ipex = use_ipex,
|
302 |
+
bf16 = bf16,
|
303 |
+
fp16 = fp16,
|
304 |
+
fp16_opt_level = fp16_opt_level,
|
305 |
+
half_precision_backend = half_precision_backend,
|
306 |
+
bf16_full_eval = bf16_full_eval,
|
307 |
+
fp16_full_eval = fp16_full_eval,
|
308 |
+
tf32 = tf32,
|
309 |
+
local_rank = local_rank,
|
310 |
+
ddp_backend = ddp_backend,
|
311 |
+
tpu_num_cores = tpu_num_cores,
|
312 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
313 |
+
debug = debug,
|
314 |
+
dataloader_drop_last = dataloader_drop_last,
|
315 |
+
eval_steps = eval_steps,
|
316 |
+
dataloader_num_workers = dataloader_num_workers,
|
317 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
318 |
+
past_index = past_index,
|
319 |
+
run_name = run_name,
|
320 |
+
disable_tqdm = disable_tqdm,
|
321 |
+
remove_unused_columns = remove_unused_columns,
|
322 |
+
label_names = label_names,
|
323 |
+
load_best_model_at_end = load_best_model_at_end,
|
324 |
+
metric_for_best_model = metric_for_best_model,
|
325 |
+
greater_is_better = greater_is_better,
|
326 |
+
ignore_data_skip = ignore_data_skip,
|
327 |
+
fsdp = fsdp,
|
328 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
329 |
+
fsdp_config = fsdp_config,
|
330 |
+
tp_size = tp_size,
|
331 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
332 |
+
accelerator_config = accelerator_config,
|
333 |
+
deepspeed = deepspeed,
|
334 |
+
label_smoothing_factor = label_smoothing_factor,
|
335 |
+
optim = optim,
|
336 |
+
optim_args = optim_args,
|
337 |
+
adafactor = adafactor,
|
338 |
+
group_by_length = group_by_length,
|
339 |
+
length_column_name = length_column_name,
|
340 |
+
report_to = report_to,
|
341 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
342 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
343 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
344 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
345 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
346 |
+
skip_memory_metrics = skip_memory_metrics,
|
347 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
348 |
+
push_to_hub = push_to_hub,
|
349 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
350 |
+
hub_model_id = hub_model_id,
|
351 |
+
hub_strategy = hub_strategy,
|
352 |
+
hub_token = hub_token,
|
353 |
+
hub_private_repo = hub_private_repo,
|
354 |
+
hub_always_push = hub_always_push,
|
355 |
+
gradient_checkpointing = gradient_checkpointing,
|
356 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
357 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
358 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
359 |
+
fp16_backend = fp16_backend,
|
360 |
+
evaluation_strategy = evaluation_strategy,
|
361 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
362 |
+
push_to_hub_organization = push_to_hub_organization,
|
363 |
+
push_to_hub_token = push_to_hub_token,
|
364 |
+
mp_parameters = mp_parameters,
|
365 |
+
auto_find_batch_size = auto_find_batch_size,
|
366 |
+
full_determinism = full_determinism,
|
367 |
+
torchdynamo = torchdynamo,
|
368 |
+
ray_scope = ray_scope,
|
369 |
+
ddp_timeout = ddp_timeout,
|
370 |
+
torch_compile = torch_compile,
|
371 |
+
torch_compile_backend = torch_compile_backend,
|
372 |
+
torch_compile_mode = torch_compile_mode,
|
373 |
+
dispatch_batches = dispatch_batches,
|
374 |
+
split_batches = split_batches,
|
375 |
+
include_tokens_per_second = include_tokens_per_second,
|
376 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
377 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
378 |
+
optim_target_modules = optim_target_modules,
|
379 |
+
batch_eval_metrics = batch_eval_metrics,
|
380 |
+
eval_on_start = eval_on_start,
|
381 |
+
use_liger_kernel = use_liger_kernel,
|
382 |
+
eval_use_gather_object = eval_use_gather_object,
|
383 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
384 |
+
max_length = max_length,
|
385 |
+
max_prompt_length = max_prompt_length,
|
386 |
+
max_completion_length = max_completion_length,
|
387 |
+
beta = beta,
|
388 |
+
disable_dropout = disable_dropout,
|
389 |
+
label_pad_token_id = label_pad_token_id,
|
390 |
+
padding_value = padding_value,
|
391 |
+
truncation_mode = truncation_mode,
|
392 |
+
generate_during_eval = generate_during_eval,
|
393 |
+
is_encoder_decoder = is_encoder_decoder,
|
394 |
+
model_init_kwargs = model_init_kwargs,
|
395 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
396 |
+
self.vllm_sampling_params = vllm_sampling_params
|
397 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
398 |
+
pass
|
399 |
+
|
400 |
+
class _UnslothORPOTrainer(Trainer):
|
401 |
+
r""""""
|
402 |
+
|
403 |
+
_tag_names = ["trl", "orpo"]
|
404 |
+
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
408 |
+
args: Optional[ORPOConfig] = None,
|
409 |
+
data_collator: Optional[DataCollator] = None,
|
410 |
+
train_dataset: Optional[Dataset] = None,
|
411 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
412 |
+
processing_class: Optional[
|
413 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
414 |
+
] = None,
|
415 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
416 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
417 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
418 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
419 |
+
peft_config: Optional[dict] = None,
|
420 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
421 |
+
):
|
422 |
+
if args.model_init_kwargs is None:
|
423 |
+
model_init_kwargs = {}
|
424 |
+
elif not isinstance(model, str):
|
425 |
+
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
426 |
+
else:
|
427 |
+
model_init_kwargs = args.model_init_kwargs
|
428 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
429 |
+
if torch_dtype is not None:
|
430 |
+
# Convert to `torch.dtype` if an str is passed
|
431 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
432 |
+
torch_dtype = getattr(torch, torch_dtype)
|
433 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
434 |
+
raise ValueError(
|
435 |
+
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
436 |
+
)
|
437 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
438 |
+
|
439 |
+
if isinstance(model, str):
|
440 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
441 |
+
|
442 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
443 |
+
# has been called in order to properly call autocast if needed.
|
444 |
+
self._peft_has_been_casted_to_bf16 = False
|
445 |
+
|
446 |
+
if not is_peft_available() and peft_config is not None:
|
447 |
+
raise ValueError(
|
448 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
449 |
+
)
|
450 |
+
elif is_peft_available() and peft_config is not None:
|
451 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
452 |
+
if isinstance(model, PeftModel):
|
453 |
+
model = model.merge_and_unload()
|
454 |
+
|
455 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
456 |
+
_support_gc_kwargs = hasattr(
|
457 |
+
args, "gradient_checkpointing_kwargs"
|
458 |
+
) and "gradient_checkpointing_kwargs" in list(
|
459 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
460 |
+
)
|
461 |
+
|
462 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
463 |
+
|
464 |
+
if _support_gc_kwargs:
|
465 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
466 |
+
|
467 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
468 |
+
elif getattr(args, "gradient_checkpointing", False):
|
469 |
+
# For backward compatibility with older versions of transformers
|
470 |
+
if hasattr(model, "enable_input_require_grads"):
|
471 |
+
model.enable_input_require_grads()
|
472 |
+
else:
|
473 |
+
|
474 |
+
def make_inputs_require_grad(module, input, output):
|
475 |
+
output.requires_grad_(True)
|
476 |
+
|
477 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
478 |
+
|
479 |
+
# get peft model with the given config
|
480 |
+
model = model
|
481 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
482 |
+
peft_module_casting_to_bf16(model)
|
483 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
484 |
+
self._peft_has_been_casted_to_bf16 = True
|
485 |
+
|
486 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
487 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
488 |
+
# fail or completely fail.
|
489 |
+
elif getattr(args, "gradient_checkpointing", False):
|
490 |
+
# For backward compatibility with older versions of transformers
|
491 |
+
if hasattr(model, "enable_input_require_grads"):
|
492 |
+
model.enable_input_require_grads()
|
493 |
+
else:
|
494 |
+
|
495 |
+
def make_inputs_require_grad(module, input, output):
|
496 |
+
output.requires_grad_(True)
|
497 |
+
|
498 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
499 |
+
|
500 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
501 |
+
raise ValueError(
|
502 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
503 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
504 |
+
)
|
505 |
+
|
506 |
+
if model is not None:
|
507 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
508 |
+
elif args.is_encoder_decoder is None:
|
509 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
510 |
+
else:
|
511 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
512 |
+
|
513 |
+
if self.is_encoder_decoder:
|
514 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
515 |
+
self.pad_token_id = model.config.pad_token_id
|
516 |
+
|
517 |
+
if processing_class is None:
|
518 |
+
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
519 |
+
if args.max_length is None:
|
520 |
+
warnings.warn(
|
521 |
+
"`max_length` is not set in the ORPOConfig's init"
|
522 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
523 |
+
UserWarning,
|
524 |
+
)
|
525 |
+
max_length = 512
|
526 |
+
else:
|
527 |
+
max_length = args.max_length
|
528 |
+
if args.max_prompt_length is None:
|
529 |
+
warnings.warn(
|
530 |
+
"`max_prompt_length` is not set in the ORPOConfig's init"
|
531 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
532 |
+
UserWarning,
|
533 |
+
)
|
534 |
+
max_prompt_length = 128
|
535 |
+
else:
|
536 |
+
max_prompt_length = args.max_prompt_length
|
537 |
+
|
538 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
539 |
+
warnings.warn(
|
540 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
541 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
542 |
+
UserWarning,
|
543 |
+
)
|
544 |
+
self.max_completion_length = 128
|
545 |
+
else:
|
546 |
+
self.max_completion_length = args.max_completion_length
|
547 |
+
|
548 |
+
if data_collator is None:
|
549 |
+
data_collator = DPODataCollatorWithPadding(
|
550 |
+
pad_token_id=processing_class.pad_token_id,
|
551 |
+
label_pad_token_id=args.label_pad_token_id,
|
552 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
553 |
+
)
|
554 |
+
|
555 |
+
if args.remove_unused_columns:
|
556 |
+
args.remove_unused_columns = False
|
557 |
+
# warn users
|
558 |
+
warnings.warn(
|
559 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
560 |
+
" we have set it for you, but you should do it yourself in the future.",
|
561 |
+
UserWarning,
|
562 |
+
)
|
563 |
+
|
564 |
+
self.use_dpo_data_collator = True
|
565 |
+
else:
|
566 |
+
self.use_dpo_data_collator = False
|
567 |
+
|
568 |
+
# Disable dropout in the model and reference model
|
569 |
+
if args.disable_dropout:
|
570 |
+
disable_dropout_in_model(model)
|
571 |
+
|
572 |
+
self.max_length = max_length
|
573 |
+
self.generate_during_eval = args.generate_during_eval
|
574 |
+
self.label_pad_token_id = args.label_pad_token_id
|
575 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
576 |
+
self.max_prompt_length = max_prompt_length
|
577 |
+
self.truncation_mode = args.truncation_mode
|
578 |
+
self.processing_class = processing_class
|
579 |
+
|
580 |
+
self.beta = args.beta
|
581 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
582 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
583 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
584 |
+
warnings.warn(
|
585 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
586 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
587 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
588 |
+
"loss.",
|
589 |
+
UserWarning,
|
590 |
+
)
|
591 |
+
|
592 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
593 |
+
|
594 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
595 |
+
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
596 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
597 |
+
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
598 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
599 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
600 |
+
# that the warning has already been issued.
|
601 |
+
model.warnings_issued["estimate_tokens"] = True
|
602 |
+
|
603 |
+
# Compute that only on the main process for faster data processing.
|
604 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
605 |
+
with PartialState().local_main_process_first():
|
606 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
607 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
608 |
+
train_dataset = train_dataset.map(
|
609 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
610 |
+
)
|
611 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
612 |
+
if eval_dataset is not None:
|
613 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
614 |
+
eval_dataset = eval_dataset.map(
|
615 |
+
maybe_apply_chat_template,
|
616 |
+
fn_kwargs={"tokenizer": processing_class},
|
617 |
+
num_proc=args.dataset_num_proc,
|
618 |
+
)
|
619 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
620 |
+
|
621 |
+
super().__init__(
|
622 |
+
model=model,
|
623 |
+
args=args,
|
624 |
+
data_collator=data_collator,
|
625 |
+
train_dataset=train_dataset,
|
626 |
+
eval_dataset=eval_dataset,
|
627 |
+
processing_class=processing_class,
|
628 |
+
model_init=model_init,
|
629 |
+
compute_metrics=compute_metrics,
|
630 |
+
callbacks=callbacks,
|
631 |
+
optimizers=optimizers,
|
632 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
633 |
+
)
|
634 |
+
|
635 |
+
# Add tags for models that have been loaded with the correct transformers version
|
636 |
+
if hasattr(self.model, "add_model_tags"):
|
637 |
+
self.model.add_model_tags(self._tag_names)
|
638 |
+
|
639 |
+
if not hasattr(self, "accelerator"):
|
640 |
+
raise AttributeError(
|
641 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
642 |
+
)
|
643 |
+
|
644 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
645 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
646 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
647 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
648 |
+
|
649 |
+
if model is not None:
|
650 |
+
if hasattr(model, "config"):
|
651 |
+
hidden_size = (
|
652 |
+
max(model.config.hidden_sizes)
|
653 |
+
if getattr(model.config, "hidden_sizes", None)
|
654 |
+
else getattr(model.config, "hidden_size", None)
|
655 |
+
)
|
656 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
657 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
658 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
659 |
+
config_kwargs.update(
|
660 |
+
{
|
661 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
662 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
663 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
664 |
+
}
|
665 |
+
)
|
666 |
+
|
667 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
668 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
669 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
670 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
671 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
672 |
+
model.eval()
|
673 |
+
return model
|
674 |
+
|
675 |
+
def build_tokenized_answer(self, prompt, answer):
|
676 |
+
"""
|
677 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
678 |
+
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
679 |
+
Reference:
|
680 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
681 |
+
"""
|
682 |
+
|
683 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
684 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
685 |
+
|
686 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
687 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
688 |
+
|
689 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
690 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
691 |
+
|
692 |
+
# Prepare input tokens for token by token comparison
|
693 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
694 |
+
|
695 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
696 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
697 |
+
|
698 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
699 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
700 |
+
# on the last token from the prompt being different when tokenized on its own
|
701 |
+
# vs when done as prompt+answer.
|
702 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
703 |
+
|
704 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
705 |
+
# last token has changed due to merging.
|
706 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
707 |
+
response_token_ids_start_idx -= 1
|
708 |
+
|
709 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
710 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
711 |
+
|
712 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
713 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
714 |
+
|
715 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
716 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
717 |
+
|
718 |
+
return dict(
|
719 |
+
prompt_input_ids=prompt_input_ids,
|
720 |
+
prompt_attention_mask=prompt_attention_mask,
|
721 |
+
input_ids=answer_input_ids,
|
722 |
+
attention_mask=answer_attention_mask,
|
723 |
+
)
|
724 |
+
|
725 |
+
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
726 |
+
"""Tokenize a single row from a ORPO specific dataset.
|
727 |
+
|
728 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
729 |
+
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
730 |
+
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
731 |
+
|
732 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to
|
733 |
+
the sum of the length of the prompt and the chosen/rejected response, with
|
734 |
+
label_pad_token_id for the prompt tokens.
|
735 |
+
"""
|
736 |
+
batch = {}
|
737 |
+
prompt = feature["prompt"]
|
738 |
+
chosen = feature["chosen"]
|
739 |
+
rejected = feature["rejected"]
|
740 |
+
|
741 |
+
if not self.is_encoder_decoder:
|
742 |
+
# Check issues below for more details
|
743 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
744 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
745 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
746 |
+
|
747 |
+
if not isinstance(prompt, str):
|
748 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
749 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
750 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
751 |
+
|
752 |
+
if not isinstance(chosen, str):
|
753 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
754 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
755 |
+
|
756 |
+
if not isinstance(rejected, str):
|
757 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
758 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
759 |
+
|
760 |
+
# Last prompt token might get merged by tokenizer and
|
761 |
+
# it should not be included for generation if that happens
|
762 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
763 |
+
|
764 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
765 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
766 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
767 |
+
|
768 |
+
for k, v in prompt_tokens.items():
|
769 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
770 |
+
|
771 |
+
# Make sure prompts only have one different token at most an
|
772 |
+
# and length only differs by 1 at most
|
773 |
+
num_diff_tokens = sum(
|
774 |
+
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
775 |
+
)
|
776 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
777 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
778 |
+
raise ValueError(
|
779 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
780 |
+
"last token due to tokenizer merge ops."
|
781 |
+
)
|
782 |
+
|
783 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
784 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
785 |
+
self.processing_class.bos_token_id,
|
786 |
+
prompt_len_input_ids,
|
787 |
+
prompt_tokens,
|
788 |
+
chosen_prompt_len_input_ids,
|
789 |
+
chosen_tokens,
|
790 |
+
rejected_prompt_len_input_ids,
|
791 |
+
rejected_tokens,
|
792 |
+
)
|
793 |
+
|
794 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
795 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
796 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
797 |
+
)
|
798 |
+
|
799 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
800 |
+
|
801 |
+
# if combined sequence is too long, truncate the prompt
|
802 |
+
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
803 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
804 |
+
if self.truncation_mode == "keep_start":
|
805 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
806 |
+
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
807 |
+
elif self.truncation_mode == "keep_end":
|
808 |
+
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
809 |
+
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
810 |
+
else:
|
811 |
+
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
812 |
+
|
813 |
+
# if that's still too long, truncate the response
|
814 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
815 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
816 |
+
for k in ["input_ids", "attention_mask"]:
|
817 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
818 |
+
|
819 |
+
# Create labels
|
820 |
+
chosen_sequence_tokens = {
|
821 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
822 |
+
}
|
823 |
+
rejected_sequence_tokens = {
|
824 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
825 |
+
}
|
826 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
827 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
828 |
+
self.label_pad_token_id
|
829 |
+
] * len(chosen_tokens["prompt_input_ids"])
|
830 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
831 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
832 |
+
self.label_pad_token_id
|
833 |
+
] * len(rejected_tokens["prompt_input_ids"])
|
834 |
+
|
835 |
+
for k, toks in {
|
836 |
+
"chosen_": chosen_sequence_tokens,
|
837 |
+
"rejected_": rejected_sequence_tokens,
|
838 |
+
"": prompt_tokens,
|
839 |
+
}.items():
|
840 |
+
for type_key, tokens in toks.items():
|
841 |
+
if type_key == "token_type_ids":
|
842 |
+
continue
|
843 |
+
batch[f"{k}{type_key}"] = tokens
|
844 |
+
|
845 |
+
else:
|
846 |
+
chosen_tokens = self.processing_class(
|
847 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
848 |
+
)
|
849 |
+
rejected_tokens = self.processing_class(
|
850 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
851 |
+
)
|
852 |
+
prompt_tokens = self.processing_class(
|
853 |
+
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
854 |
+
)
|
855 |
+
|
856 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
857 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
858 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
859 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
860 |
+
|
861 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
862 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
863 |
+
labels=torch.tensor(batch["rejected_labels"])
|
864 |
+
)
|
865 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
866 |
+
labels=torch.tensor(batch["chosen_labels"])
|
867 |
+
)
|
868 |
+
|
869 |
+
if is_torch_xla_available():
|
870 |
+
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
871 |
+
for k in batch:
|
872 |
+
if "labels" in k or self.is_encoder_decoder:
|
873 |
+
pad_value = self.label_pad_token_id
|
874 |
+
elif k.endswith("_input_ids"):
|
875 |
+
pad_value = self.padding_value
|
876 |
+
elif k.endswith("_attention_mask"):
|
877 |
+
pad_value = 0
|
878 |
+
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
879 |
+
return batch
|
880 |
+
|
881 |
+
@staticmethod
|
882 |
+
def concatenated_inputs(
|
883 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
884 |
+
is_encoder_decoder: bool = False,
|
885 |
+
label_pad_token_id: int = -100,
|
886 |
+
padding_value: int = 0,
|
887 |
+
device: Optional[torch.device] = None,
|
888 |
+
) -> dict[str, torch.LongTensor]:
|
889 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
890 |
+
|
891 |
+
Args:
|
892 |
+
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
893 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
894 |
+
label_pad_token_id: The label pad token id.
|
895 |
+
padding_value: The padding value to use for the concatenated inputs_ids.
|
896 |
+
device: The device for the concatenated inputs.
|
897 |
+
|
898 |
+
Returns:
|
899 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
900 |
+
"""
|
901 |
+
concatenated_batch = {}
|
902 |
+
|
903 |
+
if is_encoder_decoder:
|
904 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
905 |
+
else:
|
906 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
907 |
+
|
908 |
+
for k in batch:
|
909 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
910 |
+
if "labels" in k or is_encoder_decoder:
|
911 |
+
pad_value = label_pad_token_id
|
912 |
+
elif k.endswith("_input_ids"):
|
913 |
+
pad_value = padding_value
|
914 |
+
elif k.endswith("_attention_mask"):
|
915 |
+
pad_value = 0
|
916 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
917 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
918 |
+
for k in batch:
|
919 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
920 |
+
if "labels" in k or is_encoder_decoder:
|
921 |
+
pad_value = label_pad_token_id
|
922 |
+
elif k.endswith("_input_ids"):
|
923 |
+
pad_value = padding_value
|
924 |
+
elif k.endswith("_attention_mask"):
|
925 |
+
pad_value = 0
|
926 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
927 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
928 |
+
(
|
929 |
+
concatenated_batch[concatenated_key],
|
930 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
931 |
+
),
|
932 |
+
dim=0,
|
933 |
+
).to(device=device)
|
934 |
+
|
935 |
+
if is_encoder_decoder:
|
936 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
937 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
938 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
939 |
+
)
|
940 |
+
|
941 |
+
return concatenated_batch
|
942 |
+
|
943 |
+
def odds_ratio_loss(
|
944 |
+
self,
|
945 |
+
policy_chosen_logps: torch.FloatTensor,
|
946 |
+
policy_rejected_logps: torch.FloatTensor,
|
947 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
948 |
+
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
949 |
+
|
950 |
+
Args:
|
951 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
952 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
953 |
+
|
954 |
+
Returns:
|
955 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
956 |
+
The losses tensor contains the ORPO loss for each example in the batch.
|
957 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
958 |
+
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
959 |
+
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
960 |
+
"""
|
961 |
+
|
962 |
+
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
963 |
+
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
964 |
+
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
965 |
+
)
|
966 |
+
ratio = F.logsigmoid(log_odds)
|
967 |
+
losses = self.beta * ratio
|
968 |
+
|
969 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
970 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
971 |
+
|
972 |
+
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
973 |
+
|
974 |
+
@staticmethod
|
975 |
+
def get_batch_logps(
|
976 |
+
logits: torch.FloatTensor,
|
977 |
+
labels: torch.LongTensor,
|
978 |
+
average_log_prob: bool = False,
|
979 |
+
label_pad_token_id: int = -100,
|
980 |
+
is_encoder_decoder: bool = False,
|
981 |
+
) -> torch.FloatTensor:
|
982 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
983 |
+
|
984 |
+
Args:
|
985 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
986 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
987 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
988 |
+
label_pad_token_id: The label pad token id.
|
989 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
990 |
+
|
991 |
+
Returns:
|
992 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
993 |
+
"""
|
994 |
+
if logits.shape[:-1] != labels.shape:
|
995 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
996 |
+
|
997 |
+
if not is_encoder_decoder:
|
998 |
+
labels = labels[:, 1:].clone()
|
999 |
+
logits = logits[:, :-1, :]
|
1000 |
+
loss_mask = labels != label_pad_token_id
|
1001 |
+
|
1002 |
+
# dummy token; we'll ignore the losses on these tokens later
|
1003 |
+
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
1004 |
+
|
1005 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
1006 |
+
|
1007 |
+
if average_log_prob:
|
1008 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
1009 |
+
else:
|
1010 |
+
return (per_token_logps * loss_mask).sum(-1)
|
1011 |
+
|
1012 |
+
def concatenated_forward(
|
1013 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
1014 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
1015 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
1016 |
+
|
1017 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
1018 |
+
"""
|
1019 |
+
concatenated_batch = self.concatenated_inputs(
|
1020 |
+
batch,
|
1021 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1022 |
+
label_pad_token_id=self.label_pad_token_id,
|
1023 |
+
padding_value=self.padding_value,
|
1024 |
+
device=self.accelerator.device,
|
1025 |
+
)
|
1026 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
1027 |
+
|
1028 |
+
model_kwargs = (
|
1029 |
+
{
|
1030 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
1031 |
+
}
|
1032 |
+
if self.is_encoder_decoder
|
1033 |
+
else {}
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
if self.aux_loss_enabled:
|
1037 |
+
model_kwargs["output_router_logits"] = True
|
1038 |
+
|
1039 |
+
outputs = model(
|
1040 |
+
concatenated_batch["concatenated_input_ids"],
|
1041 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
1042 |
+
use_cache=False,
|
1043 |
+
**model_kwargs,
|
1044 |
+
)
|
1045 |
+
all_logits = outputs.logits
|
1046 |
+
|
1047 |
+
def cross_entropy_loss(logits, labels):
|
1048 |
+
if not self.is_encoder_decoder:
|
1049 |
+
# Shift so that tokens < n predict n
|
1050 |
+
logits = logits[..., :-1, :].contiguous()
|
1051 |
+
labels = labels[..., 1:].contiguous()
|
1052 |
+
# Flatten the tokens
|
1053 |
+
loss_fct = nn.CrossEntropyLoss()
|
1054 |
+
logits = logits.view(-1, logits.shape[-1])
|
1055 |
+
labels = labels.view(-1)
|
1056 |
+
# Enable model parallelism
|
1057 |
+
labels = labels.to(logits.device)
|
1058 |
+
loss = loss_fct(logits, labels)
|
1059 |
+
return loss
|
1060 |
+
|
1061 |
+
if self.is_encoder_decoder:
|
1062 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
1063 |
+
else:
|
1064 |
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
1065 |
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
1066 |
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
1067 |
+
# orpo chosen nll loss is computed over the full prompt and response
|
1068 |
+
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
1069 |
+
|
1070 |
+
all_logps = self.get_batch_logps(
|
1071 |
+
all_logits,
|
1072 |
+
concatenated_batch["concatenated_labels"],
|
1073 |
+
average_log_prob=True,
|
1074 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
1075 |
+
label_pad_token_id=self.label_pad_token_id,
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
chosen_logps = all_logps[:len_chosen]
|
1079 |
+
rejected_logps = all_logps[len_chosen:]
|
1080 |
+
|
1081 |
+
if not self.is_encoder_decoder:
|
1082 |
+
chosen_logits = all_logits[:len_chosen, :-1, :]
|
1083 |
+
rejected_logits = all_logits[len_chosen:, :-1, :]
|
1084 |
+
else:
|
1085 |
+
chosen_logits = all_logits[:len_chosen]
|
1086 |
+
rejected_logits = all_logits[len_chosen:]
|
1087 |
+
|
1088 |
+
if self.aux_loss_enabled:
|
1089 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
1090 |
+
|
1091 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
1092 |
+
|
1093 |
+
def get_batch_loss_metrics(
|
1094 |
+
self,
|
1095 |
+
model,
|
1096 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
1097 |
+
train_eval: Literal["train", "eval"] = "train",
|
1098 |
+
):
|
1099 |
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
1100 |
+
metrics = {}
|
1101 |
+
|
1102 |
+
forward_output = self.concatenated_forward(model, batch)
|
1103 |
+
(
|
1104 |
+
policy_chosen_logps,
|
1105 |
+
policy_rejected_logps,
|
1106 |
+
policy_chosen_logits,
|
1107 |
+
policy_rejected_logits,
|
1108 |
+
policy_nll_loss,
|
1109 |
+
) = forward_output[:5]
|
1110 |
+
if self.aux_loss_enabled:
|
1111 |
+
aux_loss = forward_output[5]
|
1112 |
+
|
1113 |
+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
1114 |
+
policy_chosen_logps, policy_rejected_logps
|
1115 |
+
)
|
1116 |
+
# full ORPO loss
|
1117 |
+
loss = policy_nll_loss - losses.mean()
|
1118 |
+
|
1119 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
1120 |
+
|
1121 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
1122 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
1123 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
1124 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
1125 |
+
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
1126 |
+
chosen_rewards - rejected_rewards
|
1127 |
+
).mean()
|
1128 |
+
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
1129 |
+
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
1130 |
+
metrics[f"{prefix}logits/rejected"] = (
|
1131 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
|
1132 |
+
)
|
1133 |
+
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
|
1134 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
1135 |
+
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
|
1136 |
+
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
|
1137 |
+
if is_torch_xla_available():
|
1138 |
+
xm.mark_step() # needed because .item() calls
|
1139 |
+
for k, v in metrics.items():
|
1140 |
+
metrics[k] = v.item()
|
1141 |
+
if self.aux_loss_enabled:
|
1142 |
+
loss += self.aux_loss_coef * aux_loss
|
1143 |
+
|
1144 |
+
return loss, metrics
|
1145 |
+
|
1146 |
+
def compute_loss(
|
1147 |
+
self,
|
1148 |
+
model: Union[PreTrainedModel, nn.Module],
|
1149 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1150 |
+
return_outputs=False,
|
1151 |
+
num_items_in_batch=None,
|
1152 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
1153 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1154 |
+
|
1155 |
+
with compute_loss_context_manager:
|
1156 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
1157 |
+
|
1158 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
1159 |
+
loss = loss.to(self.args.device)
|
1160 |
+
|
1161 |
+
# force log the metrics
|
1162 |
+
self.store_metrics(metrics, train_eval="train")
|
1163 |
+
|
1164 |
+
if return_outputs:
|
1165 |
+
return (loss, metrics)
|
1166 |
+
return loss
|
1167 |
+
|
1168 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
1169 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
1170 |
+
|
1171 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
1172 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
1173 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1174 |
+
|
1175 |
+
with generate_context_manager:
|
1176 |
+
policy_output = model.generate(
|
1177 |
+
input_ids=batch["prompt_input_ids"],
|
1178 |
+
attention_mask=batch["prompt_attention_mask"],
|
1179 |
+
max_length=self.max_length,
|
1180 |
+
do_sample=True,
|
1181 |
+
pad_token_id=self.processing_class.pad_token_id,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
1185 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
1186 |
+
|
1187 |
+
return policy_output_decoded
|
1188 |
+
|
1189 |
+
def prediction_step(
|
1190 |
+
self,
|
1191 |
+
model: Union[PreTrainedModel, nn.Module],
|
1192 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
1193 |
+
prediction_loss_only: bool,
|
1194 |
+
ignore_keys: Optional[list[str]] = None,
|
1195 |
+
):
|
1196 |
+
if not self.use_dpo_data_collator:
|
1197 |
+
warnings.warn(
|
1198 |
+
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
1199 |
+
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
1200 |
+
)
|
1201 |
+
if ignore_keys is None:
|
1202 |
+
if hasattr(model, "config"):
|
1203 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
1204 |
+
else:
|
1205 |
+
ignore_keys = []
|
1206 |
+
|
1207 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
1208 |
+
|
1209 |
+
with torch.no_grad(), prediction_context_manager:
|
1210 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
1211 |
+
|
1212 |
+
# force log the metrics
|
1213 |
+
self.store_metrics(metrics, train_eval="eval")
|
1214 |
+
|
1215 |
+
if prediction_loss_only:
|
1216 |
+
return (loss.detach(), None, None)
|
1217 |
+
|
1218 |
+
# logits for the chosen and rejected samples from model
|
1219 |
+
logits_dict = {
|
1220 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
1221 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
1222 |
+
}
|
1223 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
1224 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
1225 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
1226 |
+
|
1227 |
+
return (loss.detach(), logits, labels)
|
1228 |
+
|
1229 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
1230 |
+
for key, value in metrics.items():
|
1231 |
+
self._stored_metrics[train_eval][key].append(value)
|
1232 |
+
|
1233 |
+
def evaluation_loop(
|
1234 |
+
self,
|
1235 |
+
dataloader: DataLoader,
|
1236 |
+
description: str,
|
1237 |
+
prediction_loss_only: Optional[bool] = None,
|
1238 |
+
ignore_keys: Optional[list[str]] = None,
|
1239 |
+
metric_key_prefix: str = "eval",
|
1240 |
+
) -> EvalLoopOutput:
|
1241 |
+
"""
|
1242 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
1243 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
1244 |
+
|
1245 |
+
Works both with or without labels.
|
1246 |
+
"""
|
1247 |
+
|
1248 |
+
# Sample and save to game log if requested (for one batch to save time)
|
1249 |
+
if self.generate_during_eval:
|
1250 |
+
# Generate random indices within the range of the total number of samples
|
1251 |
+
num_samples = len(dataloader.dataset)
|
1252 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
1253 |
+
|
1254 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
1255 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
1256 |
+
random_batch = self.data_collator(random_batch_dataset)
|
1257 |
+
random_batch = self._prepare_inputs(random_batch)
|
1258 |
+
|
1259 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
1260 |
+
|
1261 |
+
table = pd.DataFrame(
|
1262 |
+
columns=["Prompt", "Policy"],
|
1263 |
+
data=[
|
1264 |
+
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
1265 |
+
],
|
1266 |
+
)
|
1267 |
+
if "wandb" in self.args.report_to:
|
1268 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
1269 |
+
|
1270 |
+
if "comet_ml" in self.args.report_to:
|
1271 |
+
log_table_to_comet_experiment(
|
1272 |
+
name="game_log.csv",
|
1273 |
+
table=table,
|
1274 |
+
)
|
1275 |
+
|
1276 |
+
# Base evaluation
|
1277 |
+
initial_output = super().evaluation_loop(
|
1278 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
1279 |
+
)
|
1280 |
+
|
1281 |
+
return initial_output
|
1282 |
+
|
1283 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
1284 |
+
"""
|
1285 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
1286 |
+
|
1287 |
+
Args:
|
1288 |
+
logs (`dict[str, float]`):
|
1289 |
+
The values to log.
|
1290 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
1291 |
+
Start time of the training.
|
1292 |
+
"""
|
1293 |
+
# logs either has 'loss' or 'eval_loss'
|
1294 |
+
train_eval = "train" if "loss" in logs else "eval"
|
1295 |
+
# Add averaged stored metrics to logs
|
1296 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
1297 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
1298 |
+
del self._stored_metrics[train_eval]
|
1299 |
+
|
1300 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
1301 |
+
return super().log(logs, start_time)
|
1302 |
+
else: # transformers<=4.46
|
1303 |
+
return super().log(logs)
|
1304 |
+
|
1305 |
+
def _shift_right(self, input_ids):
|
1306 |
+
if self.decoder_start_token_id is None:
|
1307 |
+
raise ValueError(
|
1308 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
1309 |
+
)
|
1310 |
+
|
1311 |
+
# shift inputs to the right
|
1312 |
+
if is_torch_fx_proxy(input_ids):
|
1313 |
+
# Item assignment is not supported natively for proxies.
|
1314 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
1315 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
1316 |
+
else:
|
1317 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
1318 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
1319 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
1320 |
+
|
1321 |
+
if self.pad_token_id is None:
|
1322 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
1323 |
+
# replace possible -100 values in labels by `pad_token_id`
|
1324 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
1325 |
+
|
1326 |
+
return shifted_input_ids
|
1327 |
+
|
1328 |
+
def create_model_card(
|
1329 |
+
self,
|
1330 |
+
model_name: Optional[str] = None,
|
1331 |
+
dataset_name: Optional[str] = None,
|
1332 |
+
tags: Union[str, list[str], None] = None,
|
1333 |
+
):
|
1334 |
+
"""
|
1335 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1336 |
+
|
1337 |
+
Args:
|
1338 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1339 |
+
Name of the model.
|
1340 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1341 |
+
Name of the dataset used for training.
|
1342 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1343 |
+
Tags to be associated with the model card.
|
1344 |
+
"""
|
1345 |
+
if not self.is_world_process_zero():
|
1346 |
+
return
|
1347 |
+
|
1348 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1349 |
+
base_model = self.model.config._name_or_path
|
1350 |
+
else:
|
1351 |
+
base_model = None
|
1352 |
+
|
1353 |
+
tags = tags or []
|
1354 |
+
if isinstance(tags, str):
|
1355 |
+
tags = [tags]
|
1356 |
+
|
1357 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1358 |
+
tags.append("unsloth")
|
1359 |
+
|
1360 |
+
citation = textwrap.dedent("""\
|
1361 |
+
@article{hong2024orpo,
|
1362 |
+
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
1363 |
+
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
1364 |
+
year = 2024,
|
1365 |
+
eprint = {arXiv:2403.07691}
|
1366 |
+
}""")
|
1367 |
+
|
1368 |
+
model_card = generate_model_card(
|
1369 |
+
base_model=base_model,
|
1370 |
+
model_name=model_name,
|
1371 |
+
hub_model_id=self.hub_model_id,
|
1372 |
+
dataset_name=dataset_name,
|
1373 |
+
tags=tags,
|
1374 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1375 |
+
comet_url=get_comet_experiment_url(),
|
1376 |
+
trainer_name="ORPO",
|
1377 |
+
trainer_citation=citation,
|
1378 |
+
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
|
1379 |
+
paper_id="2403.07691",
|
1380 |
+
)
|
1381 |
+
|
1382 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1383 |
+
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
1384 |
+
"""
|
1385 |
+
|
1386 |
+
Initialize ORPOTrainer.
|
1387 |
+
|
1388 |
+
Args:
|
1389 |
+
model (`transformers.PreTrainedModel`):
|
1390 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
1391 |
+
args (`ORPOConfig`):
|
1392 |
+
The ORPO config arguments to use for training.
|
1393 |
+
data_collator (`transformers.DataCollator`):
|
1394 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1395 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1396 |
+
train_dataset (`datasets.Dataset`):
|
1397 |
+
The dataset to use for training.
|
1398 |
+
eval_dataset (`datasets.Dataset`):
|
1399 |
+
The dataset to use for evaluation.
|
1400 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1401 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1402 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1403 |
+
reuse the fine-tuned model.
|
1404 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
1405 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
1406 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1407 |
+
The callbacks to use for training.
|
1408 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1409 |
+
The optimizer and scheduler to use for training.
|
1410 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1411 |
+
The function to use to preprocess the logits before computing the metrics.
|
1412 |
+
peft_config (`dict`, defaults to `None`):
|
1413 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
1414 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1415 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1416 |
+
a dictionary string to metric values.
|
1417 |
+
|
1418 |
+
"""
|
1419 |
+
def __init__(
|
1420 |
+
self,
|
1421 |
+
model = None,
|
1422 |
+
args = None,
|
1423 |
+
data_collator = None,
|
1424 |
+
train_dataset = None,
|
1425 |
+
eval_dataset = None,
|
1426 |
+
processing_class = None,
|
1427 |
+
model_init = None,
|
1428 |
+
callbacks = None,
|
1429 |
+
preprocess_logits_for_metrics = None,
|
1430 |
+
peft_config = None,
|
1431 |
+
compute_metrics = None,
|
1432 |
+
**kwargs
|
1433 |
+
):
|
1434 |
+
if args is None: args = UnslothORPOConfig()
|
1435 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1436 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1437 |
+
force_float32 = False
|
1438 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1439 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1440 |
+
force_float32 = True
|
1441 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1442 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1443 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1444 |
+
from unsloth_zoo.utils import _get_dtype
|
1445 |
+
dtype = _get_dtype(dtype)
|
1446 |
+
float16 = dtype == torch.float16
|
1447 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1448 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1449 |
+
if force_float32:
|
1450 |
+
args.fp16 = False
|
1451 |
+
args.bf16 = False
|
1452 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1453 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1454 |
+
args.fp16 = float16
|
1455 |
+
args.bf16 = not float16
|
1456 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1457 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1458 |
+
args.eval_strategy = 'steps'
|
1459 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1460 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1461 |
+
if ga_steps is not None and ga_steps > 1:
|
1462 |
+
from transformers import __version__ as transformers_version
|
1463 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1464 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1465 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1466 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1467 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1468 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1469 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1470 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1471 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1472 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1473 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1474 |
+
if force_float32:
|
1475 |
+
args.bf16_full_eval = False
|
1476 |
+
args.fp16_full_eval = False
|
1477 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1478 |
+
args.bf16_full_eval = True
|
1479 |
+
args.fp16_full_eval = False
|
1480 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1481 |
+
args.bf16_full_eval = args.bf16
|
1482 |
+
args.fp16_full_eval = args.fp16
|
1483 |
+
_output_logits = False
|
1484 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1485 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1486 |
+
if _output_logits:
|
1487 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1488 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1489 |
+
pass
|
1490 |
+
else:
|
1491 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1492 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1493 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1494 |
+
max_seq_length = model.max_seq_length
|
1495 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1496 |
+
if model is not None and hasattr(model, 'for_training'):
|
1497 |
+
model.for_training()
|
1498 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1499 |
+
if 'processing_class' in locals():
|
1500 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1501 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1502 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1503 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1504 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1505 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1506 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1507 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1508 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1509 |
+
else:
|
1510 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1511 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1512 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1513 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1514 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1515 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1516 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1517 |
+
else:
|
1518 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1519 |
+
other_metrics = []
|
1520 |
+
|
1521 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1522 |
+
PatchRLStatistics('orpo_trainer', other_metrics)
|
1523 |
+
|
1524 |
+
super().__init__(
|
1525 |
+
model = model,
|
1526 |
+
args = args,
|
1527 |
+
data_collator = data_collator,
|
1528 |
+
train_dataset = train_dataset,
|
1529 |
+
eval_dataset = eval_dataset,
|
1530 |
+
processing_class = processing_class,
|
1531 |
+
model_init = model_init,
|
1532 |
+
callbacks = callbacks,
|
1533 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1534 |
+
peft_config = peft_config,
|
1535 |
+
compute_metrics = compute_metrics,**kwargs)
|
1536 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1537 |
+
self.neftune_hook_handle.remove()
|
1538 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1539 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1540 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1541 |
+
pass
|
1542 |
+
|
1543 |
+
pass
|
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
ADDED
@@ -0,0 +1,1269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, warnings, wraps, F, is_conversational, os, torch)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
def vLLMSamplingParams(**kwargs):
|
43 |
+
from vllm import SamplingParams
|
44 |
+
sampling_params = SamplingParams(**kwargs)
|
45 |
+
sampling_params._set_kwargs = kwargs
|
46 |
+
return sampling_params
|
47 |
+
@dataclass
|
48 |
+
class UnslothOnlineDPOConfig(OnlineDPOConfig):
|
49 |
+
"""
|
50 |
+
|
51 |
+
Configuration class for the [`OnlineDPOTrainer`].
|
52 |
+
|
53 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
54 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
55 |
+
command line.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
59 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
60 |
+
[`~transformers.TrainingArguments`].
|
61 |
+
reward_model_path (`str` or `None`, *optional*, defaults to `None`):
|
62 |
+
Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
|
63 |
+
judge (`str` or `None`, *optional*, defaults to `None`):
|
64 |
+
Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
|
65 |
+
max_new_tokens (`int`, *optional*, defaults to `64`):
|
66 |
+
Maximum number of tokens to generate per completion.
|
67 |
+
max_length (`int`, *optional*, defaults to `256`):
|
68 |
+
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
|
69 |
+
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
|
70 |
+
possible.
|
71 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
72 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
73 |
+
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
|
74 |
+
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
|
75 |
+
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
|
76 |
+
value.
|
77 |
+
beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
|
78 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
79 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
80 |
+
the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
|
81 |
+
selected for each new epoch and the last β is used for the rest of the epochs.
|
82 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
83 |
+
Type of loss to use. Possible values are:
|
84 |
+
|
85 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
86 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
87 |
+
|
88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
89 |
+
Number of processes to use for processing the dataset.
|
90 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
91 |
+
Whether to disable dropout in the model and reference model.
|
92 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
93 |
+
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
|
94 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
95 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
96 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
97 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
98 |
+
|
99 |
+
"""
|
100 |
+
vllm_sampling_params: Optional[Any] = field(
|
101 |
+
default = None,
|
102 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
103 |
+
)
|
104 |
+
unsloth_num_chunks : Optional[int] = field(
|
105 |
+
default = -1,
|
106 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
107 |
+
)
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
output_dir = None,
|
111 |
+
overwrite_output_dir = None,
|
112 |
+
do_train = False,
|
113 |
+
do_eval = False,
|
114 |
+
do_predict = False,
|
115 |
+
eval_strategy = 'no',
|
116 |
+
prediction_loss_only = False,
|
117 |
+
per_device_train_batch_size = 4,
|
118 |
+
per_device_eval_batch_size = 4,
|
119 |
+
per_gpu_train_batch_size = None,
|
120 |
+
per_gpu_eval_batch_size = None,
|
121 |
+
gradient_accumulation_steps = 2,
|
122 |
+
eval_accumulation_steps = 2,
|
123 |
+
eval_delay = 0,
|
124 |
+
torch_empty_cache_steps = 250,
|
125 |
+
learning_rate = 5e-05,
|
126 |
+
weight_decay = 0.01,
|
127 |
+
adam_beta1 = 0.9,
|
128 |
+
adam_beta2 = 0.999,
|
129 |
+
adam_epsilon = 1e-08,
|
130 |
+
max_grad_norm = 1.0,
|
131 |
+
num_train_epochs = 3.0,
|
132 |
+
max_steps = -1,
|
133 |
+
lr_scheduler_type = 'linear',
|
134 |
+
warmup_ratio = 0.1,
|
135 |
+
warmup_steps = 0,
|
136 |
+
log_level = 'passive',
|
137 |
+
log_level_replica = 'warning',
|
138 |
+
log_on_each_node = True,
|
139 |
+
logging_dir = None,
|
140 |
+
logging_strategy = 'steps',
|
141 |
+
logging_first_step = False,
|
142 |
+
logging_steps = 1,
|
143 |
+
logging_nan_inf_filter = False,
|
144 |
+
save_strategy = 'steps',
|
145 |
+
save_steps = 500,
|
146 |
+
save_total_limit = None,
|
147 |
+
save_safetensors = True,
|
148 |
+
save_on_each_node = False,
|
149 |
+
save_only_model = False,
|
150 |
+
restore_callback_states_from_checkpoint = False,
|
151 |
+
no_cuda = False,
|
152 |
+
use_cpu = False,
|
153 |
+
use_mps_device = False,
|
154 |
+
seed = 3407,
|
155 |
+
data_seed = 3407,
|
156 |
+
jit_mode_eval = False,
|
157 |
+
use_ipex = False,
|
158 |
+
bf16 = False,
|
159 |
+
fp16 = False,
|
160 |
+
fp16_opt_level = 'O1',
|
161 |
+
half_precision_backend = 'auto',
|
162 |
+
bf16_full_eval = False,
|
163 |
+
fp16_full_eval = False,
|
164 |
+
tf32 = None,
|
165 |
+
local_rank = -1,
|
166 |
+
ddp_backend = None,
|
167 |
+
tpu_num_cores = None,
|
168 |
+
tpu_metrics_debug = False,
|
169 |
+
debug = '',
|
170 |
+
dataloader_drop_last = False,
|
171 |
+
eval_steps = None,
|
172 |
+
dataloader_num_workers = 0,
|
173 |
+
dataloader_prefetch_factor = None,
|
174 |
+
past_index = -1,
|
175 |
+
run_name = None,
|
176 |
+
disable_tqdm = None,
|
177 |
+
remove_unused_columns = True,
|
178 |
+
label_names = None,
|
179 |
+
load_best_model_at_end = False,
|
180 |
+
metric_for_best_model = None,
|
181 |
+
greater_is_better = None,
|
182 |
+
ignore_data_skip = False,
|
183 |
+
fsdp = '',
|
184 |
+
fsdp_min_num_params = 0,
|
185 |
+
fsdp_config = None,
|
186 |
+
tp_size = 0,
|
187 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
188 |
+
accelerator_config = None,
|
189 |
+
deepspeed = None,
|
190 |
+
label_smoothing_factor = 0.0,
|
191 |
+
optim = 'adamw_8bit',
|
192 |
+
optim_args = None,
|
193 |
+
adafactor = False,
|
194 |
+
group_by_length = False,
|
195 |
+
length_column_name = 'length',
|
196 |
+
report_to = None,
|
197 |
+
ddp_find_unused_parameters = None,
|
198 |
+
ddp_bucket_cap_mb = None,
|
199 |
+
ddp_broadcast_buffers = None,
|
200 |
+
dataloader_pin_memory = True,
|
201 |
+
dataloader_persistent_workers = False,
|
202 |
+
skip_memory_metrics = True,
|
203 |
+
use_legacy_prediction_loop = False,
|
204 |
+
push_to_hub = False,
|
205 |
+
resume_from_checkpoint = None,
|
206 |
+
hub_model_id = None,
|
207 |
+
hub_strategy = 'every_save',
|
208 |
+
hub_token = None,
|
209 |
+
hub_private_repo = None,
|
210 |
+
hub_always_push = False,
|
211 |
+
gradient_checkpointing = False,
|
212 |
+
gradient_checkpointing_kwargs = None,
|
213 |
+
include_inputs_for_metrics = False,
|
214 |
+
eval_do_concat_batches = True,
|
215 |
+
fp16_backend = 'auto',
|
216 |
+
evaluation_strategy = None,
|
217 |
+
push_to_hub_model_id = None,
|
218 |
+
push_to_hub_organization = None,
|
219 |
+
push_to_hub_token = None,
|
220 |
+
mp_parameters = '',
|
221 |
+
auto_find_batch_size = False,
|
222 |
+
full_determinism = False,
|
223 |
+
torchdynamo = None,
|
224 |
+
ray_scope = 'last',
|
225 |
+
ddp_timeout = 1800,
|
226 |
+
torch_compile = False,
|
227 |
+
torch_compile_backend = None,
|
228 |
+
torch_compile_mode = None,
|
229 |
+
dispatch_batches = None,
|
230 |
+
split_batches = None,
|
231 |
+
include_tokens_per_second = False,
|
232 |
+
include_num_input_tokens_seen = False,
|
233 |
+
neftune_noise_alpha = None,
|
234 |
+
optim_target_modules = None,
|
235 |
+
batch_eval_metrics = False,
|
236 |
+
eval_on_start = False,
|
237 |
+
use_liger_kernel = False,
|
238 |
+
eval_use_gather_object = False,
|
239 |
+
average_tokens_across_devices = False,
|
240 |
+
reward_model_path = None,
|
241 |
+
judge = None,
|
242 |
+
max_new_tokens = 64,
|
243 |
+
max_length = 512,
|
244 |
+
temperature = 0.9,
|
245 |
+
missing_eos_penalty = None,
|
246 |
+
loss_type = 'sigmoid',
|
247 |
+
dataset_num_proc = None,
|
248 |
+
disable_dropout = True,
|
249 |
+
use_vllm = False,
|
250 |
+
ds3_gather_for_generation = True,
|
251 |
+
vllm_sampling_params = None,
|
252 |
+
unsloth_num_chunks = -1,
|
253 |
+
**kwargs,
|
254 |
+
):
|
255 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
256 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
257 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
258 |
+
output_dir = 'unsloth_training_checkpoints'
|
259 |
+
save_strategy = 'no'
|
260 |
+
if dataset_num_proc is None:
|
261 |
+
from multiprocessing import cpu_count
|
262 |
+
dataset_num_proc = cpu_count()
|
263 |
+
|
264 |
+
super().__init__(
|
265 |
+
output_dir = output_dir,
|
266 |
+
overwrite_output_dir = overwrite_output_dir,
|
267 |
+
do_train = do_train,
|
268 |
+
do_eval = do_eval,
|
269 |
+
do_predict = do_predict,
|
270 |
+
eval_strategy = eval_strategy,
|
271 |
+
prediction_loss_only = prediction_loss_only,
|
272 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
273 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
274 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
275 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
276 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
277 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
278 |
+
eval_delay = eval_delay,
|
279 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
280 |
+
learning_rate = learning_rate,
|
281 |
+
weight_decay = weight_decay,
|
282 |
+
adam_beta1 = adam_beta1,
|
283 |
+
adam_beta2 = adam_beta2,
|
284 |
+
adam_epsilon = adam_epsilon,
|
285 |
+
max_grad_norm = max_grad_norm,
|
286 |
+
num_train_epochs = num_train_epochs,
|
287 |
+
max_steps = max_steps,
|
288 |
+
lr_scheduler_type = lr_scheduler_type,
|
289 |
+
warmup_ratio = warmup_ratio,
|
290 |
+
warmup_steps = warmup_steps,
|
291 |
+
log_level = log_level,
|
292 |
+
log_level_replica = log_level_replica,
|
293 |
+
log_on_each_node = log_on_each_node,
|
294 |
+
logging_dir = logging_dir,
|
295 |
+
logging_strategy = logging_strategy,
|
296 |
+
logging_first_step = logging_first_step,
|
297 |
+
logging_steps = logging_steps,
|
298 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
299 |
+
save_strategy = save_strategy,
|
300 |
+
save_steps = save_steps,
|
301 |
+
save_total_limit = save_total_limit,
|
302 |
+
save_safetensors = save_safetensors,
|
303 |
+
save_on_each_node = save_on_each_node,
|
304 |
+
save_only_model = save_only_model,
|
305 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
306 |
+
no_cuda = no_cuda,
|
307 |
+
use_cpu = use_cpu,
|
308 |
+
use_mps_device = use_mps_device,
|
309 |
+
seed = seed,
|
310 |
+
data_seed = data_seed,
|
311 |
+
jit_mode_eval = jit_mode_eval,
|
312 |
+
use_ipex = use_ipex,
|
313 |
+
bf16 = bf16,
|
314 |
+
fp16 = fp16,
|
315 |
+
fp16_opt_level = fp16_opt_level,
|
316 |
+
half_precision_backend = half_precision_backend,
|
317 |
+
bf16_full_eval = bf16_full_eval,
|
318 |
+
fp16_full_eval = fp16_full_eval,
|
319 |
+
tf32 = tf32,
|
320 |
+
local_rank = local_rank,
|
321 |
+
ddp_backend = ddp_backend,
|
322 |
+
tpu_num_cores = tpu_num_cores,
|
323 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
324 |
+
debug = debug,
|
325 |
+
dataloader_drop_last = dataloader_drop_last,
|
326 |
+
eval_steps = eval_steps,
|
327 |
+
dataloader_num_workers = dataloader_num_workers,
|
328 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
329 |
+
past_index = past_index,
|
330 |
+
run_name = run_name,
|
331 |
+
disable_tqdm = disable_tqdm,
|
332 |
+
remove_unused_columns = remove_unused_columns,
|
333 |
+
label_names = label_names,
|
334 |
+
load_best_model_at_end = load_best_model_at_end,
|
335 |
+
metric_for_best_model = metric_for_best_model,
|
336 |
+
greater_is_better = greater_is_better,
|
337 |
+
ignore_data_skip = ignore_data_skip,
|
338 |
+
fsdp = fsdp,
|
339 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
340 |
+
fsdp_config = fsdp_config,
|
341 |
+
tp_size = tp_size,
|
342 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
343 |
+
accelerator_config = accelerator_config,
|
344 |
+
deepspeed = deepspeed,
|
345 |
+
label_smoothing_factor = label_smoothing_factor,
|
346 |
+
optim = optim,
|
347 |
+
optim_args = optim_args,
|
348 |
+
adafactor = adafactor,
|
349 |
+
group_by_length = group_by_length,
|
350 |
+
length_column_name = length_column_name,
|
351 |
+
report_to = report_to,
|
352 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
353 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
354 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
355 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
356 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
357 |
+
skip_memory_metrics = skip_memory_metrics,
|
358 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
359 |
+
push_to_hub = push_to_hub,
|
360 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
361 |
+
hub_model_id = hub_model_id,
|
362 |
+
hub_strategy = hub_strategy,
|
363 |
+
hub_token = hub_token,
|
364 |
+
hub_private_repo = hub_private_repo,
|
365 |
+
hub_always_push = hub_always_push,
|
366 |
+
gradient_checkpointing = gradient_checkpointing,
|
367 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
368 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
369 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
370 |
+
fp16_backend = fp16_backend,
|
371 |
+
evaluation_strategy = evaluation_strategy,
|
372 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
373 |
+
push_to_hub_organization = push_to_hub_organization,
|
374 |
+
push_to_hub_token = push_to_hub_token,
|
375 |
+
mp_parameters = mp_parameters,
|
376 |
+
auto_find_batch_size = auto_find_batch_size,
|
377 |
+
full_determinism = full_determinism,
|
378 |
+
torchdynamo = torchdynamo,
|
379 |
+
ray_scope = ray_scope,
|
380 |
+
ddp_timeout = ddp_timeout,
|
381 |
+
torch_compile = torch_compile,
|
382 |
+
torch_compile_backend = torch_compile_backend,
|
383 |
+
torch_compile_mode = torch_compile_mode,
|
384 |
+
dispatch_batches = dispatch_batches,
|
385 |
+
split_batches = split_batches,
|
386 |
+
include_tokens_per_second = include_tokens_per_second,
|
387 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
388 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
389 |
+
optim_target_modules = optim_target_modules,
|
390 |
+
batch_eval_metrics = batch_eval_metrics,
|
391 |
+
eval_on_start = eval_on_start,
|
392 |
+
use_liger_kernel = use_liger_kernel,
|
393 |
+
eval_use_gather_object = eval_use_gather_object,
|
394 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
395 |
+
reward_model_path = reward_model_path,
|
396 |
+
judge = judge,
|
397 |
+
max_new_tokens = max_new_tokens,
|
398 |
+
max_length = max_length,
|
399 |
+
temperature = temperature,
|
400 |
+
missing_eos_penalty = missing_eos_penalty,
|
401 |
+
loss_type = loss_type,
|
402 |
+
dataset_num_proc = dataset_num_proc,
|
403 |
+
disable_dropout = disable_dropout,
|
404 |
+
use_vllm = use_vllm,
|
405 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
406 |
+
self.vllm_sampling_params = vllm_sampling_params
|
407 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
408 |
+
pass
|
409 |
+
|
410 |
+
class _UnslothOnlineDPOTrainer(Trainer):
|
411 |
+
r""""""
|
412 |
+
|
413 |
+
_tag_names = ["trl", "online-dpo"]
|
414 |
+
|
415 |
+
def __init__(
|
416 |
+
self,
|
417 |
+
model: Union[PreTrainedModel, nn.Module],
|
418 |
+
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
419 |
+
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
420 |
+
judge: Optional[BasePairwiseJudge] = None,
|
421 |
+
args: Optional[OnlineDPOConfig] = None,
|
422 |
+
data_collator: Optional[DataCollator] = None,
|
423 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
424 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
|
425 |
+
processing_class: Optional[
|
426 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
427 |
+
] = None,
|
428 |
+
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
429 |
+
peft_config: Optional[dict] = None,
|
430 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
431 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
432 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
433 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
434 |
+
) -> None:
|
435 |
+
|
436 |
+
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): args.use_vllm = True
|
437 |
+
if ref_model is model:
|
438 |
+
raise ValueError(
|
439 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
440 |
+
"same as `model`, either omit the `ref_model` argument or pass `None`."
|
441 |
+
)
|
442 |
+
|
443 |
+
self.ref_model = ref_model
|
444 |
+
|
445 |
+
if reward_model is not None and judge is not None:
|
446 |
+
warnings.warn(
|
447 |
+
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
|
448 |
+
"Ignoring `judge` and using `reward_model`.",
|
449 |
+
UserWarning,
|
450 |
+
)
|
451 |
+
judge = None
|
452 |
+
elif reward_model is None and judge is None:
|
453 |
+
raise ValueError("Either `reward_model` or `judge` must be provided.")
|
454 |
+
|
455 |
+
self.reward_model = reward_model
|
456 |
+
self.reward_processing_class = reward_processing_class
|
457 |
+
self.judge = judge
|
458 |
+
|
459 |
+
if args.missing_eos_penalty is not None and judge is not None:
|
460 |
+
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
|
461 |
+
|
462 |
+
if args is None:
|
463 |
+
raise ValueError("`args` must be provided.")
|
464 |
+
|
465 |
+
# Check that the processing_class is provided
|
466 |
+
if processing_class is None:
|
467 |
+
raise ValueError("`processing_class` must be provided.")
|
468 |
+
|
469 |
+
# Convert to PEFT model if peft_config is provided
|
470 |
+
if False:
|
471 |
+
# Check if PEFT is available
|
472 |
+
if not is_peft_available():
|
473 |
+
raise ImportError(
|
474 |
+
"PEFT is not available and passed `peft_config`. Please install PEFT with "
|
475 |
+
"`pip install peft` to use it."
|
476 |
+
)
|
477 |
+
|
478 |
+
# If the model is already a PeftModel, we need to merge and unload it.
|
479 |
+
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
|
480 |
+
if isinstance(model, PeftModel):
|
481 |
+
model = model.merge_and_unload()
|
482 |
+
|
483 |
+
# Get peft model with the given config
|
484 |
+
model = model
|
485 |
+
|
486 |
+
# Disable dropout in the model and reference model
|
487 |
+
if args.disable_dropout:
|
488 |
+
disable_dropout_in_model(model)
|
489 |
+
if self.ref_model is not None:
|
490 |
+
disable_dropout_in_model(self.ref_model)
|
491 |
+
|
492 |
+
# Handle the ref_model
|
493 |
+
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
|
494 |
+
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
|
495 |
+
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
|
496 |
+
if ref_model is None: # No ref model provided, the most common case
|
497 |
+
if False:
|
498 |
+
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
|
499 |
+
else:
|
500 |
+
self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
|
501 |
+
else: # rare case, the user provided a ref model
|
502 |
+
self.ref_model = ref_model
|
503 |
+
self.ref_model.eval()
|
504 |
+
|
505 |
+
# Disable the gradient and set the reward model in eval mode
|
506 |
+
if self.reward_model is not None:
|
507 |
+
self.reward_model.eval()
|
508 |
+
|
509 |
+
# Define the collator is not provided
|
510 |
+
if data_collator is None:
|
511 |
+
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
|
512 |
+
|
513 |
+
self.max_length = args.max_length
|
514 |
+
|
515 |
+
self.stats = {
|
516 |
+
"objective/kl": [],
|
517 |
+
"objective/entropy": [],
|
518 |
+
"objective/non_score_reward": [],
|
519 |
+
"rewards/chosen": [],
|
520 |
+
"rewards/rejected": [],
|
521 |
+
"rewards/accuracies": [],
|
522 |
+
"rewards/margins": [],
|
523 |
+
"logps/chosen": [],
|
524 |
+
"logps/rejected": [],
|
525 |
+
"val/contain_eos_token": [],
|
526 |
+
"beta": [],
|
527 |
+
}
|
528 |
+
if self.reward_model is not None:
|
529 |
+
self.stats["objective/rlhf_reward"] = []
|
530 |
+
self.stats["objective/scores_margin"] = []
|
531 |
+
self.stats["objective/scores"] = []
|
532 |
+
|
533 |
+
if args.use_vllm:
|
534 |
+
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
|
535 |
+
n=2, max_tokens=args.max_new_tokens,
|
536 |
+
temperature=args.temperature,
|
537 |
+
top_k=50,
|
538 |
+
top_p=1.0,
|
539 |
+
detokenize=False,**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),)
|
540 |
+
else:
|
541 |
+
self.generation_config = GenerationConfig(
|
542 |
+
max_new_tokens=args.max_new_tokens,
|
543 |
+
temperature=args.temperature,
|
544 |
+
top_k=50,
|
545 |
+
top_p=1.0,
|
546 |
+
do_sample=True,
|
547 |
+
use_cache=False if args.gradient_checkpointing else True,
|
548 |
+
)
|
549 |
+
|
550 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
551 |
+
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
|
552 |
+
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
553 |
+
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
554 |
+
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
555 |
+
# that the warning has already been issued.
|
556 |
+
model.warnings_issued["estimate_tokens"] = True
|
557 |
+
|
558 |
+
super().__init__(
|
559 |
+
model=model,
|
560 |
+
args=args,
|
561 |
+
data_collator=data_collator,
|
562 |
+
train_dataset=train_dataset,
|
563 |
+
eval_dataset=eval_dataset,
|
564 |
+
processing_class=processing_class,
|
565 |
+
compute_metrics=compute_metrics,
|
566 |
+
callbacks=callbacks,
|
567 |
+
optimizers=optimizers,
|
568 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
569 |
+
)
|
570 |
+
|
571 |
+
# Add tags for models that have been loaded with the correct transformers version
|
572 |
+
if hasattr(self.model, "add_model_tags"):
|
573 |
+
self.model.add_model_tags(self._tag_names)
|
574 |
+
|
575 |
+
self._beta = args.beta
|
576 |
+
|
577 |
+
# Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator
|
578 |
+
if self.is_deepspeed_enabled:
|
579 |
+
if self.reward_model is not None:
|
580 |
+
self.reward_model = prepare_deepspeed(
|
581 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
582 |
+
)
|
583 |
+
if self.ref_model is not None:
|
584 |
+
self.ref_model = prepare_deepspeed(
|
585 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
if self.ref_model is not None:
|
589 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
590 |
+
if self.reward_model is not None:
|
591 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
592 |
+
|
593 |
+
@property
|
594 |
+
def beta(self):
|
595 |
+
if isinstance(self._beta, list):
|
596 |
+
epoch = self.state.epoch
|
597 |
+
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
|
598 |
+
else:
|
599 |
+
return self._beta
|
600 |
+
|
601 |
+
@staticmethod
|
602 |
+
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
|
603 |
+
"""Tokenize a single row from a DPO specific dataset."""
|
604 |
+
if not is_encoder_decoder:
|
605 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=False)
|
606 |
+
# Add BOS token to head of prompt. Avoid adding if it's already there
|
607 |
+
if tokenizer.bos_token_id is not None:
|
608 |
+
prompt_len_input_ids = len(batch["input_ids"])
|
609 |
+
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
|
610 |
+
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
|
611 |
+
batch["attention_mask"] = [1] + batch["attention_mask"]
|
612 |
+
else:
|
613 |
+
batch = tokenizer(feature["prompt"], add_special_tokens=True)
|
614 |
+
batch = {f"prompt_{key}": value for key, value in batch.items()}
|
615 |
+
return batch
|
616 |
+
|
617 |
+
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
|
618 |
+
@wraps(Trainer.get_train_dataloader)
|
619 |
+
def get_train_dataloader(self) -> DataLoader:
|
620 |
+
if self.train_dataset is None:
|
621 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
622 |
+
|
623 |
+
train_dataset = self.train_dataset
|
624 |
+
data_collator = self.data_collator
|
625 |
+
dataloader_params = {
|
626 |
+
"batch_size": self._train_batch_size,
|
627 |
+
"collate_fn": data_collator,
|
628 |
+
"num_workers": self.args.dataloader_num_workers,
|
629 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
630 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
631 |
+
}
|
632 |
+
|
633 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
634 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
635 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
636 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
637 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
638 |
+
|
639 |
+
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
640 |
+
|
641 |
+
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
642 |
+
@wraps(Trainer.get_eval_dataloader)
|
643 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
644 |
+
if eval_dataset is None and self.eval_dataset is None:
|
645 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
646 |
+
|
647 |
+
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
648 |
+
# don't change during training
|
649 |
+
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
650 |
+
if (
|
651 |
+
hasattr(self, "_eval_dataloaders")
|
652 |
+
and dataloader_key in self._eval_dataloaders
|
653 |
+
and self.args.dataloader_persistent_workers
|
654 |
+
):
|
655 |
+
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
656 |
+
|
657 |
+
eval_dataset = (
|
658 |
+
self.eval_dataset[eval_dataset]
|
659 |
+
if isinstance(eval_dataset, str)
|
660 |
+
else eval_dataset
|
661 |
+
if eval_dataset is not None
|
662 |
+
else self.eval_dataset
|
663 |
+
)
|
664 |
+
data_collator = self.data_collator
|
665 |
+
|
666 |
+
dataloader_params = {
|
667 |
+
"batch_size": self.args.eval_batch_size,
|
668 |
+
"collate_fn": data_collator,
|
669 |
+
"num_workers": self.args.dataloader_num_workers,
|
670 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
671 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
672 |
+
}
|
673 |
+
|
674 |
+
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
675 |
+
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
676 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
677 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
678 |
+
|
679 |
+
# accelerator.free_memory() will destroy the references, so
|
680 |
+
# we need to store the non-prepared version
|
681 |
+
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
682 |
+
if self.args.dataloader_persistent_workers:
|
683 |
+
if hasattr(self, "_eval_dataloaders"):
|
684 |
+
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
685 |
+
else:
|
686 |
+
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
687 |
+
|
688 |
+
return self.accelerator.prepare(eval_dataloader)
|
689 |
+
|
690 |
+
def _generate_vllm(self, model, prompts):
|
691 |
+
eos_token_id = self.processing_class.eos_token_id
|
692 |
+
pad_token_id = self.processing_class.pad_token_id
|
693 |
+
|
694 |
+
# Load the latest weights
|
695 |
+
|
696 |
+
pass
|
697 |
+
|
698 |
+
pass
|
699 |
+
|
700 |
+
if is_conversational({"prompt": prompts[0]}):
|
701 |
+
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
702 |
+
else:
|
703 |
+
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
704 |
+
|
705 |
+
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
|
706 |
+
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
|
707 |
+
|
708 |
+
# Create mask and pad the prompt and completion
|
709 |
+
max_prompt_length = max(len(ids) for ids in prompt_ids)
|
710 |
+
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
|
711 |
+
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
|
712 |
+
max_tokens = self.generation_config.max_tokens
|
713 |
+
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
|
714 |
+
completion_ids = [
|
715 |
+
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
|
716 |
+
for ids in completion_ids
|
717 |
+
]
|
718 |
+
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
|
719 |
+
|
720 |
+
# Convert to tensors
|
721 |
+
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
|
722 |
+
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
|
723 |
+
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
|
724 |
+
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
|
725 |
+
|
726 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
727 |
+
|
728 |
+
def _generate(self, model, prompts):
|
729 |
+
eos_token_id = self.processing_class.eos_token_id
|
730 |
+
pad_token_id = self.processing_class.pad_token_id
|
731 |
+
|
732 |
+
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
|
733 |
+
# policies with different tokenizers / chat templates.
|
734 |
+
inputs = [{"prompt": prompt} for prompt in prompts]
|
735 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
736 |
+
inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
737 |
+
inputs = self.data_collator(inputs)
|
738 |
+
|
739 |
+
# Sample 2 completions per prompt of size `max_new_tokens` from the model
|
740 |
+
inputs = self._prepare_inputs(inputs)
|
741 |
+
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
|
742 |
+
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
|
743 |
+
with unwrap_model_for_generation(
|
744 |
+
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
745 |
+
) as unwrapped_model:
|
746 |
+
output = unwrapped_model.generate(
|
747 |
+
input_ids=prompt_ids,
|
748 |
+
attention_mask=prompt_mask,
|
749 |
+
generation_config=self.generation_config,
|
750 |
+
)
|
751 |
+
|
752 |
+
completion_ids = output[:, prompt_ids.size(1) :]
|
753 |
+
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
|
754 |
+
|
755 |
+
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
756 |
+
|
757 |
+
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
|
758 |
+
# Get the number of tokens to truncate from prompt
|
759 |
+
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
|
760 |
+
|
761 |
+
# Truncate left to avoid oom
|
762 |
+
prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
|
763 |
+
prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
|
764 |
+
|
765 |
+
# Concat the prompt and completion
|
766 |
+
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
767 |
+
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
768 |
+
|
769 |
+
# Get the logprobs of the completions from the model
|
770 |
+
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
|
771 |
+
|
772 |
+
# There is 1 offset, because the model predict the next token
|
773 |
+
logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
|
774 |
+
|
775 |
+
# Take the completion tokens logprob
|
776 |
+
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
777 |
+
return logprobs
|
778 |
+
|
779 |
+
def training_step(
|
780 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
781 |
+
) -> torch.Tensor:
|
782 |
+
model.train()
|
783 |
+
|
784 |
+
prompts = inputs["prompt"]
|
785 |
+
batch_size = len(prompts)
|
786 |
+
|
787 |
+
if self.args.use_vllm:
|
788 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
|
789 |
+
else:
|
790 |
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
|
791 |
+
|
792 |
+
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
|
793 |
+
|
794 |
+
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
795 |
+
with torch.no_grad():
|
796 |
+
if self.ref_model is not None:
|
797 |
+
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
798 |
+
else: # peft case: we just need to disable the adapter
|
799 |
+
with self.model.disable_adapter():
|
800 |
+
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
801 |
+
|
802 |
+
# Decode the completions, and format them if the input is conversational
|
803 |
+
device = logprobs.device
|
804 |
+
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
805 |
+
if is_conversational({"prompt": prompts[0]}):
|
806 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
807 |
+
|
808 |
+
# Get the reward from the reward model or judge
|
809 |
+
if self.judge is not None:
|
810 |
+
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
|
811 |
+
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
|
812 |
+
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
|
813 |
+
# template to it.
|
814 |
+
if is_conversational({"prompt": prompts[0]}):
|
815 |
+
environment = jinja2.Environment()
|
816 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
817 |
+
prompts = [template.render(messages=prompt) for prompt in prompts]
|
818 |
+
completions = [template.render(messages=completion) for completion in completions]
|
819 |
+
|
820 |
+
ranks_of_first_completion = self.judge.judge(
|
821 |
+
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
822 |
+
)
|
823 |
+
|
824 |
+
# convert ranks to a True/False mask:
|
825 |
+
# when rank == 0, it means the first completion is the best
|
826 |
+
# when rank == 1, it means the second completion is the best
|
827 |
+
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
|
828 |
+
else:
|
829 |
+
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
|
830 |
+
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
|
831 |
+
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
|
832 |
+
if is_conversational({"prompt": prompts[0]}):
|
833 |
+
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
|
834 |
+
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
|
835 |
+
prompts = [example["prompt"] for example in examples]
|
836 |
+
completions = [example["completion"] for example in examples]
|
837 |
+
|
838 |
+
# Tokenize the prompts
|
839 |
+
prompts_ids = self.reward_processing_class(
|
840 |
+
prompts, padding=True, return_tensors="pt", padding_side="left"
|
841 |
+
)["input_ids"].to(device)
|
842 |
+
context_length = prompts_ids.shape[1]
|
843 |
+
|
844 |
+
# Tokenize the completions
|
845 |
+
completions_ids = self.reward_processing_class(
|
846 |
+
completions, padding=True, return_tensors="pt", padding_side="right"
|
847 |
+
)["input_ids"].to(device)
|
848 |
+
|
849 |
+
# Concatenate the prompts and completions and get the reward
|
850 |
+
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
|
851 |
+
with torch.inference_mode():
|
852 |
+
_, scores, _ = get_reward(
|
853 |
+
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
|
854 |
+
)
|
855 |
+
|
856 |
+
# Filter completion. Ensure that the sample contains stop_token_id
|
857 |
+
# Completions not passing that filter will receive a lower score.
|
858 |
+
if self.args.missing_eos_penalty is not None:
|
859 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
860 |
+
|
861 |
+
# Split the scores in 2 (the prompts of the first half are the same as the second half)
|
862 |
+
first_half, second_half = scores.split(batch_size)
|
863 |
+
|
864 |
+
# Get the indices of the chosen and rejected examples
|
865 |
+
mask = first_half >= second_half
|
866 |
+
|
867 |
+
batch_range = torch.arange(batch_size, device=device)
|
868 |
+
chosen_indices = batch_range + (~mask * batch_size)
|
869 |
+
rejected_indices = batch_range + (mask * batch_size)
|
870 |
+
|
871 |
+
# Build tensor so that the first half is the chosen examples and the second half the rejected examples
|
872 |
+
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
|
873 |
+
cr_logprobs = logprobs[cr_indices]
|
874 |
+
cr_ref_logprobs = ref_logprobs[cr_indices]
|
875 |
+
|
876 |
+
# mask out the padding tokens
|
877 |
+
padding_mask = ~completion_mask.bool()
|
878 |
+
cr_padding_mask = padding_mask[cr_indices]
|
879 |
+
|
880 |
+
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
|
881 |
+
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
|
882 |
+
|
883 |
+
# Split the chosen and rejected examples
|
884 |
+
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
|
885 |
+
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
|
886 |
+
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
|
887 |
+
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
|
888 |
+
|
889 |
+
logits = pi_logratios - ref_logratios
|
890 |
+
|
891 |
+
if self.args.loss_type == "sigmoid":
|
892 |
+
losses = -F.logsigmoid(self.beta * logits)
|
893 |
+
elif self.args.loss_type == "ipo":
|
894 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
895 |
+
else:
|
896 |
+
raise NotImplementedError(f"invalid loss type {self.loss_type}")
|
897 |
+
|
898 |
+
loss = losses.mean()
|
899 |
+
|
900 |
+
# Log everything
|
901 |
+
if self.reward_model is not None:
|
902 |
+
scores_margin = scores[chosen_indices] - scores[rejected_indices]
|
903 |
+
self.stats["objective/scores_margin"].append(
|
904 |
+
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
|
905 |
+
)
|
906 |
+
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
|
907 |
+
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
|
908 |
+
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
|
909 |
+
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
|
910 |
+
|
911 |
+
kl = logprobs - ref_logprobs
|
912 |
+
mean_kl = kl.sum(1).mean()
|
913 |
+
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
914 |
+
non_score_reward = (-self.beta * kl).sum(1)
|
915 |
+
mean_non_score_reward = non_score_reward.mean()
|
916 |
+
self.stats["objective/non_score_reward"].append(
|
917 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
918 |
+
)
|
919 |
+
if self.reward_model is not None:
|
920 |
+
rlhf_reward = scores + non_score_reward
|
921 |
+
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
|
922 |
+
mean_entropy = -logprobs.sum(1).mean()
|
923 |
+
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
|
924 |
+
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
|
925 |
+
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
|
926 |
+
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
|
927 |
+
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
|
928 |
+
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
|
929 |
+
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
|
930 |
+
margin = gathered_chosen_rewards - gathered_rejected_rewards
|
931 |
+
self.stats["rewards/margins"].append(margin.mean().item())
|
932 |
+
accuracy = margin > 0
|
933 |
+
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
|
934 |
+
self.stats["beta"].append(self.beta)
|
935 |
+
|
936 |
+
if (
|
937 |
+
self.args.torch_empty_cache_steps is not None
|
938 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
939 |
+
):
|
940 |
+
empty_cache()
|
941 |
+
|
942 |
+
kwargs = {}
|
943 |
+
|
944 |
+
# For LOMO optimizers you need to explicitly use the learnign rate
|
945 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
946 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
947 |
+
|
948 |
+
if self.args.n_gpu > 1:
|
949 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
950 |
+
|
951 |
+
if self.use_apex:
|
952 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
953 |
+
scaled_loss.backward()
|
954 |
+
else:
|
955 |
+
self.accelerator.backward(loss, **kwargs)
|
956 |
+
|
957 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
958 |
+
|
959 |
+
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
960 |
+
# start_time defaults to None to allow compatibility with transformers<=4.46
|
961 |
+
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
|
962 |
+
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
963 |
+
logs: dict[str, float] = {}
|
964 |
+
|
965 |
+
# all_gather + mean() to get average loss over all processes
|
966 |
+
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
967 |
+
|
968 |
+
# reset tr_loss to zero
|
969 |
+
tr_loss -= tr_loss
|
970 |
+
|
971 |
+
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
972 |
+
if grad_norm is not None:
|
973 |
+
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
974 |
+
logs["learning_rate"] = self._get_learning_rate()
|
975 |
+
|
976 |
+
# Add our metrics
|
977 |
+
for key, val in self.stats.items():
|
978 |
+
logs[key] = sum(val) / len(val)
|
979 |
+
self.stats = {key: [] for key in self.stats} # reset stats
|
980 |
+
|
981 |
+
self._total_loss_scalar += tr_loss_scalar
|
982 |
+
self._globalstep_last_logged = self.state.global_step
|
983 |
+
self.store_flos()
|
984 |
+
|
985 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
986 |
+
self.log(logs, start_time)
|
987 |
+
else: # transformers<=4.46
|
988 |
+
self.log(logs)
|
989 |
+
|
990 |
+
metrics = None
|
991 |
+
if self.control.should_evaluate:
|
992 |
+
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
993 |
+
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
994 |
+
|
995 |
+
if self.args.save_strategy == "best":
|
996 |
+
self.control.should_save = is_new_best_metric
|
997 |
+
|
998 |
+
if self.control.should_save:
|
999 |
+
self._save_checkpoint(model, trial)
|
1000 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
1001 |
+
|
1002 |
+
# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
|
1003 |
+
# This can be removed once the minimum transformers version is updated to 4.47.
|
1004 |
+
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
|
1005 |
+
def _determine_best_metric(self, metrics, trial):
|
1006 |
+
"""
|
1007 |
+
Determine if the model should be saved based on the evaluation metrics.
|
1008 |
+
If args.metric_for_best_model is not set, the loss is used.
|
1009 |
+
Returns:
|
1010 |
+
bool: True if a new best metric was found, else False
|
1011 |
+
"""
|
1012 |
+
is_new_best_metric = False
|
1013 |
+
|
1014 |
+
if self.args.metric_for_best_model is not None:
|
1015 |
+
metric_to_check = self.args.metric_for_best_model
|
1016 |
+
|
1017 |
+
if not metric_to_check.startswith("eval_"):
|
1018 |
+
metric_to_check = f"eval_{metric_to_check}"
|
1019 |
+
|
1020 |
+
try:
|
1021 |
+
metric_value = metrics[metric_to_check]
|
1022 |
+
except KeyError as exc:
|
1023 |
+
raise KeyError(
|
1024 |
+
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
1025 |
+
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
1026 |
+
) from exc
|
1027 |
+
|
1028 |
+
operator = np.greater if self.args.greater_is_better else np.less
|
1029 |
+
|
1030 |
+
if self.state.best_metric is None:
|
1031 |
+
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
1032 |
+
|
1033 |
+
if operator(metric_value, self.state.best_metric):
|
1034 |
+
run_dir = self._get_output_dir(trial=trial)
|
1035 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
1036 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
1037 |
+
self.state.best_metric = metric_value
|
1038 |
+
self.state.best_model_checkpoint = output_dir
|
1039 |
+
|
1040 |
+
is_new_best_metric = True
|
1041 |
+
|
1042 |
+
return is_new_best_metric
|
1043 |
+
|
1044 |
+
def create_model_card(
|
1045 |
+
self,
|
1046 |
+
model_name: Optional[str] = None,
|
1047 |
+
dataset_name: Optional[str] = None,
|
1048 |
+
tags: Union[str, list[str], None] = None,
|
1049 |
+
):
|
1050 |
+
"""
|
1051 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1052 |
+
|
1053 |
+
Args:
|
1054 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1055 |
+
Name of the model.
|
1056 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1057 |
+
Name of the dataset used for training.
|
1058 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1059 |
+
Tags to be associated with the model card.
|
1060 |
+
"""
|
1061 |
+
if not self.is_world_process_zero():
|
1062 |
+
return
|
1063 |
+
|
1064 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1065 |
+
base_model = self.model.config._name_or_path
|
1066 |
+
else:
|
1067 |
+
base_model = None
|
1068 |
+
|
1069 |
+
tags = tags or []
|
1070 |
+
if isinstance(tags, str):
|
1071 |
+
tags = [tags]
|
1072 |
+
|
1073 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1074 |
+
tags.append("unsloth")
|
1075 |
+
|
1076 |
+
citation = textwrap.dedent("""\
|
1077 |
+
@article{guo2024direct,
|
1078 |
+
title = {{Direct Language Model Alignment from Online AI Feedback}},
|
1079 |
+
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
|
1080 |
+
year = 2024,
|
1081 |
+
eprint = {arXiv:2402.04792}
|
1082 |
+
}""")
|
1083 |
+
|
1084 |
+
model_card = generate_model_card(
|
1085 |
+
base_model=base_model,
|
1086 |
+
model_name=model_name,
|
1087 |
+
hub_model_id=self.hub_model_id,
|
1088 |
+
dataset_name=dataset_name,
|
1089 |
+
tags=tags,
|
1090 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1091 |
+
comet_url=get_comet_experiment_url(),
|
1092 |
+
trainer_name="Online DPO",
|
1093 |
+
trainer_citation=citation,
|
1094 |
+
paper_title="Direct Language Model Alignment from Online AI Feedback",
|
1095 |
+
paper_id="2402.04792",
|
1096 |
+
)
|
1097 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1098 |
+
class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
|
1099 |
+
"""
|
1100 |
+
|
1101 |
+
Initialize OnlineDPOTrainer.
|
1102 |
+
|
1103 |
+
Args:
|
1104 |
+
model (`transformers.PreTrainedModel` or `torch.nn.Module`):
|
1105 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
1106 |
+
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1107 |
+
The reference model to use for training. If None is specified, the reference model will be created from
|
1108 |
+
the model.
|
1109 |
+
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
1110 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
1111 |
+
judge (`BasePairwiseJudge`):
|
1112 |
+
The judge to use for pairwise comparison of model completions.
|
1113 |
+
args (`OnlineDPOConfig`):
|
1114 |
+
The online DPO config arguments to use for training.
|
1115 |
+
data_collator (`transformers.DataCollator`):
|
1116 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
1117 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
1118 |
+
train_dataset (`datasets.Dataset`):
|
1119 |
+
The dataset to use for training.
|
1120 |
+
eval_dataset (`datasets.Dataset`):
|
1121 |
+
The dataset to use for evaluation.
|
1122 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
1123 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
1124 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
1125 |
+
reuse the fine-tuned model.
|
1126 |
+
peft_config (`dict`):
|
1127 |
+
The peft config to use for training.
|
1128 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
1129 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
1130 |
+
a dictionary string to metric values.
|
1131 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
1132 |
+
The callbacks to use for training.
|
1133 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
1134 |
+
The optimizer and scheduler to use for training.
|
1135 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
1136 |
+
The function to use to preprocess the logits before computing the metrics.
|
1137 |
+
|
1138 |
+
"""
|
1139 |
+
def __init__(
|
1140 |
+
self,
|
1141 |
+
model,
|
1142 |
+
ref_model = None,
|
1143 |
+
reward_model = None,
|
1144 |
+
judge = None,
|
1145 |
+
args = None,
|
1146 |
+
data_collator = None,
|
1147 |
+
train_dataset = None,
|
1148 |
+
eval_dataset = None,
|
1149 |
+
processing_class = None,
|
1150 |
+
reward_processing_class = None,
|
1151 |
+
peft_config = None,
|
1152 |
+
compute_metrics = None,
|
1153 |
+
callbacks = None,
|
1154 |
+
preprocess_logits_for_metrics = None,
|
1155 |
+
**kwargs
|
1156 |
+
):
|
1157 |
+
if args is None: args = UnslothOnlineDPOConfig()
|
1158 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1159 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1160 |
+
force_float32 = False
|
1161 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1162 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1163 |
+
force_float32 = True
|
1164 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1165 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1166 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1167 |
+
from unsloth_zoo.utils import _get_dtype
|
1168 |
+
dtype = _get_dtype(dtype)
|
1169 |
+
float16 = dtype == torch.float16
|
1170 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1171 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1172 |
+
if force_float32:
|
1173 |
+
args.fp16 = False
|
1174 |
+
args.bf16 = False
|
1175 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1176 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1177 |
+
args.fp16 = float16
|
1178 |
+
args.bf16 = not float16
|
1179 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1180 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1181 |
+
args.eval_strategy = 'steps'
|
1182 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1183 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1184 |
+
if ga_steps is not None and ga_steps > 1:
|
1185 |
+
from transformers import __version__ as transformers_version
|
1186 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1187 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1188 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1189 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1190 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1191 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1192 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1193 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1194 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1195 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1196 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1197 |
+
if force_float32:
|
1198 |
+
args.bf16_full_eval = False
|
1199 |
+
args.fp16_full_eval = False
|
1200 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1201 |
+
args.bf16_full_eval = True
|
1202 |
+
args.fp16_full_eval = False
|
1203 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1204 |
+
args.bf16_full_eval = args.bf16
|
1205 |
+
args.fp16_full_eval = args.fp16
|
1206 |
+
_output_logits = False
|
1207 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1208 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1209 |
+
if _output_logits:
|
1210 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1211 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1212 |
+
pass
|
1213 |
+
else:
|
1214 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1215 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1216 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1217 |
+
max_seq_length = model.max_seq_length
|
1218 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1219 |
+
if model is not None and hasattr(model, 'for_training'):
|
1220 |
+
model.for_training()
|
1221 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1222 |
+
if 'processing_class' in locals():
|
1223 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1224 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1225 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1226 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1227 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1228 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1229 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1230 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1231 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1232 |
+
else:
|
1233 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1234 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1235 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1236 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1237 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1238 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1239 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1240 |
+
else:
|
1241 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1242 |
+
other_metrics = []
|
1243 |
+
|
1244 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1245 |
+
PatchRLStatistics('online_dpo_trainer', other_metrics)
|
1246 |
+
|
1247 |
+
super().__init__(
|
1248 |
+
model = model,
|
1249 |
+
ref_model = ref_model,
|
1250 |
+
reward_model = reward_model,
|
1251 |
+
judge = judge,
|
1252 |
+
args = args,
|
1253 |
+
data_collator = data_collator,
|
1254 |
+
train_dataset = train_dataset,
|
1255 |
+
eval_dataset = eval_dataset,
|
1256 |
+
processing_class = processing_class,
|
1257 |
+
reward_processing_class = reward_processing_class,
|
1258 |
+
peft_config = peft_config,
|
1259 |
+
compute_metrics = compute_metrics,
|
1260 |
+
callbacks = callbacks,
|
1261 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
1262 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1263 |
+
self.neftune_hook_handle.remove()
|
1264 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1265 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1266 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1267 |
+
pass
|
1268 |
+
|
1269 |
+
pass
|
unsloth_compiled_cache/UnslothPPOTrainer.py
ADDED
@@ -0,0 +1,1259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothPPOConfig(PPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`PPOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
54 |
+
Name of this experiment.
|
55 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
56 |
+
Path to the reward model.
|
57 |
+
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
58 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
59 |
+
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
60 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
61 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
62 |
+
Number of epochs to train.
|
63 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
64 |
+
Whether to whiten the rewards.
|
65 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
66 |
+
KL coefficient.
|
67 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
68 |
+
Clip range.
|
69 |
+
vf_coef (`float`, *optional*, defaults to `0.1`):
|
70 |
+
Value function coefficient.
|
71 |
+
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
72 |
+
Clip range for the value function.
|
73 |
+
gamma (`float`, *optional*, defaults to `1.0`):
|
74 |
+
Discount factor.
|
75 |
+
lam (`float`, *optional*, defaults to `0.95`):
|
76 |
+
Lambda value for GAE.
|
77 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
78 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
79 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
80 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
81 |
+
|
82 |
+
"""
|
83 |
+
vllm_sampling_params: Optional[Any] = field(
|
84 |
+
default = None,
|
85 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
86 |
+
)
|
87 |
+
unsloth_num_chunks : Optional[int] = field(
|
88 |
+
default = -1,
|
89 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
90 |
+
)
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
output_dir = None,
|
94 |
+
overwrite_output_dir = None,
|
95 |
+
do_train = False,
|
96 |
+
do_eval = False,
|
97 |
+
do_predict = False,
|
98 |
+
eval_strategy = 'no',
|
99 |
+
prediction_loss_only = False,
|
100 |
+
per_device_train_batch_size = 4,
|
101 |
+
per_device_eval_batch_size = 4,
|
102 |
+
per_gpu_train_batch_size = None,
|
103 |
+
per_gpu_eval_batch_size = None,
|
104 |
+
gradient_accumulation_steps = 2,
|
105 |
+
eval_accumulation_steps = 2,
|
106 |
+
eval_delay = 0,
|
107 |
+
torch_empty_cache_steps = 250,
|
108 |
+
learning_rate = 5e-05,
|
109 |
+
weight_decay = 0.01,
|
110 |
+
adam_beta1 = 0.9,
|
111 |
+
adam_beta2 = 0.999,
|
112 |
+
adam_epsilon = 1e-08,
|
113 |
+
max_grad_norm = 1.0,
|
114 |
+
num_train_epochs = 3.0,
|
115 |
+
max_steps = -1,
|
116 |
+
lr_scheduler_type = 'linear',
|
117 |
+
warmup_ratio = 0.1,
|
118 |
+
warmup_steps = 0,
|
119 |
+
log_level = 'passive',
|
120 |
+
log_level_replica = 'warning',
|
121 |
+
log_on_each_node = True,
|
122 |
+
logging_dir = None,
|
123 |
+
logging_strategy = 'steps',
|
124 |
+
logging_first_step = False,
|
125 |
+
logging_steps = 1,
|
126 |
+
logging_nan_inf_filter = False,
|
127 |
+
save_strategy = 'steps',
|
128 |
+
save_steps = 500,
|
129 |
+
save_total_limit = None,
|
130 |
+
save_safetensors = True,
|
131 |
+
save_on_each_node = False,
|
132 |
+
save_only_model = False,
|
133 |
+
restore_callback_states_from_checkpoint = False,
|
134 |
+
no_cuda = False,
|
135 |
+
use_cpu = False,
|
136 |
+
use_mps_device = False,
|
137 |
+
seed = 3407,
|
138 |
+
data_seed = 3407,
|
139 |
+
jit_mode_eval = False,
|
140 |
+
use_ipex = False,
|
141 |
+
bf16 = False,
|
142 |
+
fp16 = False,
|
143 |
+
fp16_opt_level = 'O1',
|
144 |
+
half_precision_backend = 'auto',
|
145 |
+
bf16_full_eval = False,
|
146 |
+
fp16_full_eval = False,
|
147 |
+
tf32 = None,
|
148 |
+
local_rank = -1,
|
149 |
+
ddp_backend = None,
|
150 |
+
tpu_num_cores = None,
|
151 |
+
tpu_metrics_debug = False,
|
152 |
+
debug = '',
|
153 |
+
dataloader_drop_last = False,
|
154 |
+
eval_steps = None,
|
155 |
+
dataloader_num_workers = 0,
|
156 |
+
dataloader_prefetch_factor = None,
|
157 |
+
past_index = -1,
|
158 |
+
run_name = None,
|
159 |
+
disable_tqdm = None,
|
160 |
+
remove_unused_columns = True,
|
161 |
+
label_names = None,
|
162 |
+
load_best_model_at_end = False,
|
163 |
+
metric_for_best_model = None,
|
164 |
+
greater_is_better = None,
|
165 |
+
ignore_data_skip = False,
|
166 |
+
fsdp = '',
|
167 |
+
fsdp_min_num_params = 0,
|
168 |
+
fsdp_config = None,
|
169 |
+
tp_size = 0,
|
170 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
171 |
+
accelerator_config = None,
|
172 |
+
deepspeed = None,
|
173 |
+
label_smoothing_factor = 0.0,
|
174 |
+
optim = 'adamw_8bit',
|
175 |
+
optim_args = None,
|
176 |
+
adafactor = False,
|
177 |
+
group_by_length = False,
|
178 |
+
length_column_name = 'length',
|
179 |
+
report_to = None,
|
180 |
+
ddp_find_unused_parameters = None,
|
181 |
+
ddp_bucket_cap_mb = None,
|
182 |
+
ddp_broadcast_buffers = None,
|
183 |
+
dataloader_pin_memory = True,
|
184 |
+
dataloader_persistent_workers = False,
|
185 |
+
skip_memory_metrics = True,
|
186 |
+
use_legacy_prediction_loop = False,
|
187 |
+
push_to_hub = False,
|
188 |
+
resume_from_checkpoint = None,
|
189 |
+
hub_model_id = None,
|
190 |
+
hub_strategy = 'every_save',
|
191 |
+
hub_token = None,
|
192 |
+
hub_private_repo = None,
|
193 |
+
hub_always_push = False,
|
194 |
+
gradient_checkpointing = False,
|
195 |
+
gradient_checkpointing_kwargs = None,
|
196 |
+
include_inputs_for_metrics = False,
|
197 |
+
eval_do_concat_batches = True,
|
198 |
+
fp16_backend = 'auto',
|
199 |
+
evaluation_strategy = None,
|
200 |
+
push_to_hub_model_id = None,
|
201 |
+
push_to_hub_organization = None,
|
202 |
+
push_to_hub_token = None,
|
203 |
+
mp_parameters = '',
|
204 |
+
auto_find_batch_size = False,
|
205 |
+
full_determinism = False,
|
206 |
+
torchdynamo = None,
|
207 |
+
ray_scope = 'last',
|
208 |
+
ddp_timeout = 1800,
|
209 |
+
torch_compile = False,
|
210 |
+
torch_compile_backend = None,
|
211 |
+
torch_compile_mode = None,
|
212 |
+
dispatch_batches = None,
|
213 |
+
split_batches = None,
|
214 |
+
include_tokens_per_second = False,
|
215 |
+
include_num_input_tokens_seen = False,
|
216 |
+
neftune_noise_alpha = None,
|
217 |
+
optim_target_modules = None,
|
218 |
+
batch_eval_metrics = False,
|
219 |
+
eval_on_start = False,
|
220 |
+
use_liger_kernel = False,
|
221 |
+
eval_use_gather_object = False,
|
222 |
+
average_tokens_across_devices = False,
|
223 |
+
dataset_num_proc = None,
|
224 |
+
num_mini_batches = 1,
|
225 |
+
total_episodes = None,
|
226 |
+
local_rollout_forward_batch_size = 64,
|
227 |
+
num_sample_generations = 10,
|
228 |
+
response_length = 53,
|
229 |
+
stop_token = None,
|
230 |
+
stop_token_id = None,
|
231 |
+
temperature = 0.7,
|
232 |
+
missing_eos_penalty = None,
|
233 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
234 |
+
world_size = None,
|
235 |
+
num_total_batches = None,
|
236 |
+
micro_batch_size = None,
|
237 |
+
local_batch_size = None,
|
238 |
+
batch_size = None,
|
239 |
+
local_mini_batch_size = None,
|
240 |
+
mini_batch_size = None,
|
241 |
+
exp_name = 'ppo_config',
|
242 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
243 |
+
model_adapter_name = None,
|
244 |
+
ref_adapter_name = None,
|
245 |
+
num_ppo_epochs = 4,
|
246 |
+
whiten_rewards = False,
|
247 |
+
kl_coef = 0.05,
|
248 |
+
cliprange = 0.2,
|
249 |
+
vf_coef = 0.1,
|
250 |
+
cliprange_value = 0.2,
|
251 |
+
gamma = 1.0,
|
252 |
+
lam = 0.95,
|
253 |
+
ds3_gather_for_generation = True,
|
254 |
+
vllm_sampling_params = None,
|
255 |
+
unsloth_num_chunks = -1,
|
256 |
+
**kwargs,
|
257 |
+
):
|
258 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
259 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
260 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
261 |
+
output_dir = 'unsloth_training_checkpoints'
|
262 |
+
save_strategy = 'no'
|
263 |
+
if dataset_num_proc is None:
|
264 |
+
from multiprocessing import cpu_count
|
265 |
+
dataset_num_proc = cpu_count()
|
266 |
+
|
267 |
+
super().__init__(
|
268 |
+
output_dir = output_dir,
|
269 |
+
overwrite_output_dir = overwrite_output_dir,
|
270 |
+
do_train = do_train,
|
271 |
+
do_eval = do_eval,
|
272 |
+
do_predict = do_predict,
|
273 |
+
eval_strategy = eval_strategy,
|
274 |
+
prediction_loss_only = prediction_loss_only,
|
275 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
276 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
277 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
278 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
279 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
280 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
281 |
+
eval_delay = eval_delay,
|
282 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
283 |
+
learning_rate = learning_rate,
|
284 |
+
weight_decay = weight_decay,
|
285 |
+
adam_beta1 = adam_beta1,
|
286 |
+
adam_beta2 = adam_beta2,
|
287 |
+
adam_epsilon = adam_epsilon,
|
288 |
+
max_grad_norm = max_grad_norm,
|
289 |
+
num_train_epochs = num_train_epochs,
|
290 |
+
max_steps = max_steps,
|
291 |
+
lr_scheduler_type = lr_scheduler_type,
|
292 |
+
warmup_ratio = warmup_ratio,
|
293 |
+
warmup_steps = warmup_steps,
|
294 |
+
log_level = log_level,
|
295 |
+
log_level_replica = log_level_replica,
|
296 |
+
log_on_each_node = log_on_each_node,
|
297 |
+
logging_dir = logging_dir,
|
298 |
+
logging_strategy = logging_strategy,
|
299 |
+
logging_first_step = logging_first_step,
|
300 |
+
logging_steps = logging_steps,
|
301 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
302 |
+
save_strategy = save_strategy,
|
303 |
+
save_steps = save_steps,
|
304 |
+
save_total_limit = save_total_limit,
|
305 |
+
save_safetensors = save_safetensors,
|
306 |
+
save_on_each_node = save_on_each_node,
|
307 |
+
save_only_model = save_only_model,
|
308 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
309 |
+
no_cuda = no_cuda,
|
310 |
+
use_cpu = use_cpu,
|
311 |
+
use_mps_device = use_mps_device,
|
312 |
+
seed = seed,
|
313 |
+
data_seed = data_seed,
|
314 |
+
jit_mode_eval = jit_mode_eval,
|
315 |
+
use_ipex = use_ipex,
|
316 |
+
bf16 = bf16,
|
317 |
+
fp16 = fp16,
|
318 |
+
fp16_opt_level = fp16_opt_level,
|
319 |
+
half_precision_backend = half_precision_backend,
|
320 |
+
bf16_full_eval = bf16_full_eval,
|
321 |
+
fp16_full_eval = fp16_full_eval,
|
322 |
+
tf32 = tf32,
|
323 |
+
local_rank = local_rank,
|
324 |
+
ddp_backend = ddp_backend,
|
325 |
+
tpu_num_cores = tpu_num_cores,
|
326 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
327 |
+
debug = debug,
|
328 |
+
dataloader_drop_last = dataloader_drop_last,
|
329 |
+
eval_steps = eval_steps,
|
330 |
+
dataloader_num_workers = dataloader_num_workers,
|
331 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
332 |
+
past_index = past_index,
|
333 |
+
run_name = run_name,
|
334 |
+
disable_tqdm = disable_tqdm,
|
335 |
+
remove_unused_columns = remove_unused_columns,
|
336 |
+
label_names = label_names,
|
337 |
+
load_best_model_at_end = load_best_model_at_end,
|
338 |
+
metric_for_best_model = metric_for_best_model,
|
339 |
+
greater_is_better = greater_is_better,
|
340 |
+
ignore_data_skip = ignore_data_skip,
|
341 |
+
fsdp = fsdp,
|
342 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
343 |
+
fsdp_config = fsdp_config,
|
344 |
+
tp_size = tp_size,
|
345 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
346 |
+
accelerator_config = accelerator_config,
|
347 |
+
deepspeed = deepspeed,
|
348 |
+
label_smoothing_factor = label_smoothing_factor,
|
349 |
+
optim = optim,
|
350 |
+
optim_args = optim_args,
|
351 |
+
adafactor = adafactor,
|
352 |
+
group_by_length = group_by_length,
|
353 |
+
length_column_name = length_column_name,
|
354 |
+
report_to = report_to,
|
355 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
356 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
357 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
358 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
359 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
360 |
+
skip_memory_metrics = skip_memory_metrics,
|
361 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
362 |
+
push_to_hub = push_to_hub,
|
363 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
364 |
+
hub_model_id = hub_model_id,
|
365 |
+
hub_strategy = hub_strategy,
|
366 |
+
hub_token = hub_token,
|
367 |
+
hub_private_repo = hub_private_repo,
|
368 |
+
hub_always_push = hub_always_push,
|
369 |
+
gradient_checkpointing = gradient_checkpointing,
|
370 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
371 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
372 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
373 |
+
fp16_backend = fp16_backend,
|
374 |
+
evaluation_strategy = evaluation_strategy,
|
375 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
376 |
+
push_to_hub_organization = push_to_hub_organization,
|
377 |
+
push_to_hub_token = push_to_hub_token,
|
378 |
+
mp_parameters = mp_parameters,
|
379 |
+
auto_find_batch_size = auto_find_batch_size,
|
380 |
+
full_determinism = full_determinism,
|
381 |
+
torchdynamo = torchdynamo,
|
382 |
+
ray_scope = ray_scope,
|
383 |
+
ddp_timeout = ddp_timeout,
|
384 |
+
torch_compile = torch_compile,
|
385 |
+
torch_compile_backend = torch_compile_backend,
|
386 |
+
torch_compile_mode = torch_compile_mode,
|
387 |
+
dispatch_batches = dispatch_batches,
|
388 |
+
split_batches = split_batches,
|
389 |
+
include_tokens_per_second = include_tokens_per_second,
|
390 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
391 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
392 |
+
optim_target_modules = optim_target_modules,
|
393 |
+
batch_eval_metrics = batch_eval_metrics,
|
394 |
+
eval_on_start = eval_on_start,
|
395 |
+
use_liger_kernel = use_liger_kernel,
|
396 |
+
eval_use_gather_object = eval_use_gather_object,
|
397 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
398 |
+
dataset_num_proc = dataset_num_proc,
|
399 |
+
num_mini_batches = num_mini_batches,
|
400 |
+
total_episodes = total_episodes,
|
401 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
402 |
+
num_sample_generations = num_sample_generations,
|
403 |
+
response_length = response_length,
|
404 |
+
stop_token = stop_token,
|
405 |
+
stop_token_id = stop_token_id,
|
406 |
+
temperature = temperature,
|
407 |
+
missing_eos_penalty = missing_eos_penalty,
|
408 |
+
sft_model_path = sft_model_path,
|
409 |
+
world_size = world_size,
|
410 |
+
num_total_batches = num_total_batches,
|
411 |
+
micro_batch_size = micro_batch_size,
|
412 |
+
local_batch_size = local_batch_size,
|
413 |
+
batch_size = batch_size,
|
414 |
+
local_mini_batch_size = local_mini_batch_size,
|
415 |
+
mini_batch_size = mini_batch_size,
|
416 |
+
exp_name = exp_name,
|
417 |
+
reward_model_path = reward_model_path,
|
418 |
+
model_adapter_name = model_adapter_name,
|
419 |
+
ref_adapter_name = ref_adapter_name,
|
420 |
+
num_ppo_epochs = num_ppo_epochs,
|
421 |
+
whiten_rewards = whiten_rewards,
|
422 |
+
kl_coef = kl_coef,
|
423 |
+
cliprange = cliprange,
|
424 |
+
vf_coef = vf_coef,
|
425 |
+
cliprange_value = cliprange_value,
|
426 |
+
gamma = gamma,
|
427 |
+
lam = lam,
|
428 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
429 |
+
self.vllm_sampling_params = vllm_sampling_params
|
430 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
431 |
+
pass
|
432 |
+
|
433 |
+
class _UnslothPPOTrainer(Trainer):
|
434 |
+
_tag_names = ["trl", "ppo"]
|
435 |
+
|
436 |
+
def __init__(
|
437 |
+
self,
|
438 |
+
args: PPOConfig,
|
439 |
+
processing_class: Optional[
|
440 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
441 |
+
],
|
442 |
+
model: nn.Module,
|
443 |
+
ref_model: Optional[nn.Module],
|
444 |
+
reward_model: nn.Module,
|
445 |
+
train_dataset: Dataset,
|
446 |
+
value_model: Optional[nn.Module] = None,
|
447 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
448 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
449 |
+
# less commonly used
|
450 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
451 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
452 |
+
peft_config: Optional["PeftConfig"] = None,
|
453 |
+
) -> None:
|
454 |
+
if ref_model is model:
|
455 |
+
raise ValueError(
|
456 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
457 |
+
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
458 |
+
)
|
459 |
+
|
460 |
+
self.args = args
|
461 |
+
self.processing_class = processing_class
|
462 |
+
self.policy_model = model
|
463 |
+
|
464 |
+
# Define the collator if not provided
|
465 |
+
if data_collator is None:
|
466 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
467 |
+
|
468 |
+
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
469 |
+
if args.stop_token and args.stop_token_id:
|
470 |
+
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
471 |
+
elif args.stop_token:
|
472 |
+
if args.stop_token == "eos":
|
473 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
474 |
+
else:
|
475 |
+
raise ValueError(
|
476 |
+
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
477 |
+
)
|
478 |
+
else:
|
479 |
+
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
480 |
+
|
481 |
+
# peft support
|
482 |
+
if not is_peft_available() and peft_config is not None:
|
483 |
+
raise ImportError(
|
484 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
485 |
+
)
|
486 |
+
elif is_peft_available() and peft_config is not None:
|
487 |
+
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
488 |
+
if isinstance(self.policy_model, PeftModel):
|
489 |
+
self.policy_model = self.policy_model.merge_and_unload()
|
490 |
+
|
491 |
+
# get peft model with the given config
|
492 |
+
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
493 |
+
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
494 |
+
peft_module_casting_to_bf16(self.policy_model)
|
495 |
+
|
496 |
+
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
497 |
+
self.model_adapter_name = args.model_adapter_name
|
498 |
+
self.ref_adapter_name = args.ref_adapter_name
|
499 |
+
|
500 |
+
if ref_model:
|
501 |
+
self.ref_model = ref_model
|
502 |
+
elif self.is_peft_model:
|
503 |
+
self.ref_model = None
|
504 |
+
else:
|
505 |
+
self.ref_model = create_reference_model(self.policy_model)
|
506 |
+
|
507 |
+
self.reward_model = reward_model
|
508 |
+
self.train_dataset = train_dataset
|
509 |
+
self.train_dataset_len = len(train_dataset)
|
510 |
+
self.value_model = value_model
|
511 |
+
self.data_collator = data_collator
|
512 |
+
self.eval_dataset = eval_dataset
|
513 |
+
self.optimizer, self.lr_scheduler = optimizers
|
514 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
515 |
+
|
516 |
+
#########
|
517 |
+
# calculate various batch sizes
|
518 |
+
#########
|
519 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
520 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
521 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
522 |
+
self.accelerator = accelerator
|
523 |
+
args.world_size = accelerator.num_processes
|
524 |
+
args.local_batch_size = (
|
525 |
+
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
526 |
+
)
|
527 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
528 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
529 |
+
args.mini_batch_size = exact_div(
|
530 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
531 |
+
)
|
532 |
+
args.local_mini_batch_size = exact_div(
|
533 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
534 |
+
)
|
535 |
+
if args.whiten_rewards:
|
536 |
+
assert (
|
537 |
+
args.local_mini_batch_size >= 8
|
538 |
+
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
539 |
+
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
540 |
+
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
541 |
+
args.num_total_batches = math.ceil(
|
542 |
+
args.total_episodes / args.batch_size
|
543 |
+
) # we may train for more than `total_episodes`
|
544 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
545 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
546 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
547 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
548 |
+
if args.num_sample_generations > 0:
|
549 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
550 |
+
self.local_dataloader_batch_size = args.local_batch_size
|
551 |
+
|
552 |
+
#########
|
553 |
+
# setup model, optimizer, and others
|
554 |
+
#########
|
555 |
+
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
556 |
+
if module is not None:
|
557 |
+
disable_dropout_in_model(module)
|
558 |
+
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
559 |
+
self.model.config = self.policy_model.config # needed for pushing to hub
|
560 |
+
self.create_optimizer_and_scheduler(
|
561 |
+
num_training_steps=args.num_total_batches
|
562 |
+
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
563 |
+
|
564 |
+
#########
|
565 |
+
### trainer specifics
|
566 |
+
#########
|
567 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
568 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
569 |
+
self.callback_handler = CallbackHandler(
|
570 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
571 |
+
)
|
572 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
573 |
+
self.control = TrainerControl()
|
574 |
+
self.state = OnlineTrainerState(
|
575 |
+
is_local_process_zero=self.is_local_process_zero(),
|
576 |
+
is_world_process_zero=self.is_world_process_zero(),
|
577 |
+
stateful_callbacks=[
|
578 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
579 |
+
],
|
580 |
+
)
|
581 |
+
self.current_flos = 0
|
582 |
+
self.hp_search_backend = None
|
583 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
584 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
585 |
+
# Create distant repo and output directory if needed
|
586 |
+
self.hub_model_id = None
|
587 |
+
if self.args.push_to_hub:
|
588 |
+
self.init_hf_repo()
|
589 |
+
if self.args.should_save:
|
590 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
591 |
+
|
592 |
+
# Add tags for models that have been loaded with the correct transformers version
|
593 |
+
if hasattr(self.model, "add_model_tags"):
|
594 |
+
self.model.add_model_tags(self._tag_names)
|
595 |
+
|
596 |
+
#########
|
597 |
+
### setup dataloader
|
598 |
+
#########
|
599 |
+
self.dataloader = DataLoader(
|
600 |
+
self.train_dataset,
|
601 |
+
batch_size=self.local_dataloader_batch_size,
|
602 |
+
shuffle=True,
|
603 |
+
collate_fn=self.data_collator,
|
604 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
605 |
+
)
|
606 |
+
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
607 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
608 |
+
torch.manual_seed(args.seed)
|
609 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
610 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
611 |
+
|
612 |
+
self.eval_dataloader = DataLoader(
|
613 |
+
self.eval_dataset,
|
614 |
+
batch_size=args.per_device_eval_batch_size,
|
615 |
+
collate_fn=self.data_collator,
|
616 |
+
drop_last=True,
|
617 |
+
) # no need to shuffle eval dataset
|
618 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
619 |
+
|
620 |
+
if self.is_deepspeed_enabled:
|
621 |
+
self.reward_model = prepare_deepspeed(
|
622 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
623 |
+
)
|
624 |
+
|
625 |
+
if self.ref_model is None:
|
626 |
+
if not self.is_peft_model:
|
627 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
628 |
+
else:
|
629 |
+
self.ref_model = prepare_deepspeed(
|
630 |
+
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
631 |
+
)
|
632 |
+
else:
|
633 |
+
if self.ref_model is None:
|
634 |
+
if not self.is_peft_model:
|
635 |
+
raise ValueError("No reference model and model is not a Peft model.")
|
636 |
+
else:
|
637 |
+
self.ref_model = self.ref_model.to(self.accelerator.device)
|
638 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
639 |
+
|
640 |
+
def get_train_dataloader(self) -> DataLoader:
|
641 |
+
return self.dataloader
|
642 |
+
|
643 |
+
def get_eval_dataloader(self) -> DataLoader:
|
644 |
+
return self.eval_dataloader
|
645 |
+
|
646 |
+
@contextmanager
|
647 |
+
def null_ref_context(self):
|
648 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
649 |
+
with (
|
650 |
+
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
651 |
+
if self.is_peft_model and not self.ref_adapter_name
|
652 |
+
else nullcontext()
|
653 |
+
):
|
654 |
+
if self.ref_adapter_name:
|
655 |
+
self.model.policy.set_adapter(self.ref_adapter_name)
|
656 |
+
yield
|
657 |
+
if self.ref_adapter_name:
|
658 |
+
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
659 |
+
|
660 |
+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
661 |
+
backup_model = self.model
|
662 |
+
self.model = self.model.policy # save only the policy
|
663 |
+
|
664 |
+
if self.is_deepspeed_enabled:
|
665 |
+
backup_deepspeed = self.deepspeed
|
666 |
+
self.deepspeed = self.model
|
667 |
+
|
668 |
+
super().save_model(output_dir, _internal_call)
|
669 |
+
|
670 |
+
self.model = backup_model
|
671 |
+
|
672 |
+
if self.is_deepspeed_enabled:
|
673 |
+
self.deepspeed = backup_deepspeed
|
674 |
+
|
675 |
+
def train(self):
|
676 |
+
args = self.args
|
677 |
+
accelerator = self.accelerator
|
678 |
+
optimizer = self.optimizer
|
679 |
+
model = self.model
|
680 |
+
ref_policy = self.ref_model
|
681 |
+
reward_model = self.reward_model
|
682 |
+
processing_class = self.processing_class
|
683 |
+
dataloader = self.dataloader
|
684 |
+
device = accelerator.device
|
685 |
+
|
686 |
+
def repeat_generator():
|
687 |
+
while True:
|
688 |
+
yield from dataloader
|
689 |
+
|
690 |
+
iter_dataloader = iter(repeat_generator())
|
691 |
+
generation_config = GenerationConfig(
|
692 |
+
max_new_tokens=args.response_length,
|
693 |
+
temperature=(args.temperature + 1e-7),
|
694 |
+
top_k=0.0,
|
695 |
+
top_p=1.0,
|
696 |
+
do_sample=True,
|
697 |
+
)
|
698 |
+
|
699 |
+
accelerator.print("===training policy===")
|
700 |
+
start_time = time.time()
|
701 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
702 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
703 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
704 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
705 |
+
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
706 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
707 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
708 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
709 |
+
model.train()
|
710 |
+
|
711 |
+
# trainer state initialization
|
712 |
+
self.state.global_step = 0
|
713 |
+
self.state.episode = 0
|
714 |
+
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
715 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
716 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
717 |
+
if args.logging_steps is not None:
|
718 |
+
if args.logging_steps < 1:
|
719 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
720 |
+
else:
|
721 |
+
self.state.logging_steps = args.logging_steps
|
722 |
+
if args.eval_steps is not None:
|
723 |
+
if args.eval_steps < 1:
|
724 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
725 |
+
else:
|
726 |
+
self.state.eval_steps = args.eval_steps
|
727 |
+
if args.save_steps is not None:
|
728 |
+
if args.save_steps < 1:
|
729 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
730 |
+
else:
|
731 |
+
self.state.save_steps = args.save_steps
|
732 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
733 |
+
|
734 |
+
# backward compatibility
|
735 |
+
if self.is_deepspeed_enabled:
|
736 |
+
self.deepspeed = self.model
|
737 |
+
self.model_wrapped = self.model
|
738 |
+
|
739 |
+
for update in range(1, args.num_total_batches + 1):
|
740 |
+
self.state.episode += 1 * args.batch_size
|
741 |
+
data = next(iter_dataloader)
|
742 |
+
with torch.no_grad():
|
743 |
+
queries = data["input_ids"].to(device)
|
744 |
+
context_length = queries.shape[1]
|
745 |
+
responses = []
|
746 |
+
postprocessed_responses = []
|
747 |
+
logprobs = []
|
748 |
+
ref_logprobs = []
|
749 |
+
scores = []
|
750 |
+
sequence_lengths = []
|
751 |
+
values = []
|
752 |
+
with unwrap_model_for_generation(
|
753 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
754 |
+
) as unwrapped_model:
|
755 |
+
query_responses, logitss = batch_generation(
|
756 |
+
unwrapped_model.policy,
|
757 |
+
queries,
|
758 |
+
args.local_rollout_forward_batch_size,
|
759 |
+
processing_class.pad_token_id,
|
760 |
+
generation_config,
|
761 |
+
)
|
762 |
+
|
763 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
764 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
765 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
766 |
+
response = query_response[:, context_length:]
|
767 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
768 |
+
logprob = selective_log_softmax(logits, response)
|
769 |
+
del logits
|
770 |
+
torch.cuda.empty_cache()
|
771 |
+
|
772 |
+
if ref_policy is None:
|
773 |
+
with self.null_ref_context():
|
774 |
+
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
775 |
+
else:
|
776 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
777 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
778 |
+
ref_logits /= args.temperature + 1e-7
|
779 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
780 |
+
del ref_output, ref_logits
|
781 |
+
torch.cuda.empty_cache()
|
782 |
+
|
783 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
784 |
+
postprocessed_response = response
|
785 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
786 |
+
postprocessed_response = truncate_response(
|
787 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
788 |
+
)
|
789 |
+
|
790 |
+
# Response Processing 2. run reward model on the truncated responses
|
791 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
792 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
793 |
+
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
794 |
+
full_value, _, _ = get_reward(
|
795 |
+
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
796 |
+
)
|
797 |
+
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
798 |
+
_, score, _ = get_reward(
|
799 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
800 |
+
)
|
801 |
+
|
802 |
+
responses.append(response)
|
803 |
+
postprocessed_responses.append(postprocessed_response)
|
804 |
+
logprobs.append(logprob)
|
805 |
+
ref_logprobs.append(ref_logprob)
|
806 |
+
sequence_lengths.append(sequence_length)
|
807 |
+
scores.append(score)
|
808 |
+
values.append(value)
|
809 |
+
responses = torch.cat(responses, 0)
|
810 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
811 |
+
logprobs = torch.cat(logprobs, 0)
|
812 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
813 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
814 |
+
scores = torch.cat(scores, 0)
|
815 |
+
values = torch.cat(values, 0)
|
816 |
+
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
817 |
+
torch.cuda.empty_cache()
|
818 |
+
gc.collect()
|
819 |
+
|
820 |
+
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
821 |
+
# Completions not passing that filter will receive a lower score.
|
822 |
+
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
823 |
+
if self.args.missing_eos_penalty is not None:
|
824 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
825 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
826 |
+
|
827 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
828 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
829 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
830 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
831 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
832 |
+
sequence_lengths_p1 = sequence_lengths + 1
|
833 |
+
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
834 |
+
values = torch.masked_fill(values, padding_mask_p1, 0)
|
835 |
+
|
836 |
+
# 4. compute rewards
|
837 |
+
kl = logprobs - ref_logprobs
|
838 |
+
non_score_reward = -args.kl_coef * kl
|
839 |
+
rewards = non_score_reward.clone()
|
840 |
+
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
841 |
+
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
842 |
+
rewards[[actual_start, actual_end]] += scores
|
843 |
+
|
844 |
+
# 5. whiten rewards
|
845 |
+
if args.whiten_rewards:
|
846 |
+
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
847 |
+
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
848 |
+
|
849 |
+
# 6. compute advantages and returns
|
850 |
+
lastgaelam = 0
|
851 |
+
advantages_reversed = []
|
852 |
+
gen_length = responses.shape[1]
|
853 |
+
for t in reversed(range(gen_length)):
|
854 |
+
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
855 |
+
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
856 |
+
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
857 |
+
advantages_reversed.append(lastgaelam)
|
858 |
+
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
859 |
+
returns = advantages + values
|
860 |
+
advantages = masked_whiten(advantages, ~padding_mask)
|
861 |
+
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
862 |
+
torch.cuda.empty_cache()
|
863 |
+
|
864 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
865 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
866 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
867 |
+
minibatch_idx = 0
|
868 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
869 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
870 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
871 |
+
gradient_accumulation_idx = 0
|
872 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
873 |
+
with accelerator.accumulate(model):
|
874 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
875 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
876 |
+
mb_advantage = advantages[micro_batch_inds]
|
877 |
+
mb_responses = responses[micro_batch_inds]
|
878 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
879 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
880 |
+
mb_return = returns[micro_batch_inds]
|
881 |
+
mb_values = values[micro_batch_inds]
|
882 |
+
|
883 |
+
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
884 |
+
logits = output.logits[:, context_length - 1 : -1]
|
885 |
+
logits /= args.temperature + 1e-7
|
886 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
887 |
+
new_logprobs = torch.masked_fill(
|
888 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
889 |
+
)
|
890 |
+
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
891 |
+
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
892 |
+
vpredclipped = torch.clamp(
|
893 |
+
vpred,
|
894 |
+
mb_values - args.cliprange_value,
|
895 |
+
mb_values + args.cliprange_value,
|
896 |
+
)
|
897 |
+
vf_losses1 = torch.square(vpred - mb_return)
|
898 |
+
vf_losses2 = torch.square(vpredclipped - mb_return)
|
899 |
+
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
900 |
+
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
901 |
+
vf_clipfrac = masked_mean(
|
902 |
+
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
903 |
+
)
|
904 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
905 |
+
ratio = torch.exp(logprobs_diff)
|
906 |
+
pg_losses = -mb_advantage * ratio
|
907 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
908 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
909 |
+
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
910 |
+
loss = pg_loss + args.vf_coef * vf_loss
|
911 |
+
accelerator.backward(loss)
|
912 |
+
optimizer.step()
|
913 |
+
optimizer.zero_grad()
|
914 |
+
with torch.no_grad():
|
915 |
+
pg_clipfrac = masked_mean(
|
916 |
+
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
917 |
+
)
|
918 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
919 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
920 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
921 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
922 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
923 |
+
pg_clipfrac
|
924 |
+
)
|
925 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
926 |
+
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
927 |
+
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
928 |
+
vf_clipfrac
|
929 |
+
)
|
930 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
931 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
932 |
+
gradient_accumulation_idx += 1
|
933 |
+
minibatch_idx += 1
|
934 |
+
# del everything and empty cache
|
935 |
+
# fmt: off
|
936 |
+
del (
|
937 |
+
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
938 |
+
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
939 |
+
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
940 |
+
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
941 |
+
)
|
942 |
+
# fmt: on
|
943 |
+
torch.cuda.empty_cache()
|
944 |
+
with torch.no_grad():
|
945 |
+
mean_kl = kl.sum(1).mean()
|
946 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
947 |
+
mean_non_score_reward = non_score_reward.sum(1).mean()
|
948 |
+
rlhf_reward = mean_non_score_reward + scores.mean()
|
949 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
950 |
+
metrics = {}
|
951 |
+
metrics["eps"] = eps
|
952 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
953 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
954 |
+
metrics["objective/non_score_reward"] = (
|
955 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
956 |
+
)
|
957 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
958 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
959 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
960 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
961 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
962 |
+
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
963 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
964 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
965 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
966 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
967 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
968 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
969 |
+
metrics["episode"] = self.state.episode
|
970 |
+
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
971 |
+
self.state.global_step += 1
|
972 |
+
self.log(metrics)
|
973 |
+
|
974 |
+
self.lr_scheduler.step()
|
975 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
976 |
+
if self.control.should_save:
|
977 |
+
self._save_checkpoint(model, trial=None)
|
978 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
979 |
+
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
980 |
+
torch.cuda.empty_cache()
|
981 |
+
gc.collect()
|
982 |
+
|
983 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
984 |
+
self.generate_completions(sampling=True)
|
985 |
+
torch.cuda.empty_cache()
|
986 |
+
del (
|
987 |
+
query_responses,
|
988 |
+
responses,
|
989 |
+
postprocessed_responses,
|
990 |
+
logprobs,
|
991 |
+
ref_logprobs,
|
992 |
+
values,
|
993 |
+
sequence_lengths,
|
994 |
+
contain_eos_token,
|
995 |
+
sequence_lengths_p1,
|
996 |
+
response_idxs,
|
997 |
+
padding_mask,
|
998 |
+
padding_mask_p1,
|
999 |
+
rewards,
|
1000 |
+
actual_start,
|
1001 |
+
actual_end,
|
1002 |
+
advantages,
|
1003 |
+
returns,
|
1004 |
+
)
|
1005 |
+
torch.cuda.empty_cache()
|
1006 |
+
|
1007 |
+
# HF trainer specifics
|
1008 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
1009 |
+
if self.control.should_save:
|
1010 |
+
self._save_checkpoint(model, trial=None, metrics=None)
|
1011 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
1012 |
+
|
1013 |
+
def generate_completions(self, sampling: bool = False):
|
1014 |
+
args = self.args
|
1015 |
+
processing_class = self.processing_class
|
1016 |
+
generation_config = GenerationConfig(
|
1017 |
+
max_new_tokens=self.args.response_length,
|
1018 |
+
temperature=(0.01 + 1e-7),
|
1019 |
+
top_k=0.0,
|
1020 |
+
top_p=1.0,
|
1021 |
+
do_sample=True,
|
1022 |
+
)
|
1023 |
+
|
1024 |
+
table = defaultdict(list)
|
1025 |
+
with unwrap_model_for_generation(
|
1026 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
1027 |
+
) as unwrapped_model:
|
1028 |
+
for batch in self.eval_dataloader:
|
1029 |
+
query = batch["input_ids"]
|
1030 |
+
with torch.no_grad():
|
1031 |
+
context_length = query.shape[1]
|
1032 |
+
query_response, _ = batch_generation(
|
1033 |
+
unwrapped_model.policy,
|
1034 |
+
query,
|
1035 |
+
query.shape[0],
|
1036 |
+
processing_class.pad_token_id,
|
1037 |
+
generation_config,
|
1038 |
+
)
|
1039 |
+
response = query_response[:, context_length:]
|
1040 |
+
postprocessed_response = response
|
1041 |
+
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
1042 |
+
postprocessed_response = truncate_response(
|
1043 |
+
self.stop_token_id, processing_class.pad_token_id, response
|
1044 |
+
)
|
1045 |
+
table["query"].extend(
|
1046 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
1047 |
+
)
|
1048 |
+
table["model response"].extend(
|
1049 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
1053 |
+
_, score, _ = get_reward(
|
1054 |
+
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
1055 |
+
)
|
1056 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
1057 |
+
|
1058 |
+
if sampling:
|
1059 |
+
break
|
1060 |
+
df = pd.DataFrame(table)
|
1061 |
+
|
1062 |
+
if self.accelerator.is_main_process:
|
1063 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
1064 |
+
if "wandb" in args.report_to:
|
1065 |
+
import wandb
|
1066 |
+
|
1067 |
+
if wandb.run is not None:
|
1068 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
1069 |
+
|
1070 |
+
if "comet_ml" in args.report_to:
|
1071 |
+
log_table_to_comet_experiment(
|
1072 |
+
name="completions.csv",
|
1073 |
+
table=df,
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
def create_model_card(
|
1077 |
+
self,
|
1078 |
+
model_name: Optional[str] = None,
|
1079 |
+
dataset_name: Optional[str] = None,
|
1080 |
+
tags: Union[str, list[str], None] = None,
|
1081 |
+
):
|
1082 |
+
"""
|
1083 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1084 |
+
|
1085 |
+
Args:
|
1086 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1087 |
+
Name of the model.
|
1088 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1089 |
+
Name of the dataset used for training.
|
1090 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1091 |
+
Tags to be associated with the model card.
|
1092 |
+
"""
|
1093 |
+
if not self.is_world_process_zero():
|
1094 |
+
return
|
1095 |
+
|
1096 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1097 |
+
base_model = self.model.config._name_or_path
|
1098 |
+
else:
|
1099 |
+
base_model = None
|
1100 |
+
|
1101 |
+
tags = tags or []
|
1102 |
+
if isinstance(tags, str):
|
1103 |
+
tags = [tags]
|
1104 |
+
|
1105 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1106 |
+
tags.append("unsloth")
|
1107 |
+
|
1108 |
+
citation = textwrap.dedent("""\
|
1109 |
+
@article{mziegler2019fine-tuning,
|
1110 |
+
title = {{Fine-Tuning Language Models from Human Preferences}},
|
1111 |
+
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
1112 |
+
year = 2019,
|
1113 |
+
eprint = {arXiv:1909.08593}
|
1114 |
+
}""")
|
1115 |
+
|
1116 |
+
model_card = generate_model_card(
|
1117 |
+
base_model=base_model,
|
1118 |
+
model_name=model_name,
|
1119 |
+
hub_model_id=self.hub_model_id,
|
1120 |
+
dataset_name=dataset_name,
|
1121 |
+
tags=tags,
|
1122 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1123 |
+
comet_url=get_comet_experiment_url(),
|
1124 |
+
trainer_name="PPO",
|
1125 |
+
trainer_citation=citation,
|
1126 |
+
paper_title="Fine-Tuning Language Models from Human Preferences",
|
1127 |
+
paper_id="1909.08593",
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1131 |
+
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
1132 |
+
"""
|
1133 |
+
|
1134 |
+
"""
|
1135 |
+
def __init__(
|
1136 |
+
self,
|
1137 |
+
args,
|
1138 |
+
processing_class,
|
1139 |
+
model,
|
1140 |
+
ref_model,
|
1141 |
+
reward_model,
|
1142 |
+
train_dataset,
|
1143 |
+
value_model = None,
|
1144 |
+
data_collator = None,
|
1145 |
+
eval_dataset = None,
|
1146 |
+
callbacks = None,
|
1147 |
+
peft_config = None,
|
1148 |
+
**kwargs
|
1149 |
+
):
|
1150 |
+
if args is None: args = UnslothPPOConfig()
|
1151 |
+
use_bf16 = getattr(args, 'bf16', False)
|
1152 |
+
use_fp16 = getattr(args, 'fp16', False)
|
1153 |
+
force_float32 = False
|
1154 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
1155 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
1156 |
+
force_float32 = True
|
1157 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
1158 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
1159 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
1160 |
+
from unsloth_zoo.utils import _get_dtype
|
1161 |
+
dtype = _get_dtype(dtype)
|
1162 |
+
float16 = dtype == torch.float16
|
1163 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
1164 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
1165 |
+
if force_float32:
|
1166 |
+
args.fp16 = False
|
1167 |
+
args.bf16 = False
|
1168 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
1169 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
1170 |
+
args.fp16 = float16
|
1171 |
+
args.bf16 = not float16
|
1172 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
1173 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
1174 |
+
args.eval_strategy = 'steps'
|
1175 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
1176 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
1177 |
+
if ga_steps is not None and ga_steps > 1:
|
1178 |
+
from transformers import __version__ as transformers_version
|
1179 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
1180 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
1181 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
1182 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
1183 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
1184 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
1185 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
1186 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
1187 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
1188 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
1189 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
1190 |
+
if force_float32:
|
1191 |
+
args.bf16_full_eval = False
|
1192 |
+
args.fp16_full_eval = False
|
1193 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
1194 |
+
args.bf16_full_eval = True
|
1195 |
+
args.fp16_full_eval = False
|
1196 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
1197 |
+
args.bf16_full_eval = args.bf16
|
1198 |
+
args.fp16_full_eval = args.fp16
|
1199 |
+
_output_logits = False
|
1200 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1201 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1202 |
+
if _output_logits:
|
1203 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1204 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1205 |
+
pass
|
1206 |
+
else:
|
1207 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1208 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1209 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1210 |
+
max_seq_length = model.max_seq_length
|
1211 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1212 |
+
if model is not None and hasattr(model, 'for_training'):
|
1213 |
+
model.for_training()
|
1214 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1215 |
+
if 'processing_class' in locals():
|
1216 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1217 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1218 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1219 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1220 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1221 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1222 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1223 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1224 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1225 |
+
else:
|
1226 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1227 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1228 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1229 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1230 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1231 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1232 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1233 |
+
else:
|
1234 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1235 |
+
other_metrics = []
|
1236 |
+
|
1237 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1238 |
+
PatchRLStatistics('ppo_trainer', other_metrics)
|
1239 |
+
|
1240 |
+
super().__init__(
|
1241 |
+
args = args,
|
1242 |
+
processing_class = processing_class,
|
1243 |
+
model = model,
|
1244 |
+
ref_model = ref_model,
|
1245 |
+
reward_model = reward_model,
|
1246 |
+
train_dataset = train_dataset,
|
1247 |
+
value_model = value_model,
|
1248 |
+
data_collator = data_collator,
|
1249 |
+
eval_dataset = eval_dataset,
|
1250 |
+
callbacks = callbacks,
|
1251 |
+
peft_config = peft_config,**kwargs)
|
1252 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1253 |
+
self.neftune_hook_handle.remove()
|
1254 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1255 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1256 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1257 |
+
pass
|
1258 |
+
|
1259 |
+
pass
|
unsloth_compiled_cache/UnslothPRMTrainer.py
ADDED
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothPRMConfig(PRMConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`PRMTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
54 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
55 |
+
[`~transformers.TrainingArguments`].
|
56 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
57 |
+
Maximum length of the sequences (prompt + completion) used for truncation.
|
58 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
59 |
+
Maximum length of the prompt used for truncation.
|
60 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
61 |
+
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
62 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
63 |
+
Whether to disable dropout in the model.
|
64 |
+
step_separator (`str`, *optional*, defaults to `"\n"`):
|
65 |
+
Separator used to separate each step of the reasoning process.
|
66 |
+
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
67 |
+
Whether to train only on the last step.
|
68 |
+
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
69 |
+
Number of processes to use for processing the dataset.
|
70 |
+
|
71 |
+
"""
|
72 |
+
vllm_sampling_params: Optional[Any] = field(
|
73 |
+
default = None,
|
74 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
75 |
+
)
|
76 |
+
unsloth_num_chunks : Optional[int] = field(
|
77 |
+
default = -1,
|
78 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
79 |
+
)
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
output_dir = None,
|
83 |
+
overwrite_output_dir = None,
|
84 |
+
do_train = False,
|
85 |
+
do_eval = False,
|
86 |
+
do_predict = False,
|
87 |
+
eval_strategy = 'no',
|
88 |
+
prediction_loss_only = False,
|
89 |
+
per_device_train_batch_size = 4,
|
90 |
+
per_device_eval_batch_size = 4,
|
91 |
+
per_gpu_train_batch_size = None,
|
92 |
+
per_gpu_eval_batch_size = None,
|
93 |
+
gradient_accumulation_steps = 2,
|
94 |
+
eval_accumulation_steps = 2,
|
95 |
+
eval_delay = 0,
|
96 |
+
torch_empty_cache_steps = 250,
|
97 |
+
learning_rate = 5e-05,
|
98 |
+
weight_decay = 0.01,
|
99 |
+
adam_beta1 = 0.9,
|
100 |
+
adam_beta2 = 0.999,
|
101 |
+
adam_epsilon = 1e-08,
|
102 |
+
max_grad_norm = 1.0,
|
103 |
+
num_train_epochs = 3.0,
|
104 |
+
max_steps = -1,
|
105 |
+
lr_scheduler_type = 'linear',
|
106 |
+
warmup_ratio = 0.1,
|
107 |
+
warmup_steps = 0,
|
108 |
+
log_level = 'passive',
|
109 |
+
log_level_replica = 'warning',
|
110 |
+
log_on_each_node = True,
|
111 |
+
logging_dir = None,
|
112 |
+
logging_strategy = 'steps',
|
113 |
+
logging_first_step = False,
|
114 |
+
logging_steps = 1,
|
115 |
+
logging_nan_inf_filter = False,
|
116 |
+
save_strategy = 'steps',
|
117 |
+
save_steps = 500,
|
118 |
+
save_total_limit = None,
|
119 |
+
save_safetensors = True,
|
120 |
+
save_on_each_node = False,
|
121 |
+
save_only_model = False,
|
122 |
+
restore_callback_states_from_checkpoint = False,
|
123 |
+
no_cuda = False,
|
124 |
+
use_cpu = False,
|
125 |
+
use_mps_device = False,
|
126 |
+
seed = 3407,
|
127 |
+
data_seed = 3407,
|
128 |
+
jit_mode_eval = False,
|
129 |
+
use_ipex = False,
|
130 |
+
bf16 = False,
|
131 |
+
fp16 = False,
|
132 |
+
fp16_opt_level = 'O1',
|
133 |
+
half_precision_backend = 'auto',
|
134 |
+
bf16_full_eval = False,
|
135 |
+
fp16_full_eval = False,
|
136 |
+
tf32 = None,
|
137 |
+
local_rank = -1,
|
138 |
+
ddp_backend = None,
|
139 |
+
tpu_num_cores = None,
|
140 |
+
tpu_metrics_debug = False,
|
141 |
+
debug = '',
|
142 |
+
dataloader_drop_last = False,
|
143 |
+
eval_steps = None,
|
144 |
+
dataloader_num_workers = 0,
|
145 |
+
dataloader_prefetch_factor = None,
|
146 |
+
past_index = -1,
|
147 |
+
run_name = None,
|
148 |
+
disable_tqdm = None,
|
149 |
+
remove_unused_columns = True,
|
150 |
+
label_names = None,
|
151 |
+
load_best_model_at_end = False,
|
152 |
+
metric_for_best_model = None,
|
153 |
+
greater_is_better = None,
|
154 |
+
ignore_data_skip = False,
|
155 |
+
fsdp = '',
|
156 |
+
fsdp_min_num_params = 0,
|
157 |
+
fsdp_config = None,
|
158 |
+
tp_size = 0,
|
159 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
160 |
+
accelerator_config = None,
|
161 |
+
deepspeed = None,
|
162 |
+
label_smoothing_factor = 0.0,
|
163 |
+
optim = 'adamw_8bit',
|
164 |
+
optim_args = None,
|
165 |
+
adafactor = False,
|
166 |
+
group_by_length = False,
|
167 |
+
length_column_name = 'length',
|
168 |
+
report_to = None,
|
169 |
+
ddp_find_unused_parameters = None,
|
170 |
+
ddp_bucket_cap_mb = None,
|
171 |
+
ddp_broadcast_buffers = None,
|
172 |
+
dataloader_pin_memory = True,
|
173 |
+
dataloader_persistent_workers = False,
|
174 |
+
skip_memory_metrics = True,
|
175 |
+
use_legacy_prediction_loop = False,
|
176 |
+
push_to_hub = False,
|
177 |
+
resume_from_checkpoint = None,
|
178 |
+
hub_model_id = None,
|
179 |
+
hub_strategy = 'every_save',
|
180 |
+
hub_token = None,
|
181 |
+
hub_private_repo = None,
|
182 |
+
hub_always_push = False,
|
183 |
+
gradient_checkpointing = False,
|
184 |
+
gradient_checkpointing_kwargs = None,
|
185 |
+
include_inputs_for_metrics = False,
|
186 |
+
eval_do_concat_batches = True,
|
187 |
+
fp16_backend = 'auto',
|
188 |
+
evaluation_strategy = None,
|
189 |
+
push_to_hub_model_id = None,
|
190 |
+
push_to_hub_organization = None,
|
191 |
+
push_to_hub_token = None,
|
192 |
+
mp_parameters = '',
|
193 |
+
auto_find_batch_size = False,
|
194 |
+
full_determinism = False,
|
195 |
+
torchdynamo = None,
|
196 |
+
ray_scope = 'last',
|
197 |
+
ddp_timeout = 1800,
|
198 |
+
torch_compile = False,
|
199 |
+
torch_compile_backend = None,
|
200 |
+
torch_compile_mode = None,
|
201 |
+
dispatch_batches = None,
|
202 |
+
split_batches = None,
|
203 |
+
include_tokens_per_second = False,
|
204 |
+
include_num_input_tokens_seen = False,
|
205 |
+
neftune_noise_alpha = None,
|
206 |
+
optim_target_modules = None,
|
207 |
+
batch_eval_metrics = False,
|
208 |
+
eval_on_start = False,
|
209 |
+
use_liger_kernel = False,
|
210 |
+
eval_use_gather_object = False,
|
211 |
+
average_tokens_across_devices = False,
|
212 |
+
max_length = 1024,
|
213 |
+
max_prompt_length = 512,
|
214 |
+
max_completion_length = None,
|
215 |
+
disable_dropout = True,
|
216 |
+
step_separator = '\
|
217 |
+
',
|
218 |
+
train_on_last_step_only = False,
|
219 |
+
dataset_num_proc = None,
|
220 |
+
vllm_sampling_params = None,
|
221 |
+
unsloth_num_chunks = -1,
|
222 |
+
**kwargs,
|
223 |
+
):
|
224 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
225 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
226 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
227 |
+
output_dir = 'unsloth_training_checkpoints'
|
228 |
+
save_strategy = 'no'
|
229 |
+
if dataset_num_proc is None:
|
230 |
+
from multiprocessing import cpu_count
|
231 |
+
dataset_num_proc = cpu_count()
|
232 |
+
|
233 |
+
super().__init__(
|
234 |
+
output_dir = output_dir,
|
235 |
+
overwrite_output_dir = overwrite_output_dir,
|
236 |
+
do_train = do_train,
|
237 |
+
do_eval = do_eval,
|
238 |
+
do_predict = do_predict,
|
239 |
+
eval_strategy = eval_strategy,
|
240 |
+
prediction_loss_only = prediction_loss_only,
|
241 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
242 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
243 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
244 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
245 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
246 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
247 |
+
eval_delay = eval_delay,
|
248 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
249 |
+
learning_rate = learning_rate,
|
250 |
+
weight_decay = weight_decay,
|
251 |
+
adam_beta1 = adam_beta1,
|
252 |
+
adam_beta2 = adam_beta2,
|
253 |
+
adam_epsilon = adam_epsilon,
|
254 |
+
max_grad_norm = max_grad_norm,
|
255 |
+
num_train_epochs = num_train_epochs,
|
256 |
+
max_steps = max_steps,
|
257 |
+
lr_scheduler_type = lr_scheduler_type,
|
258 |
+
warmup_ratio = warmup_ratio,
|
259 |
+
warmup_steps = warmup_steps,
|
260 |
+
log_level = log_level,
|
261 |
+
log_level_replica = log_level_replica,
|
262 |
+
log_on_each_node = log_on_each_node,
|
263 |
+
logging_dir = logging_dir,
|
264 |
+
logging_strategy = logging_strategy,
|
265 |
+
logging_first_step = logging_first_step,
|
266 |
+
logging_steps = logging_steps,
|
267 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
268 |
+
save_strategy = save_strategy,
|
269 |
+
save_steps = save_steps,
|
270 |
+
save_total_limit = save_total_limit,
|
271 |
+
save_safetensors = save_safetensors,
|
272 |
+
save_on_each_node = save_on_each_node,
|
273 |
+
save_only_model = save_only_model,
|
274 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
275 |
+
no_cuda = no_cuda,
|
276 |
+
use_cpu = use_cpu,
|
277 |
+
use_mps_device = use_mps_device,
|
278 |
+
seed = seed,
|
279 |
+
data_seed = data_seed,
|
280 |
+
jit_mode_eval = jit_mode_eval,
|
281 |
+
use_ipex = use_ipex,
|
282 |
+
bf16 = bf16,
|
283 |
+
fp16 = fp16,
|
284 |
+
fp16_opt_level = fp16_opt_level,
|
285 |
+
half_precision_backend = half_precision_backend,
|
286 |
+
bf16_full_eval = bf16_full_eval,
|
287 |
+
fp16_full_eval = fp16_full_eval,
|
288 |
+
tf32 = tf32,
|
289 |
+
local_rank = local_rank,
|
290 |
+
ddp_backend = ddp_backend,
|
291 |
+
tpu_num_cores = tpu_num_cores,
|
292 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
293 |
+
debug = debug,
|
294 |
+
dataloader_drop_last = dataloader_drop_last,
|
295 |
+
eval_steps = eval_steps,
|
296 |
+
dataloader_num_workers = dataloader_num_workers,
|
297 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
298 |
+
past_index = past_index,
|
299 |
+
run_name = run_name,
|
300 |
+
disable_tqdm = disable_tqdm,
|
301 |
+
remove_unused_columns = remove_unused_columns,
|
302 |
+
label_names = label_names,
|
303 |
+
load_best_model_at_end = load_best_model_at_end,
|
304 |
+
metric_for_best_model = metric_for_best_model,
|
305 |
+
greater_is_better = greater_is_better,
|
306 |
+
ignore_data_skip = ignore_data_skip,
|
307 |
+
fsdp = fsdp,
|
308 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
309 |
+
fsdp_config = fsdp_config,
|
310 |
+
tp_size = tp_size,
|
311 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
312 |
+
accelerator_config = accelerator_config,
|
313 |
+
deepspeed = deepspeed,
|
314 |
+
label_smoothing_factor = label_smoothing_factor,
|
315 |
+
optim = optim,
|
316 |
+
optim_args = optim_args,
|
317 |
+
adafactor = adafactor,
|
318 |
+
group_by_length = group_by_length,
|
319 |
+
length_column_name = length_column_name,
|
320 |
+
report_to = report_to,
|
321 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
322 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
323 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
324 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
325 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
326 |
+
skip_memory_metrics = skip_memory_metrics,
|
327 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
328 |
+
push_to_hub = push_to_hub,
|
329 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
330 |
+
hub_model_id = hub_model_id,
|
331 |
+
hub_strategy = hub_strategy,
|
332 |
+
hub_token = hub_token,
|
333 |
+
hub_private_repo = hub_private_repo,
|
334 |
+
hub_always_push = hub_always_push,
|
335 |
+
gradient_checkpointing = gradient_checkpointing,
|
336 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
337 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
338 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
339 |
+
fp16_backend = fp16_backend,
|
340 |
+
evaluation_strategy = evaluation_strategy,
|
341 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
342 |
+
push_to_hub_organization = push_to_hub_organization,
|
343 |
+
push_to_hub_token = push_to_hub_token,
|
344 |
+
mp_parameters = mp_parameters,
|
345 |
+
auto_find_batch_size = auto_find_batch_size,
|
346 |
+
full_determinism = full_determinism,
|
347 |
+
torchdynamo = torchdynamo,
|
348 |
+
ray_scope = ray_scope,
|
349 |
+
ddp_timeout = ddp_timeout,
|
350 |
+
torch_compile = torch_compile,
|
351 |
+
torch_compile_backend = torch_compile_backend,
|
352 |
+
torch_compile_mode = torch_compile_mode,
|
353 |
+
dispatch_batches = dispatch_batches,
|
354 |
+
split_batches = split_batches,
|
355 |
+
include_tokens_per_second = include_tokens_per_second,
|
356 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
357 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
358 |
+
optim_target_modules = optim_target_modules,
|
359 |
+
batch_eval_metrics = batch_eval_metrics,
|
360 |
+
eval_on_start = eval_on_start,
|
361 |
+
use_liger_kernel = use_liger_kernel,
|
362 |
+
eval_use_gather_object = eval_use_gather_object,
|
363 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
364 |
+
max_length = max_length,
|
365 |
+
max_prompt_length = max_prompt_length,
|
366 |
+
max_completion_length = max_completion_length,
|
367 |
+
disable_dropout = disable_dropout,
|
368 |
+
step_separator = step_separator,
|
369 |
+
train_on_last_step_only = train_on_last_step_only,
|
370 |
+
dataset_num_proc = dataset_num_proc,**kwargs)
|
371 |
+
self.vllm_sampling_params = vllm_sampling_params
|
372 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
373 |
+
pass
|
374 |
+
|
375 |
+
class _UnslothPRMTrainer(Trainer):
|
376 |
+
""""""
|
377 |
+
|
378 |
+
_tag_names = ["trl", "prm"]
|
379 |
+
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
383 |
+
args: Optional[PRMConfig] = None,
|
384 |
+
data_collator: Optional[DataCollator] = None,
|
385 |
+
train_dataset: Optional[Dataset] = None,
|
386 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
387 |
+
processing_class: Optional[
|
388 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
389 |
+
] = None,
|
390 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
391 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
392 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
393 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
394 |
+
None,
|
395 |
+
None,
|
396 |
+
),
|
397 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
398 |
+
peft_config: Optional[dict] = None,
|
399 |
+
):
|
400 |
+
if not is_peft_available() and peft_config is not None:
|
401 |
+
raise ValueError(
|
402 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
403 |
+
)
|
404 |
+
elif is_peft_available() and peft_config is not None:
|
405 |
+
if not isinstance(model, PeftModel):
|
406 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
407 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
408 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
409 |
+
)
|
410 |
+
|
411 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
412 |
+
|
413 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
414 |
+
warnings.warn(
|
415 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
416 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
417 |
+
)
|
418 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
419 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
420 |
+
|
421 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
422 |
+
|
423 |
+
model = model
|
424 |
+
|
425 |
+
# Disable dropout in the model
|
426 |
+
if args.disable_dropout:
|
427 |
+
disable_dropout_in_model(model)
|
428 |
+
|
429 |
+
if compute_metrics is None:
|
430 |
+
compute_metrics = compute_accuracy
|
431 |
+
|
432 |
+
if data_collator is None:
|
433 |
+
if processing_class is None:
|
434 |
+
raise ValueError(
|
435 |
+
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
436 |
+
)
|
437 |
+
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
438 |
+
|
439 |
+
if "input_ids" not in train_dataset.column_names:
|
440 |
+
with PartialState().local_main_process_first():
|
441 |
+
fn_kwargs = {
|
442 |
+
"tokenizer": processing_class,
|
443 |
+
"step_separator": args.step_separator,
|
444 |
+
"max_length": args.max_length,
|
445 |
+
"max_prompt_length": args.max_prompt_length,
|
446 |
+
"max_completion_length": args.max_completion_length,
|
447 |
+
"train_on_last_step_only": args.train_on_last_step_only,
|
448 |
+
}
|
449 |
+
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
450 |
+
train_dataset = train_dataset.map(
|
451 |
+
self.tokenize_row,
|
452 |
+
fn_kwargs=train_fn_kwargs,
|
453 |
+
num_proc=args.dataset_num_proc,
|
454 |
+
remove_columns=train_dataset.features,
|
455 |
+
desc="Tokenizing train dataset",
|
456 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
457 |
+
{
|
458 |
+
"labels": features.Sequence(features.Value("int64")),
|
459 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
460 |
+
}
|
461 |
+
),
|
462 |
+
)
|
463 |
+
|
464 |
+
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
465 |
+
if eval_dataset is not None:
|
466 |
+
eval_dataset = eval_dataset.map(
|
467 |
+
self.tokenize_row,
|
468 |
+
fn_kwargs=eval_fn_kwargs,
|
469 |
+
num_proc=args.dataset_num_proc,
|
470 |
+
remove_columns=eval_dataset.features,
|
471 |
+
desc="Tokenizing eval dataset",
|
472 |
+
features=features.Features( # needed to avoid map to cast labels to bool
|
473 |
+
{
|
474 |
+
"labels": features.Sequence(features.Value("int64")),
|
475 |
+
"input_ids": features.Sequence(features.Value("int64")),
|
476 |
+
}
|
477 |
+
),
|
478 |
+
)
|
479 |
+
|
480 |
+
super().__init__(
|
481 |
+
model=model,
|
482 |
+
args=args,
|
483 |
+
data_collator=data_collator,
|
484 |
+
train_dataset=train_dataset,
|
485 |
+
eval_dataset=eval_dataset,
|
486 |
+
processing_class=processing_class,
|
487 |
+
model_init=model_init,
|
488 |
+
compute_metrics=compute_metrics,
|
489 |
+
callbacks=callbacks,
|
490 |
+
optimizers=optimizers,
|
491 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
492 |
+
)
|
493 |
+
|
494 |
+
# Add tags for models that have been loaded with the correct transformers version
|
495 |
+
if hasattr(self.model, "add_model_tags"):
|
496 |
+
self.model.add_model_tags(self._tag_names)
|
497 |
+
|
498 |
+
@staticmethod
|
499 |
+
def tokenize_row(
|
500 |
+
features,
|
501 |
+
tokenizer,
|
502 |
+
step_separator,
|
503 |
+
max_length,
|
504 |
+
max_prompt_length,
|
505 |
+
max_completion_length,
|
506 |
+
train_on_last_step_only,
|
507 |
+
is_eval,
|
508 |
+
):
|
509 |
+
r"""
|
510 |
+
Tokenize a row of the dataset.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
features (`dict[str, str]`):
|
514 |
+
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
515 |
+
tokenizer (`PreTrainedTokenizerBase`):
|
516 |
+
Tokenizer used to process the data.
|
517 |
+
step_separator (`str`):
|
518 |
+
Separator between steps in the completion.
|
519 |
+
max_length (`int` or `None`):
|
520 |
+
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
521 |
+
max_prompt_length (`int` or `None`):
|
522 |
+
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
523 |
+
max_completion_length (`int` or `None`):
|
524 |
+
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
525 |
+
train_on_last_step_only (`bool`):
|
526 |
+
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
527 |
+
token of the completion.
|
528 |
+
is_eval (`bool`):
|
529 |
+
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
530 |
+
|
531 |
+
Returns:
|
532 |
+
`dict[str, list[int]]`:
|
533 |
+
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
534 |
+
|
535 |
+
Example:
|
536 |
+
```python
|
537 |
+
>>> from transformers import AutoTokenizer
|
538 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
539 |
+
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
540 |
+
... "completions": ["11 is greater than 8.",
|
541 |
+
... "Hence, 9.11 > 9.8."],
|
542 |
+
... "labels": [True, False]}
|
543 |
+
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
544 |
+
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
545 |
+
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
546 |
+
```
|
547 |
+
"""
|
548 |
+
# Tokenize the prompt and completions
|
549 |
+
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
550 |
+
completions_ids = [
|
551 |
+
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
552 |
+
]
|
553 |
+
if train_on_last_step_only and not is_eval:
|
554 |
+
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
555 |
+
else:
|
556 |
+
labels = [int(label) for label in features["labels"]]
|
557 |
+
|
558 |
+
# Get the ID of the separator token and add it to the completions
|
559 |
+
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
560 |
+
completions_ids = [completion + separator_ids for completion in completions_ids]
|
561 |
+
|
562 |
+
# Create the label
|
563 |
+
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
564 |
+
|
565 |
+
# Join the completions and labels steps
|
566 |
+
completion_ids = list(chain(*completions_ids))
|
567 |
+
labels = list(chain(*labels))
|
568 |
+
|
569 |
+
if tokenizer.bos_token_id is not None:
|
570 |
+
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
571 |
+
|
572 |
+
# Truncate prompt and completion sequences
|
573 |
+
if max_prompt_length is not None:
|
574 |
+
prompt_ids = prompt_ids[-max_prompt_length:]
|
575 |
+
if max_completion_length is not None:
|
576 |
+
completion_ids = completion_ids[:max_completion_length]
|
577 |
+
labels = labels[:max_completion_length]
|
578 |
+
|
579 |
+
input_ids = prompt_ids + completion_ids
|
580 |
+
labels = [-100] * len(prompt_ids) + labels
|
581 |
+
|
582 |
+
if max_length is not None:
|
583 |
+
input_ids = input_ids[:max_length]
|
584 |
+
labels = labels[:max_length]
|
585 |
+
|
586 |
+
return {"input_ids": input_ids, "labels": labels}
|
587 |
+
|
588 |
+
def create_model_card(
|
589 |
+
self,
|
590 |
+
model_name: Optional[str] = None,
|
591 |
+
dataset_name: Optional[str] = None,
|
592 |
+
tags: Union[str, list[str], None] = None,
|
593 |
+
):
|
594 |
+
"""
|
595 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
596 |
+
|
597 |
+
Args:
|
598 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
599 |
+
Name of the model.
|
600 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
601 |
+
Name of the dataset used for training.
|
602 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
603 |
+
Tags to be associated with the model card.
|
604 |
+
"""
|
605 |
+
if not self.is_world_process_zero():
|
606 |
+
return
|
607 |
+
|
608 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
609 |
+
base_model = self.model.config._name_or_path
|
610 |
+
else:
|
611 |
+
base_model = None
|
612 |
+
|
613 |
+
tags = tags or []
|
614 |
+
if isinstance(tags, str):
|
615 |
+
tags = [tags]
|
616 |
+
|
617 |
+
if hasattr(self.model.config, "unsloth_version"):
|
618 |
+
tags.append("unsloth")
|
619 |
+
|
620 |
+
citation = textwrap.dedent("""\
|
621 |
+
@article{uesato2022solving,
|
622 |
+
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
623 |
+
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
624 |
+
year = 2022,
|
625 |
+
journal = {arXiv preprint arXiv:2211.14275}
|
626 |
+
}""")
|
627 |
+
|
628 |
+
model_card = generate_model_card(
|
629 |
+
base_model=base_model,
|
630 |
+
model_name=model_name,
|
631 |
+
hub_model_id=self.hub_model_id,
|
632 |
+
dataset_name=dataset_name,
|
633 |
+
tags=tags,
|
634 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
635 |
+
trainer_name="PRM",
|
636 |
+
trainer_citation=citation,
|
637 |
+
paper_title="Solving math word problems with process-and outcome-based feedback",
|
638 |
+
)
|
639 |
+
|
640 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
641 |
+
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
642 |
+
"""
|
643 |
+
|
644 |
+
Initialize PRMTrainer.
|
645 |
+
|
646 |
+
Args:
|
647 |
+
model (`transformers.PreTrainedModel`):
|
648 |
+
The model to train, preferably an `AutoModelForTokenClassification`.
|
649 |
+
args (`PRMConfig`):
|
650 |
+
The arguments to use for training.
|
651 |
+
data_collator (`transformers.DataCollator`):
|
652 |
+
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
653 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
654 |
+
train_dataset (`datasets.Dataset`):
|
655 |
+
The dataset to use for training.
|
656 |
+
eval_dataset (`datasets.Dataset`):
|
657 |
+
The dataset to use for evaluation.
|
658 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
659 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
660 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
661 |
+
reuse the fine-tuned model.
|
662 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
663 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
664 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
665 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
666 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
667 |
+
The callbacks to use for training.
|
668 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
669 |
+
The optimizer and scheduler to use for training.
|
670 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
671 |
+
The function to use to preprocess the logits before computing the metrics.
|
672 |
+
peft_config (`dict`, defaults to `None`):
|
673 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
674 |
+
|
675 |
+
"""
|
676 |
+
def __init__(
|
677 |
+
self,
|
678 |
+
model = None,
|
679 |
+
args = None,
|
680 |
+
data_collator = None,
|
681 |
+
train_dataset = None,
|
682 |
+
eval_dataset = None,
|
683 |
+
processing_class = None,
|
684 |
+
model_init = None,
|
685 |
+
compute_metrics = None,
|
686 |
+
callbacks = None,
|
687 |
+
preprocess_logits_for_metrics = None,
|
688 |
+
peft_config = None,
|
689 |
+
**kwargs
|
690 |
+
):
|
691 |
+
if args is None: args = UnslothPRMConfig()
|
692 |
+
use_bf16 = getattr(args, 'bf16', False)
|
693 |
+
use_fp16 = getattr(args, 'fp16', False)
|
694 |
+
force_float32 = False
|
695 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
696 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
697 |
+
force_float32 = True
|
698 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
699 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
700 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
701 |
+
from unsloth_zoo.utils import _get_dtype
|
702 |
+
dtype = _get_dtype(dtype)
|
703 |
+
float16 = dtype == torch.float16
|
704 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
705 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
706 |
+
if force_float32:
|
707 |
+
args.fp16 = False
|
708 |
+
args.bf16 = False
|
709 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
710 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
711 |
+
args.fp16 = float16
|
712 |
+
args.bf16 = not float16
|
713 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
714 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
715 |
+
args.eval_strategy = 'steps'
|
716 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
717 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
718 |
+
if ga_steps is not None and ga_steps > 1:
|
719 |
+
from transformers import __version__ as transformers_version
|
720 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
721 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
722 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
723 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
724 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
725 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
726 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
727 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
728 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
729 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
730 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
731 |
+
if force_float32:
|
732 |
+
args.bf16_full_eval = False
|
733 |
+
args.fp16_full_eval = False
|
734 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
735 |
+
args.bf16_full_eval = True
|
736 |
+
args.fp16_full_eval = False
|
737 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
738 |
+
args.bf16_full_eval = args.bf16
|
739 |
+
args.fp16_full_eval = args.fp16
|
740 |
+
_output_logits = False
|
741 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
742 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
743 |
+
if _output_logits:
|
744 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
745 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
746 |
+
pass
|
747 |
+
else:
|
748 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
749 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
750 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
751 |
+
max_seq_length = model.max_seq_length
|
752 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
753 |
+
if model is not None and hasattr(model, 'for_training'):
|
754 |
+
model.for_training()
|
755 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
756 |
+
if 'processing_class' in locals():
|
757 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
758 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
759 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
760 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
761 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
762 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
763 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
764 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
765 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
766 |
+
else:
|
767 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
768 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
769 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
770 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
771 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
772 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
773 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
774 |
+
else:
|
775 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
776 |
+
other_metrics = []
|
777 |
+
|
778 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
779 |
+
PatchRLStatistics('prm_trainer', other_metrics)
|
780 |
+
|
781 |
+
super().__init__(
|
782 |
+
model = model,
|
783 |
+
args = args,
|
784 |
+
data_collator = data_collator,
|
785 |
+
train_dataset = train_dataset,
|
786 |
+
eval_dataset = eval_dataset,
|
787 |
+
processing_class = processing_class,
|
788 |
+
model_init = model_init,
|
789 |
+
compute_metrics = compute_metrics,
|
790 |
+
callbacks = callbacks,
|
791 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
792 |
+
peft_config = peft_config,**kwargs)
|
793 |
+
if hasattr(self, 'neftune_hook_handle'):
|
794 |
+
self.neftune_hook_handle.remove()
|
795 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
796 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
797 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
798 |
+
pass
|
799 |
+
|
800 |
+
pass
|
unsloth_compiled_cache/UnslothRLOOTrainer.py
ADDED
@@ -0,0 +1,1133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, textwrap, time, torch, truncate_response, unwrap_model_for_generation)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothRLOOConfig(RLOOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`RLOOTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
|
54 |
+
Name of this experiment.
|
55 |
+
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
56 |
+
Path to the reward model.
|
57 |
+
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
58 |
+
Number of epochs to train.
|
59 |
+
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
60 |
+
Whether to whiten the rewards.
|
61 |
+
kl_coef (`float`, *optional*, defaults to `0.05`):
|
62 |
+
KL coefficient.
|
63 |
+
cliprange (`float`, *optional*, defaults to `0.2`):
|
64 |
+
Clip range.
|
65 |
+
rloo_k (`int`, *optional*, defaults to `2`):
|
66 |
+
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
|
67 |
+
normalize_reward (`bool`, *optional*, defaults to `False`):
|
68 |
+
Whether to normalize rewards.
|
69 |
+
reward_clip_range (`float`, *optional*, defaults to `10.0`):
|
70 |
+
Clip range for rewards.
|
71 |
+
normalize_advantage (`bool`, *optional*, defaults to `False`):
|
72 |
+
Whether to normalize advantages.
|
73 |
+
token_level_kl (`bool`, *optional*, defaults to `True`):
|
74 |
+
Whether to use token-level KL penalty or sequence-level KL penalty.
|
75 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
76 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
77 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
78 |
+
capacity of a single GPU, albeit at the cost of slower generation.
|
79 |
+
|
80 |
+
"""
|
81 |
+
vllm_sampling_params: Optional[Any] = field(
|
82 |
+
default = None,
|
83 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
84 |
+
)
|
85 |
+
unsloth_num_chunks : Optional[int] = field(
|
86 |
+
default = -1,
|
87 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
88 |
+
)
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
output_dir = None,
|
92 |
+
overwrite_output_dir = None,
|
93 |
+
do_train = False,
|
94 |
+
do_eval = False,
|
95 |
+
do_predict = False,
|
96 |
+
eval_strategy = 'no',
|
97 |
+
prediction_loss_only = False,
|
98 |
+
per_device_train_batch_size = 4,
|
99 |
+
per_device_eval_batch_size = 4,
|
100 |
+
per_gpu_train_batch_size = None,
|
101 |
+
per_gpu_eval_batch_size = None,
|
102 |
+
gradient_accumulation_steps = 2,
|
103 |
+
eval_accumulation_steps = 2,
|
104 |
+
eval_delay = 0,
|
105 |
+
torch_empty_cache_steps = 250,
|
106 |
+
learning_rate = 5e-05,
|
107 |
+
weight_decay = 0.01,
|
108 |
+
adam_beta1 = 0.9,
|
109 |
+
adam_beta2 = 0.999,
|
110 |
+
adam_epsilon = 1e-08,
|
111 |
+
max_grad_norm = 1.0,
|
112 |
+
num_train_epochs = 3.0,
|
113 |
+
max_steps = -1,
|
114 |
+
lr_scheduler_type = 'linear',
|
115 |
+
warmup_ratio = 0.1,
|
116 |
+
warmup_steps = 0,
|
117 |
+
log_level = 'passive',
|
118 |
+
log_level_replica = 'warning',
|
119 |
+
log_on_each_node = True,
|
120 |
+
logging_dir = None,
|
121 |
+
logging_strategy = 'steps',
|
122 |
+
logging_first_step = False,
|
123 |
+
logging_steps = 1,
|
124 |
+
logging_nan_inf_filter = False,
|
125 |
+
save_strategy = 'steps',
|
126 |
+
save_steps = 500,
|
127 |
+
save_total_limit = None,
|
128 |
+
save_safetensors = True,
|
129 |
+
save_on_each_node = False,
|
130 |
+
save_only_model = False,
|
131 |
+
restore_callback_states_from_checkpoint = False,
|
132 |
+
no_cuda = False,
|
133 |
+
use_cpu = False,
|
134 |
+
use_mps_device = False,
|
135 |
+
seed = 3407,
|
136 |
+
data_seed = 3407,
|
137 |
+
jit_mode_eval = False,
|
138 |
+
use_ipex = False,
|
139 |
+
bf16 = False,
|
140 |
+
fp16 = False,
|
141 |
+
fp16_opt_level = 'O1',
|
142 |
+
half_precision_backend = 'auto',
|
143 |
+
bf16_full_eval = False,
|
144 |
+
fp16_full_eval = False,
|
145 |
+
tf32 = None,
|
146 |
+
local_rank = -1,
|
147 |
+
ddp_backend = None,
|
148 |
+
tpu_num_cores = None,
|
149 |
+
tpu_metrics_debug = False,
|
150 |
+
debug = '',
|
151 |
+
dataloader_drop_last = False,
|
152 |
+
eval_steps = None,
|
153 |
+
dataloader_num_workers = 0,
|
154 |
+
dataloader_prefetch_factor = None,
|
155 |
+
past_index = -1,
|
156 |
+
run_name = None,
|
157 |
+
disable_tqdm = None,
|
158 |
+
remove_unused_columns = True,
|
159 |
+
label_names = None,
|
160 |
+
load_best_model_at_end = False,
|
161 |
+
metric_for_best_model = None,
|
162 |
+
greater_is_better = None,
|
163 |
+
ignore_data_skip = False,
|
164 |
+
fsdp = '',
|
165 |
+
fsdp_min_num_params = 0,
|
166 |
+
fsdp_config = None,
|
167 |
+
tp_size = 0,
|
168 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
169 |
+
accelerator_config = None,
|
170 |
+
deepspeed = None,
|
171 |
+
label_smoothing_factor = 0.0,
|
172 |
+
optim = 'adamw_8bit',
|
173 |
+
optim_args = None,
|
174 |
+
adafactor = False,
|
175 |
+
group_by_length = False,
|
176 |
+
length_column_name = 'length',
|
177 |
+
report_to = None,
|
178 |
+
ddp_find_unused_parameters = None,
|
179 |
+
ddp_bucket_cap_mb = None,
|
180 |
+
ddp_broadcast_buffers = None,
|
181 |
+
dataloader_pin_memory = True,
|
182 |
+
dataloader_persistent_workers = False,
|
183 |
+
skip_memory_metrics = True,
|
184 |
+
use_legacy_prediction_loop = False,
|
185 |
+
push_to_hub = False,
|
186 |
+
resume_from_checkpoint = None,
|
187 |
+
hub_model_id = None,
|
188 |
+
hub_strategy = 'every_save',
|
189 |
+
hub_token = None,
|
190 |
+
hub_private_repo = None,
|
191 |
+
hub_always_push = False,
|
192 |
+
gradient_checkpointing = False,
|
193 |
+
gradient_checkpointing_kwargs = None,
|
194 |
+
include_inputs_for_metrics = False,
|
195 |
+
eval_do_concat_batches = True,
|
196 |
+
fp16_backend = 'auto',
|
197 |
+
evaluation_strategy = None,
|
198 |
+
push_to_hub_model_id = None,
|
199 |
+
push_to_hub_organization = None,
|
200 |
+
push_to_hub_token = None,
|
201 |
+
mp_parameters = '',
|
202 |
+
auto_find_batch_size = False,
|
203 |
+
full_determinism = False,
|
204 |
+
torchdynamo = None,
|
205 |
+
ray_scope = 'last',
|
206 |
+
ddp_timeout = 1800,
|
207 |
+
torch_compile = False,
|
208 |
+
torch_compile_backend = None,
|
209 |
+
torch_compile_mode = None,
|
210 |
+
dispatch_batches = None,
|
211 |
+
split_batches = None,
|
212 |
+
include_tokens_per_second = False,
|
213 |
+
include_num_input_tokens_seen = False,
|
214 |
+
neftune_noise_alpha = None,
|
215 |
+
optim_target_modules = None,
|
216 |
+
batch_eval_metrics = False,
|
217 |
+
eval_on_start = False,
|
218 |
+
use_liger_kernel = False,
|
219 |
+
eval_use_gather_object = False,
|
220 |
+
average_tokens_across_devices = False,
|
221 |
+
dataset_num_proc = None,
|
222 |
+
num_mini_batches = 1,
|
223 |
+
total_episodes = None,
|
224 |
+
local_rollout_forward_batch_size = 64,
|
225 |
+
num_sample_generations = 10,
|
226 |
+
response_length = 53,
|
227 |
+
stop_token = None,
|
228 |
+
stop_token_id = None,
|
229 |
+
temperature = 0.7,
|
230 |
+
missing_eos_penalty = None,
|
231 |
+
sft_model_path = 'EleutherAI/pythia-160m',
|
232 |
+
world_size = None,
|
233 |
+
num_total_batches = None,
|
234 |
+
micro_batch_size = None,
|
235 |
+
local_batch_size = None,
|
236 |
+
batch_size = None,
|
237 |
+
local_mini_batch_size = None,
|
238 |
+
mini_batch_size = None,
|
239 |
+
exp_name = 'rloo_config',
|
240 |
+
reward_model_path = 'EleutherAI/pythia-160m',
|
241 |
+
num_ppo_epochs = 4,
|
242 |
+
whiten_rewards = False,
|
243 |
+
kl_coef = 0.05,
|
244 |
+
cliprange = 0.2,
|
245 |
+
rloo_k = 2,
|
246 |
+
normalize_reward = False,
|
247 |
+
reward_clip_range = 10.0,
|
248 |
+
normalize_advantage = False,
|
249 |
+
token_level_kl = False,
|
250 |
+
ds3_gather_for_generation = True,
|
251 |
+
vllm_sampling_params = None,
|
252 |
+
unsloth_num_chunks = -1,
|
253 |
+
**kwargs,
|
254 |
+
):
|
255 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
256 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
257 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
258 |
+
output_dir = 'unsloth_training_checkpoints'
|
259 |
+
save_strategy = 'no'
|
260 |
+
if dataset_num_proc is None:
|
261 |
+
from multiprocessing import cpu_count
|
262 |
+
dataset_num_proc = cpu_count()
|
263 |
+
|
264 |
+
super().__init__(
|
265 |
+
output_dir = output_dir,
|
266 |
+
overwrite_output_dir = overwrite_output_dir,
|
267 |
+
do_train = do_train,
|
268 |
+
do_eval = do_eval,
|
269 |
+
do_predict = do_predict,
|
270 |
+
eval_strategy = eval_strategy,
|
271 |
+
prediction_loss_only = prediction_loss_only,
|
272 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
273 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
274 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
275 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
276 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
277 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
278 |
+
eval_delay = eval_delay,
|
279 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
280 |
+
learning_rate = learning_rate,
|
281 |
+
weight_decay = weight_decay,
|
282 |
+
adam_beta1 = adam_beta1,
|
283 |
+
adam_beta2 = adam_beta2,
|
284 |
+
adam_epsilon = adam_epsilon,
|
285 |
+
max_grad_norm = max_grad_norm,
|
286 |
+
num_train_epochs = num_train_epochs,
|
287 |
+
max_steps = max_steps,
|
288 |
+
lr_scheduler_type = lr_scheduler_type,
|
289 |
+
warmup_ratio = warmup_ratio,
|
290 |
+
warmup_steps = warmup_steps,
|
291 |
+
log_level = log_level,
|
292 |
+
log_level_replica = log_level_replica,
|
293 |
+
log_on_each_node = log_on_each_node,
|
294 |
+
logging_dir = logging_dir,
|
295 |
+
logging_strategy = logging_strategy,
|
296 |
+
logging_first_step = logging_first_step,
|
297 |
+
logging_steps = logging_steps,
|
298 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
299 |
+
save_strategy = save_strategy,
|
300 |
+
save_steps = save_steps,
|
301 |
+
save_total_limit = save_total_limit,
|
302 |
+
save_safetensors = save_safetensors,
|
303 |
+
save_on_each_node = save_on_each_node,
|
304 |
+
save_only_model = save_only_model,
|
305 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
306 |
+
no_cuda = no_cuda,
|
307 |
+
use_cpu = use_cpu,
|
308 |
+
use_mps_device = use_mps_device,
|
309 |
+
seed = seed,
|
310 |
+
data_seed = data_seed,
|
311 |
+
jit_mode_eval = jit_mode_eval,
|
312 |
+
use_ipex = use_ipex,
|
313 |
+
bf16 = bf16,
|
314 |
+
fp16 = fp16,
|
315 |
+
fp16_opt_level = fp16_opt_level,
|
316 |
+
half_precision_backend = half_precision_backend,
|
317 |
+
bf16_full_eval = bf16_full_eval,
|
318 |
+
fp16_full_eval = fp16_full_eval,
|
319 |
+
tf32 = tf32,
|
320 |
+
local_rank = local_rank,
|
321 |
+
ddp_backend = ddp_backend,
|
322 |
+
tpu_num_cores = tpu_num_cores,
|
323 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
324 |
+
debug = debug,
|
325 |
+
dataloader_drop_last = dataloader_drop_last,
|
326 |
+
eval_steps = eval_steps,
|
327 |
+
dataloader_num_workers = dataloader_num_workers,
|
328 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
329 |
+
past_index = past_index,
|
330 |
+
run_name = run_name,
|
331 |
+
disable_tqdm = disable_tqdm,
|
332 |
+
remove_unused_columns = remove_unused_columns,
|
333 |
+
label_names = label_names,
|
334 |
+
load_best_model_at_end = load_best_model_at_end,
|
335 |
+
metric_for_best_model = metric_for_best_model,
|
336 |
+
greater_is_better = greater_is_better,
|
337 |
+
ignore_data_skip = ignore_data_skip,
|
338 |
+
fsdp = fsdp,
|
339 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
340 |
+
fsdp_config = fsdp_config,
|
341 |
+
tp_size = tp_size,
|
342 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
343 |
+
accelerator_config = accelerator_config,
|
344 |
+
deepspeed = deepspeed,
|
345 |
+
label_smoothing_factor = label_smoothing_factor,
|
346 |
+
optim = optim,
|
347 |
+
optim_args = optim_args,
|
348 |
+
adafactor = adafactor,
|
349 |
+
group_by_length = group_by_length,
|
350 |
+
length_column_name = length_column_name,
|
351 |
+
report_to = report_to,
|
352 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
353 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
354 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
355 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
356 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
357 |
+
skip_memory_metrics = skip_memory_metrics,
|
358 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
359 |
+
push_to_hub = push_to_hub,
|
360 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
361 |
+
hub_model_id = hub_model_id,
|
362 |
+
hub_strategy = hub_strategy,
|
363 |
+
hub_token = hub_token,
|
364 |
+
hub_private_repo = hub_private_repo,
|
365 |
+
hub_always_push = hub_always_push,
|
366 |
+
gradient_checkpointing = gradient_checkpointing,
|
367 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
368 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
369 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
370 |
+
fp16_backend = fp16_backend,
|
371 |
+
evaluation_strategy = evaluation_strategy,
|
372 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
373 |
+
push_to_hub_organization = push_to_hub_organization,
|
374 |
+
push_to_hub_token = push_to_hub_token,
|
375 |
+
mp_parameters = mp_parameters,
|
376 |
+
auto_find_batch_size = auto_find_batch_size,
|
377 |
+
full_determinism = full_determinism,
|
378 |
+
torchdynamo = torchdynamo,
|
379 |
+
ray_scope = ray_scope,
|
380 |
+
ddp_timeout = ddp_timeout,
|
381 |
+
torch_compile = torch_compile,
|
382 |
+
torch_compile_backend = torch_compile_backend,
|
383 |
+
torch_compile_mode = torch_compile_mode,
|
384 |
+
dispatch_batches = dispatch_batches,
|
385 |
+
split_batches = split_batches,
|
386 |
+
include_tokens_per_second = include_tokens_per_second,
|
387 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
388 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
389 |
+
optim_target_modules = optim_target_modules,
|
390 |
+
batch_eval_metrics = batch_eval_metrics,
|
391 |
+
eval_on_start = eval_on_start,
|
392 |
+
use_liger_kernel = use_liger_kernel,
|
393 |
+
eval_use_gather_object = eval_use_gather_object,
|
394 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
395 |
+
dataset_num_proc = dataset_num_proc,
|
396 |
+
num_mini_batches = num_mini_batches,
|
397 |
+
total_episodes = total_episodes,
|
398 |
+
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
399 |
+
num_sample_generations = num_sample_generations,
|
400 |
+
response_length = response_length,
|
401 |
+
stop_token = stop_token,
|
402 |
+
stop_token_id = stop_token_id,
|
403 |
+
temperature = temperature,
|
404 |
+
missing_eos_penalty = missing_eos_penalty,
|
405 |
+
sft_model_path = sft_model_path,
|
406 |
+
world_size = world_size,
|
407 |
+
num_total_batches = num_total_batches,
|
408 |
+
micro_batch_size = micro_batch_size,
|
409 |
+
local_batch_size = local_batch_size,
|
410 |
+
batch_size = batch_size,
|
411 |
+
local_mini_batch_size = local_mini_batch_size,
|
412 |
+
mini_batch_size = mini_batch_size,
|
413 |
+
exp_name = exp_name,
|
414 |
+
reward_model_path = reward_model_path,
|
415 |
+
num_ppo_epochs = num_ppo_epochs,
|
416 |
+
whiten_rewards = whiten_rewards,
|
417 |
+
kl_coef = kl_coef,
|
418 |
+
cliprange = cliprange,
|
419 |
+
rloo_k = rloo_k,
|
420 |
+
normalize_reward = normalize_reward,
|
421 |
+
reward_clip_range = reward_clip_range,
|
422 |
+
normalize_advantage = normalize_advantage,
|
423 |
+
token_level_kl = token_level_kl,
|
424 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
425 |
+
self.vllm_sampling_params = vllm_sampling_params
|
426 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
427 |
+
pass
|
428 |
+
|
429 |
+
class _UnslothRLOOTrainer(Trainer):
|
430 |
+
_tag_names = ["trl", "rloo"]
|
431 |
+
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
config: RLOOConfig,
|
435 |
+
processing_class: Optional[
|
436 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
437 |
+
],
|
438 |
+
policy: nn.Module,
|
439 |
+
ref_policy: nn.Module,
|
440 |
+
reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
|
441 |
+
train_dataset: Dataset,
|
442 |
+
data_collator: Optional[DataCollatorWithPadding] = None,
|
443 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
444 |
+
# less commonly used
|
445 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
446 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
447 |
+
) -> None:
|
448 |
+
if ref_policy is policy:
|
449 |
+
raise ValueError(
|
450 |
+
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
|
451 |
+
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
|
452 |
+
)
|
453 |
+
|
454 |
+
self.args = config
|
455 |
+
args = config
|
456 |
+
self.processing_class = processing_class
|
457 |
+
self.policy = policy
|
458 |
+
|
459 |
+
# Define the collator if not provided
|
460 |
+
if data_collator is None:
|
461 |
+
data_collator = DataCollatorWithPadding(self.processing_class)
|
462 |
+
|
463 |
+
self.policy.generation_config.eos_token_id = (
|
464 |
+
None # disable `pad_token_id` and `eos_token_id` because we just want to
|
465 |
+
)
|
466 |
+
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
|
467 |
+
|
468 |
+
self.ref_policy = ref_policy
|
469 |
+
self.reward_model = reward_model
|
470 |
+
self.train_dataset = train_dataset
|
471 |
+
self.train_dataset_len = len(train_dataset)
|
472 |
+
self.data_collator = data_collator
|
473 |
+
self.eval_dataset = eval_dataset
|
474 |
+
self.optimizer, self.lr_scheduler = optimizers
|
475 |
+
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
476 |
+
|
477 |
+
#########
|
478 |
+
# calculate various batch sizes
|
479 |
+
#########
|
480 |
+
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
481 |
+
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
482 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
483 |
+
self.accelerator = accelerator
|
484 |
+
args.world_size = accelerator.num_processes
|
485 |
+
args.local_batch_size = (
|
486 |
+
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
487 |
+
)
|
488 |
+
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
489 |
+
args.batch_size = int(args.local_batch_size * args.world_size)
|
490 |
+
args.mini_batch_size = exact_div(
|
491 |
+
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
492 |
+
)
|
493 |
+
args.local_mini_batch_size = exact_div(
|
494 |
+
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
495 |
+
)
|
496 |
+
args.num_total_batches = math.ceil(
|
497 |
+
args.total_episodes / args.batch_size
|
498 |
+
) # we may train for more than `total_episodes`
|
499 |
+
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
500 |
+
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
501 |
+
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
502 |
+
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
503 |
+
if args.num_sample_generations > 0:
|
504 |
+
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
505 |
+
self.local_dataloader_batch_size = exact_div(
|
506 |
+
args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
|
507 |
+
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
|
508 |
+
|
509 |
+
#########
|
510 |
+
# setup model, optimizer, and others
|
511 |
+
#########
|
512 |
+
for module in [policy, ref_policy, reward_model]:
|
513 |
+
if isinstance(module, nn.Module):
|
514 |
+
disable_dropout_in_model(module)
|
515 |
+
if args.stop_token and args.stop_token == "eos":
|
516 |
+
args.stop_token_id = self.processing_class.eos_token_id
|
517 |
+
self.model = policy
|
518 |
+
self.create_optimizer_and_scheduler(
|
519 |
+
num_training_steps=args.num_total_batches
|
520 |
+
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
521 |
+
|
522 |
+
#########
|
523 |
+
### trainer specifics
|
524 |
+
#########
|
525 |
+
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
526 |
+
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
527 |
+
self.callback_handler = CallbackHandler(
|
528 |
+
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
529 |
+
)
|
530 |
+
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
531 |
+
self.control = TrainerControl()
|
532 |
+
self.state = OnlineTrainerState(
|
533 |
+
is_local_process_zero=self.is_local_process_zero(),
|
534 |
+
is_world_process_zero=self.is_world_process_zero(),
|
535 |
+
stateful_callbacks=[
|
536 |
+
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
537 |
+
],
|
538 |
+
)
|
539 |
+
|
540 |
+
self.current_flos = 0
|
541 |
+
self.hp_search_backend = None
|
542 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
543 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
544 |
+
# Create distant repo and output directory if needed
|
545 |
+
self.hub_model_id = None
|
546 |
+
if self.args.push_to_hub:
|
547 |
+
self.init_hf_repo()
|
548 |
+
if self.args.should_save:
|
549 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
550 |
+
self.backup_model = None
|
551 |
+
|
552 |
+
# Add tags for models that have been loaded with the correct transformers version
|
553 |
+
if hasattr(self.model, "add_model_tags"):
|
554 |
+
self.model.add_model_tags(self._tag_names)
|
555 |
+
|
556 |
+
#########
|
557 |
+
### setup dataloader
|
558 |
+
#########
|
559 |
+
self.dataloader = DataLoader(
|
560 |
+
self.train_dataset,
|
561 |
+
batch_size=self.local_dataloader_batch_size,
|
562 |
+
shuffle=True,
|
563 |
+
collate_fn=self.data_collator,
|
564 |
+
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
565 |
+
)
|
566 |
+
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
567 |
+
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
568 |
+
torch.manual_seed(args.seed)
|
569 |
+
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
570 |
+
torch.manual_seed(self.local_seed) # reset the local seed again
|
571 |
+
|
572 |
+
self.eval_dataloader = DataLoader(
|
573 |
+
self.eval_dataset,
|
574 |
+
batch_size=args.per_device_eval_batch_size,
|
575 |
+
collate_fn=self.data_collator,
|
576 |
+
drop_last=True,
|
577 |
+
) # no need to shuffle eval dataset
|
578 |
+
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
579 |
+
|
580 |
+
if self.is_deepspeed_enabled:
|
581 |
+
if isinstance(self.reward_model, nn.Module):
|
582 |
+
self.reward_model = prepare_deepspeed(
|
583 |
+
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
584 |
+
)
|
585 |
+
self.ref_policy = prepare_deepspeed(
|
586 |
+
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
587 |
+
)
|
588 |
+
self.deepspeed = self.model
|
589 |
+
else:
|
590 |
+
self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
591 |
+
if isinstance(self.reward_model, nn.Module):
|
592 |
+
self.reward_model = self.reward_model.to(self.accelerator.device)
|
593 |
+
|
594 |
+
def get_train_dataloader(self) -> DataLoader:
|
595 |
+
return self.dataloader
|
596 |
+
|
597 |
+
def get_eval_dataloader(self) -> DataLoader:
|
598 |
+
return self.eval_dataloader
|
599 |
+
|
600 |
+
def train(self):
|
601 |
+
args = self.args
|
602 |
+
accelerator = self.accelerator
|
603 |
+
optimizer = self.optimizer
|
604 |
+
model = self.model
|
605 |
+
self.model_wrapped = self.model
|
606 |
+
ref_policy = self.ref_policy
|
607 |
+
reward_model = self.reward_model
|
608 |
+
processing_class = self.processing_class
|
609 |
+
dataloader = self.dataloader
|
610 |
+
device = accelerator.device
|
611 |
+
|
612 |
+
def repeat_generator():
|
613 |
+
while True:
|
614 |
+
yield from dataloader
|
615 |
+
|
616 |
+
iter_dataloader = iter(repeat_generator())
|
617 |
+
generation_config = GenerationConfig(
|
618 |
+
max_new_tokens=args.response_length,
|
619 |
+
temperature=(args.temperature + 1e-7),
|
620 |
+
top_k=0.0,
|
621 |
+
top_p=1.0,
|
622 |
+
do_sample=True,
|
623 |
+
)
|
624 |
+
|
625 |
+
accelerator.print("===training policy===")
|
626 |
+
start_time = time.time()
|
627 |
+
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
628 |
+
approxkl_stats = torch.zeros(stats_shape, device=device)
|
629 |
+
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
630 |
+
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
631 |
+
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
632 |
+
entropy_stats = torch.zeros(stats_shape, device=device)
|
633 |
+
ratio_stats = torch.zeros(stats_shape, device=device)
|
634 |
+
model.train()
|
635 |
+
|
636 |
+
# trainer state initialization
|
637 |
+
self.state.global_step = 0
|
638 |
+
self.state.episode = 0
|
639 |
+
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
|
640 |
+
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
641 |
+
# Compute absolute values for logging, eval, and save if given as ratio
|
642 |
+
if args.logging_steps is not None:
|
643 |
+
if args.logging_steps < 1:
|
644 |
+
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
645 |
+
else:
|
646 |
+
self.state.logging_steps = args.logging_steps
|
647 |
+
if args.eval_steps is not None:
|
648 |
+
if args.eval_steps < 1:
|
649 |
+
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
650 |
+
else:
|
651 |
+
self.state.eval_steps = args.eval_steps
|
652 |
+
if args.save_steps is not None:
|
653 |
+
if args.save_steps < 1:
|
654 |
+
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
655 |
+
else:
|
656 |
+
self.state.save_steps = args.save_steps
|
657 |
+
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
658 |
+
|
659 |
+
for update in range(1, args.num_total_batches + 1):
|
660 |
+
self.state.episode += 1 * args.batch_size
|
661 |
+
data = next(iter_dataloader)
|
662 |
+
with torch.no_grad():
|
663 |
+
queries = data["input_ids"].to(device)
|
664 |
+
queries = queries.repeat(args.rloo_k, 1)
|
665 |
+
context_length = queries.shape[1]
|
666 |
+
responses = []
|
667 |
+
postprocessed_responses = []
|
668 |
+
logprobs = []
|
669 |
+
ref_logprobs = []
|
670 |
+
scores = []
|
671 |
+
sequence_lengths = []
|
672 |
+
|
673 |
+
# Generate responses and compute logprobs
|
674 |
+
with unwrap_model_for_generation(
|
675 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
676 |
+
) as unwrapped_model:
|
677 |
+
query_responses, logitss = batch_generation(
|
678 |
+
unwrapped_model,
|
679 |
+
queries,
|
680 |
+
args.local_rollout_forward_batch_size,
|
681 |
+
processing_class.pad_token_id,
|
682 |
+
generation_config,
|
683 |
+
)
|
684 |
+
|
685 |
+
# Process responses in batches
|
686 |
+
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
687 |
+
query = queries[i : i + args.local_rollout_forward_batch_size]
|
688 |
+
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
689 |
+
response = query_response[:, context_length:]
|
690 |
+
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
691 |
+
logprob = selective_log_softmax(logits, response)
|
692 |
+
del logits
|
693 |
+
torch.cuda.empty_cache()
|
694 |
+
|
695 |
+
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
696 |
+
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
697 |
+
ref_logits /= args.temperature + 1e-7
|
698 |
+
ref_logprob = selective_log_softmax(ref_logits, response)
|
699 |
+
del ref_output, ref_logits
|
700 |
+
torch.cuda.empty_cache()
|
701 |
+
|
702 |
+
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
703 |
+
postprocessed_response = response
|
704 |
+
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
705 |
+
postprocessed_response = truncate_response(
|
706 |
+
args.stop_token_id, processing_class.pad_token_id, response
|
707 |
+
)
|
708 |
+
|
709 |
+
# Response Processing 2. run reward model on the truncated responses
|
710 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
711 |
+
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
712 |
+
|
713 |
+
if isinstance(reward_model, nn.Module):
|
714 |
+
_, score, _ = get_reward(
|
715 |
+
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
716 |
+
)
|
717 |
+
else:
|
718 |
+
score = torch.tensor(
|
719 |
+
reward_model(
|
720 |
+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
721 |
+
),
|
722 |
+
dtype=torch.float,
|
723 |
+
).to(device)
|
724 |
+
|
725 |
+
# Store batch results
|
726 |
+
responses.append(response)
|
727 |
+
postprocessed_responses.append(postprocessed_response)
|
728 |
+
logprobs.append(logprob)
|
729 |
+
ref_logprobs.append(ref_logprob)
|
730 |
+
sequence_lengths.append(sequence_length)
|
731 |
+
scores.append(score)
|
732 |
+
|
733 |
+
# Concatenate all batched results
|
734 |
+
responses = torch.cat(responses, 0)
|
735 |
+
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
736 |
+
logprobs = torch.cat(logprobs, 0)
|
737 |
+
ref_logprobs = torch.cat(ref_logprobs, 0)
|
738 |
+
sequence_lengths = torch.cat(sequence_lengths, 0)
|
739 |
+
scores = torch.cat(scores, 0)
|
740 |
+
del (logprob, ref_logprob, score)
|
741 |
+
torch.cuda.empty_cache()
|
742 |
+
gc.collect()
|
743 |
+
|
744 |
+
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
|
745 |
+
# responses not passing that filter will receive a low (fixed) score
|
746 |
+
# only query humans on responses that pass that filter
|
747 |
+
contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
|
748 |
+
if args.missing_eos_penalty is not None:
|
749 |
+
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
750 |
+
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
751 |
+
|
752 |
+
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
753 |
+
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
754 |
+
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
755 |
+
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
756 |
+
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
757 |
+
|
758 |
+
# 4. compute rewards
|
759 |
+
# Compute KL divergence
|
760 |
+
kl = logprobs - ref_logprobs
|
761 |
+
|
762 |
+
# Normalize rewards
|
763 |
+
if args.normalize_reward:
|
764 |
+
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
|
765 |
+
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
|
766 |
+
|
767 |
+
# Compute total reward with KL penalty
|
768 |
+
if args.token_level_kl:
|
769 |
+
# Token-level KL penalty: apply KL penalty per token
|
770 |
+
kl_reward = -args.kl_coef * kl
|
771 |
+
|
772 |
+
# Get the index of the last non-padded token for each sequence
|
773 |
+
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
|
774 |
+
last_reward = torch.zeros_like(kl)
|
775 |
+
# Ensure scores has correct shape and type
|
776 |
+
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
|
777 |
+
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
|
778 |
+
|
779 |
+
# Combine KL reward and last reward
|
780 |
+
non_score_reward = kl_reward.sum(1) # Keep this for logging
|
781 |
+
reward = last_reward + kl_reward
|
782 |
+
rlhf_reward = reward.sum(1) # Sum across sequence length
|
783 |
+
else:
|
784 |
+
# Sequence-level KL penalty: sum KL across tokens first
|
785 |
+
sequence_kl = kl.sum(1)
|
786 |
+
non_score_reward = -args.kl_coef * sequence_kl
|
787 |
+
rlhf_reward = non_score_reward + scores
|
788 |
+
|
789 |
+
# vectorized RLOO advantages implementation
|
790 |
+
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
|
791 |
+
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
|
792 |
+
advantages = rlhf_reward - baseline
|
793 |
+
advantages = advantages.flatten()
|
794 |
+
|
795 |
+
# Normalize advantages
|
796 |
+
if args.normalize_advantage:
|
797 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
798 |
+
|
799 |
+
torch.cuda.empty_cache()
|
800 |
+
|
801 |
+
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
802 |
+
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
803 |
+
b_inds = np.random.permutation(args.local_batch_size)
|
804 |
+
minibatch_idx = 0
|
805 |
+
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
806 |
+
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
807 |
+
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
808 |
+
gradient_accumulation_idx = 0
|
809 |
+
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
810 |
+
with accelerator.accumulate(model):
|
811 |
+
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
812 |
+
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
813 |
+
|
814 |
+
# Get batch data
|
815 |
+
mb_advantage = advantages[micro_batch_inds]
|
816 |
+
mb_responses = responses[micro_batch_inds]
|
817 |
+
mb_query_responses = query_responses[micro_batch_inds]
|
818 |
+
mb_logprobs = logprobs[micro_batch_inds]
|
819 |
+
|
820 |
+
# Forward pass
|
821 |
+
output = forward(model, mb_query_responses, processing_class.pad_token_id)
|
822 |
+
logits = output.logits[:, context_length - 1 : -1]
|
823 |
+
logits /= args.temperature + 1e-7
|
824 |
+
|
825 |
+
# Compute new logprobs
|
826 |
+
new_logprobs = selective_log_softmax(logits, mb_responses)
|
827 |
+
new_logprobs = torch.masked_fill(
|
828 |
+
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
829 |
+
)
|
830 |
+
|
831 |
+
# Compute probability ratios
|
832 |
+
new_ratio = (new_logprobs - mb_logprobs).exp()
|
833 |
+
new_logprobs = new_logprobs.sum(1)
|
834 |
+
mb_logprobs = mb_logprobs.sum(1)
|
835 |
+
logprobs_diff = new_logprobs - mb_logprobs
|
836 |
+
ratio = torch.exp(logprobs_diff)
|
837 |
+
|
838 |
+
# PPO clipped loss
|
839 |
+
pg_losses = -mb_advantage * ratio
|
840 |
+
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
841 |
+
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
842 |
+
pg_loss = pg_loss_max.mean()
|
843 |
+
|
844 |
+
# Final loss
|
845 |
+
loss = pg_loss
|
846 |
+
|
847 |
+
# Optimization step
|
848 |
+
accelerator.backward(loss)
|
849 |
+
optimizer.step()
|
850 |
+
optimizer.zero_grad()
|
851 |
+
|
852 |
+
with torch.no_grad():
|
853 |
+
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
|
854 |
+
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
855 |
+
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
856 |
+
approxkl = 0.5 * (logprobs_diff**2).mean()
|
857 |
+
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
858 |
+
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
859 |
+
pg_clipfrac
|
860 |
+
)
|
861 |
+
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
862 |
+
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
863 |
+
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
|
864 |
+
gradient_accumulation_idx += 1
|
865 |
+
minibatch_idx += 1
|
866 |
+
|
867 |
+
# del everything and empty cache
|
868 |
+
# fmt: off
|
869 |
+
del (
|
870 |
+
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
|
871 |
+
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
|
872 |
+
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
|
873 |
+
)
|
874 |
+
# fmt: on
|
875 |
+
torch.cuda.empty_cache()
|
876 |
+
|
877 |
+
# Compute metrics
|
878 |
+
with torch.no_grad():
|
879 |
+
mean_kl = kl.sum(1).mean()
|
880 |
+
mean_entropy = (-logprobs).sum(1).mean()
|
881 |
+
mean_non_score_reward = non_score_reward.mean()
|
882 |
+
eps = int(self.state.episode / (time.time() - start_time))
|
883 |
+
metrics = {}
|
884 |
+
metrics["eps"] = eps
|
885 |
+
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
886 |
+
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
887 |
+
metrics["objective/non_score_reward"] = (
|
888 |
+
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
889 |
+
)
|
890 |
+
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
891 |
+
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
892 |
+
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
893 |
+
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
894 |
+
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
895 |
+
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
896 |
+
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
897 |
+
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
898 |
+
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
899 |
+
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
900 |
+
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
901 |
+
metrics["episode"] = self.state.episode
|
902 |
+
self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
|
903 |
+
self.log(metrics)
|
904 |
+
del kl, mean_kl, mean_entropy, scores
|
905 |
+
|
906 |
+
self.lr_scheduler.step()
|
907 |
+
self.state.global_step += 1
|
908 |
+
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
909 |
+
if self.control.should_save:
|
910 |
+
self._save_checkpoint(model, trial=None)
|
911 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
912 |
+
torch.cuda.empty_cache()
|
913 |
+
gc.collect()
|
914 |
+
|
915 |
+
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
916 |
+
self.generate_completions(sampling=True)
|
917 |
+
|
918 |
+
# HF trainer specifics
|
919 |
+
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
920 |
+
if self.control.should_save:
|
921 |
+
self._save_checkpoint(model, trial=None, metrics=None)
|
922 |
+
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
923 |
+
|
924 |
+
def generate_completions(self, sampling: bool = False):
|
925 |
+
args = self.args
|
926 |
+
processing_class = self.processing_class
|
927 |
+
generation_config = GenerationConfig(
|
928 |
+
max_new_tokens=self.args.response_length,
|
929 |
+
temperature=(0.01 + 1e-7),
|
930 |
+
top_k=0.0,
|
931 |
+
top_p=1.0,
|
932 |
+
do_sample=True,
|
933 |
+
)
|
934 |
+
|
935 |
+
table = defaultdict(list)
|
936 |
+
with unwrap_model_for_generation(
|
937 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
938 |
+
) as unwrapped_model:
|
939 |
+
for batch in self.eval_dataloader:
|
940 |
+
query = batch["input_ids"]
|
941 |
+
with torch.no_grad():
|
942 |
+
context_length = query.shape[1]
|
943 |
+
query_response, _ = batch_generation(
|
944 |
+
unwrapped_model,
|
945 |
+
query,
|
946 |
+
query.shape[0],
|
947 |
+
processing_class.pad_token_id,
|
948 |
+
generation_config,
|
949 |
+
)
|
950 |
+
response = query_response[:, context_length:]
|
951 |
+
postprocessed_response = response
|
952 |
+
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
953 |
+
postprocessed_response = truncate_response(
|
954 |
+
args.stop_token_id, processing_class.pad_token_id, response
|
955 |
+
)
|
956 |
+
table["query"].extend(
|
957 |
+
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
958 |
+
)
|
959 |
+
table["model response"].extend(
|
960 |
+
gather_object(processing_class.batch_decode(postprocessed_response))
|
961 |
+
)
|
962 |
+
|
963 |
+
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
964 |
+
|
965 |
+
if isinstance(self.reward_model, nn.Module):
|
966 |
+
_, score, _ = get_reward(
|
967 |
+
self.reward_model,
|
968 |
+
postprocessed_query_response,
|
969 |
+
processing_class.pad_token_id,
|
970 |
+
context_length,
|
971 |
+
)
|
972 |
+
else:
|
973 |
+
score = torch.tensor(
|
974 |
+
self.reward_model(
|
975 |
+
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
976 |
+
),
|
977 |
+
dtype=torch.float,
|
978 |
+
).to(postprocessed_query_response.device)
|
979 |
+
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
980 |
+
|
981 |
+
if sampling:
|
982 |
+
break
|
983 |
+
df = pd.DataFrame(table)
|
984 |
+
|
985 |
+
if self.accelerator.is_main_process:
|
986 |
+
print_rich_table(df.iloc[0 : 0 + 5])
|
987 |
+
if "wandb" in args.report_to:
|
988 |
+
import wandb
|
989 |
+
|
990 |
+
if wandb.run is not None:
|
991 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
992 |
+
|
993 |
+
if "comet_ml" in args.report_to:
|
994 |
+
log_table_to_comet_experiment(
|
995 |
+
name="completions.csv",
|
996 |
+
table=df,
|
997 |
+
)
|
998 |
+
|
999 |
+
def create_model_card(
|
1000 |
+
self,
|
1001 |
+
model_name: Optional[str] = None,
|
1002 |
+
dataset_name: Optional[str] = None,
|
1003 |
+
tags: Union[str, list[str], None] = None,
|
1004 |
+
):
|
1005 |
+
"""
|
1006 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
1007 |
+
|
1008 |
+
Args:
|
1009 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
1010 |
+
Name of the model.
|
1011 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
1012 |
+
Name of the dataset used for training.
|
1013 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
1014 |
+
Tags to be associated with the model card.
|
1015 |
+
"""
|
1016 |
+
if not self.is_world_process_zero():
|
1017 |
+
return
|
1018 |
+
|
1019 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
1020 |
+
base_model = self.model.config._name_or_path
|
1021 |
+
else:
|
1022 |
+
base_model = None
|
1023 |
+
|
1024 |
+
tags = tags or []
|
1025 |
+
if isinstance(tags, str):
|
1026 |
+
tags = [tags]
|
1027 |
+
|
1028 |
+
if hasattr(self.model.config, "unsloth_version"):
|
1029 |
+
tags.append("unsloth")
|
1030 |
+
|
1031 |
+
citation = textwrap.dedent("""\
|
1032 |
+
@inproceedings{ahmadian2024back,
|
1033 |
+
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
|
1034 |
+
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
|
1035 |
+
year = 2024,
|
1036 |
+
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
|
1037 |
+
publisher = {Association for Computational Linguistics},
|
1038 |
+
pages = {12248--12267},
|
1039 |
+
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
|
1040 |
+
}""")
|
1041 |
+
|
1042 |
+
model_card = generate_model_card(
|
1043 |
+
base_model=base_model,
|
1044 |
+
model_name=model_name,
|
1045 |
+
hub_model_id=self.hub_model_id,
|
1046 |
+
dataset_name=dataset_name,
|
1047 |
+
tags=tags,
|
1048 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
1049 |
+
comet_url=get_comet_experiment_url(),
|
1050 |
+
trainer_name="RLOO",
|
1051 |
+
trainer_citation=citation,
|
1052 |
+
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
|
1053 |
+
paper_id="2402.14740",
|
1054 |
+
)
|
1055 |
+
|
1056 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
1057 |
+
class UnslothRLOOTrainer(_UnslothRLOOTrainer):
|
1058 |
+
"""
|
1059 |
+
|
1060 |
+
"""
|
1061 |
+
def __init__(
|
1062 |
+
self,
|
1063 |
+
config,
|
1064 |
+
processing_class,
|
1065 |
+
policy,
|
1066 |
+
ref_policy,
|
1067 |
+
reward_model,
|
1068 |
+
train_dataset,
|
1069 |
+
data_collator = None,
|
1070 |
+
eval_dataset = None,
|
1071 |
+
callbacks = None,
|
1072 |
+
**kwargs
|
1073 |
+
):
|
1074 |
+
if args is None: args = UnslothRLOOConfig()
|
1075 |
+
_output_logits = False
|
1076 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
1077 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
1078 |
+
if _output_logits:
|
1079 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
1080 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
1081 |
+
pass
|
1082 |
+
else:
|
1083 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
1084 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
1085 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
1086 |
+
max_seq_length = model.max_seq_length
|
1087 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
1088 |
+
if model is not None and hasattr(model, 'for_training'):
|
1089 |
+
model.for_training()
|
1090 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
1091 |
+
if 'processing_class' in locals():
|
1092 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
1093 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
1094 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
1095 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
1096 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1097 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
1098 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
1099 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
1100 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
1101 |
+
else:
|
1102 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
1103 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
1104 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
1105 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
1106 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
1107 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
1108 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
1109 |
+
else:
|
1110 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
1111 |
+
other_metrics = []
|
1112 |
+
|
1113 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
1114 |
+
PatchRLStatistics('rloo_trainer', other_metrics)
|
1115 |
+
|
1116 |
+
super().__init__(
|
1117 |
+
config = config,
|
1118 |
+
processing_class = processing_class,
|
1119 |
+
policy = policy,
|
1120 |
+
ref_policy = ref_policy,
|
1121 |
+
reward_model = reward_model,
|
1122 |
+
train_dataset = train_dataset,
|
1123 |
+
data_collator = data_collator,
|
1124 |
+
eval_dataset = eval_dataset,
|
1125 |
+
callbacks = callbacks,**kwargs)
|
1126 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1127 |
+
self.neftune_hook_handle.remove()
|
1128 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1129 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1130 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1131 |
+
pass
|
1132 |
+
|
1133 |
+
pass
|
unsloth_compiled_cache/UnslothRewardTrainer.py
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, warnings)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothRewardConfig(RewardConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`RewardTrainer`].
|
47 |
+
|
48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
50 |
+
command line.
|
51 |
+
|
52 |
+
Parameters:
|
53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
54 |
+
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
55 |
+
limit. This argument is required if you want to use the default data collator.
|
56 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
57 |
+
Whether to disable dropout in the model.
|
58 |
+
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
59 |
+
Number of processes to use for processing the dataset.
|
60 |
+
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
61 |
+
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
62 |
+
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
63 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
64 |
+
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
|
65 |
+
the dataset is pretokenized.
|
66 |
+
|
67 |
+
"""
|
68 |
+
vllm_sampling_params: Optional[Any] = field(
|
69 |
+
default = None,
|
70 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
71 |
+
)
|
72 |
+
unsloth_num_chunks : Optional[int] = field(
|
73 |
+
default = -1,
|
74 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
75 |
+
)
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
output_dir = None,
|
79 |
+
overwrite_output_dir = None,
|
80 |
+
do_train = False,
|
81 |
+
do_eval = False,
|
82 |
+
do_predict = False,
|
83 |
+
eval_strategy = 'no',
|
84 |
+
prediction_loss_only = False,
|
85 |
+
per_device_train_batch_size = 4,
|
86 |
+
per_device_eval_batch_size = 4,
|
87 |
+
per_gpu_train_batch_size = None,
|
88 |
+
per_gpu_eval_batch_size = None,
|
89 |
+
gradient_accumulation_steps = 2,
|
90 |
+
eval_accumulation_steps = 2,
|
91 |
+
eval_delay = 0,
|
92 |
+
torch_empty_cache_steps = 250,
|
93 |
+
learning_rate = 5e-05,
|
94 |
+
weight_decay = 0.01,
|
95 |
+
adam_beta1 = 0.9,
|
96 |
+
adam_beta2 = 0.999,
|
97 |
+
adam_epsilon = 1e-08,
|
98 |
+
max_grad_norm = 1.0,
|
99 |
+
num_train_epochs = 3.0,
|
100 |
+
max_steps = -1,
|
101 |
+
lr_scheduler_type = 'linear',
|
102 |
+
warmup_ratio = 0.1,
|
103 |
+
warmup_steps = 0,
|
104 |
+
log_level = 'passive',
|
105 |
+
log_level_replica = 'warning',
|
106 |
+
log_on_each_node = True,
|
107 |
+
logging_dir = None,
|
108 |
+
logging_strategy = 'steps',
|
109 |
+
logging_first_step = False,
|
110 |
+
logging_steps = 1,
|
111 |
+
logging_nan_inf_filter = False,
|
112 |
+
save_strategy = 'steps',
|
113 |
+
save_steps = 500,
|
114 |
+
save_total_limit = None,
|
115 |
+
save_safetensors = True,
|
116 |
+
save_on_each_node = False,
|
117 |
+
save_only_model = False,
|
118 |
+
restore_callback_states_from_checkpoint = False,
|
119 |
+
no_cuda = False,
|
120 |
+
use_cpu = False,
|
121 |
+
use_mps_device = False,
|
122 |
+
seed = 3407,
|
123 |
+
data_seed = 3407,
|
124 |
+
jit_mode_eval = False,
|
125 |
+
use_ipex = False,
|
126 |
+
bf16 = False,
|
127 |
+
fp16 = False,
|
128 |
+
fp16_opt_level = 'O1',
|
129 |
+
half_precision_backend = 'auto',
|
130 |
+
bf16_full_eval = False,
|
131 |
+
fp16_full_eval = False,
|
132 |
+
tf32 = None,
|
133 |
+
local_rank = -1,
|
134 |
+
ddp_backend = None,
|
135 |
+
tpu_num_cores = None,
|
136 |
+
tpu_metrics_debug = False,
|
137 |
+
debug = '',
|
138 |
+
dataloader_drop_last = False,
|
139 |
+
eval_steps = None,
|
140 |
+
dataloader_num_workers = 0,
|
141 |
+
dataloader_prefetch_factor = None,
|
142 |
+
past_index = -1,
|
143 |
+
run_name = None,
|
144 |
+
disable_tqdm = None,
|
145 |
+
remove_unused_columns = False,
|
146 |
+
label_names = None,
|
147 |
+
load_best_model_at_end = False,
|
148 |
+
metric_for_best_model = None,
|
149 |
+
greater_is_better = None,
|
150 |
+
ignore_data_skip = False,
|
151 |
+
fsdp = '',
|
152 |
+
fsdp_min_num_params = 0,
|
153 |
+
fsdp_config = None,
|
154 |
+
tp_size = 0,
|
155 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
156 |
+
accelerator_config = None,
|
157 |
+
deepspeed = None,
|
158 |
+
label_smoothing_factor = 0.0,
|
159 |
+
optim = 'adamw_8bit',
|
160 |
+
optim_args = None,
|
161 |
+
adafactor = False,
|
162 |
+
group_by_length = False,
|
163 |
+
length_column_name = 'length',
|
164 |
+
report_to = None,
|
165 |
+
ddp_find_unused_parameters = None,
|
166 |
+
ddp_bucket_cap_mb = None,
|
167 |
+
ddp_broadcast_buffers = None,
|
168 |
+
dataloader_pin_memory = True,
|
169 |
+
dataloader_persistent_workers = False,
|
170 |
+
skip_memory_metrics = True,
|
171 |
+
use_legacy_prediction_loop = False,
|
172 |
+
push_to_hub = False,
|
173 |
+
resume_from_checkpoint = None,
|
174 |
+
hub_model_id = None,
|
175 |
+
hub_strategy = 'every_save',
|
176 |
+
hub_token = None,
|
177 |
+
hub_private_repo = None,
|
178 |
+
hub_always_push = False,
|
179 |
+
gradient_checkpointing = False,
|
180 |
+
gradient_checkpointing_kwargs = None,
|
181 |
+
include_inputs_for_metrics = False,
|
182 |
+
eval_do_concat_batches = True,
|
183 |
+
fp16_backend = 'auto',
|
184 |
+
evaluation_strategy = None,
|
185 |
+
push_to_hub_model_id = None,
|
186 |
+
push_to_hub_organization = None,
|
187 |
+
push_to_hub_token = None,
|
188 |
+
mp_parameters = '',
|
189 |
+
auto_find_batch_size = False,
|
190 |
+
full_determinism = False,
|
191 |
+
torchdynamo = None,
|
192 |
+
ray_scope = 'last',
|
193 |
+
ddp_timeout = 1800,
|
194 |
+
torch_compile = False,
|
195 |
+
torch_compile_backend = None,
|
196 |
+
torch_compile_mode = None,
|
197 |
+
dispatch_batches = None,
|
198 |
+
split_batches = None,
|
199 |
+
include_tokens_per_second = False,
|
200 |
+
include_num_input_tokens_seen = False,
|
201 |
+
neftune_noise_alpha = None,
|
202 |
+
optim_target_modules = None,
|
203 |
+
batch_eval_metrics = False,
|
204 |
+
eval_on_start = False,
|
205 |
+
use_liger_kernel = False,
|
206 |
+
eval_use_gather_object = False,
|
207 |
+
average_tokens_across_devices = False,
|
208 |
+
max_length = 1024,
|
209 |
+
disable_dropout = True,
|
210 |
+
dataset_num_proc = None,
|
211 |
+
center_rewards_coefficient = None,
|
212 |
+
vllm_sampling_params = None,
|
213 |
+
unsloth_num_chunks = -1,
|
214 |
+
**kwargs,
|
215 |
+
):
|
216 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
217 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
218 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
219 |
+
output_dir = 'unsloth_training_checkpoints'
|
220 |
+
save_strategy = 'no'
|
221 |
+
if dataset_num_proc is None:
|
222 |
+
from multiprocessing import cpu_count
|
223 |
+
dataset_num_proc = cpu_count()
|
224 |
+
|
225 |
+
super().__init__(
|
226 |
+
output_dir = output_dir,
|
227 |
+
overwrite_output_dir = overwrite_output_dir,
|
228 |
+
do_train = do_train,
|
229 |
+
do_eval = do_eval,
|
230 |
+
do_predict = do_predict,
|
231 |
+
eval_strategy = eval_strategy,
|
232 |
+
prediction_loss_only = prediction_loss_only,
|
233 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
234 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
235 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
236 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
237 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
238 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
239 |
+
eval_delay = eval_delay,
|
240 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
241 |
+
learning_rate = learning_rate,
|
242 |
+
weight_decay = weight_decay,
|
243 |
+
adam_beta1 = adam_beta1,
|
244 |
+
adam_beta2 = adam_beta2,
|
245 |
+
adam_epsilon = adam_epsilon,
|
246 |
+
max_grad_norm = max_grad_norm,
|
247 |
+
num_train_epochs = num_train_epochs,
|
248 |
+
max_steps = max_steps,
|
249 |
+
lr_scheduler_type = lr_scheduler_type,
|
250 |
+
warmup_ratio = warmup_ratio,
|
251 |
+
warmup_steps = warmup_steps,
|
252 |
+
log_level = log_level,
|
253 |
+
log_level_replica = log_level_replica,
|
254 |
+
log_on_each_node = log_on_each_node,
|
255 |
+
logging_dir = logging_dir,
|
256 |
+
logging_strategy = logging_strategy,
|
257 |
+
logging_first_step = logging_first_step,
|
258 |
+
logging_steps = logging_steps,
|
259 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
260 |
+
save_strategy = save_strategy,
|
261 |
+
save_steps = save_steps,
|
262 |
+
save_total_limit = save_total_limit,
|
263 |
+
save_safetensors = save_safetensors,
|
264 |
+
save_on_each_node = save_on_each_node,
|
265 |
+
save_only_model = save_only_model,
|
266 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
267 |
+
no_cuda = no_cuda,
|
268 |
+
use_cpu = use_cpu,
|
269 |
+
use_mps_device = use_mps_device,
|
270 |
+
seed = seed,
|
271 |
+
data_seed = data_seed,
|
272 |
+
jit_mode_eval = jit_mode_eval,
|
273 |
+
use_ipex = use_ipex,
|
274 |
+
bf16 = bf16,
|
275 |
+
fp16 = fp16,
|
276 |
+
fp16_opt_level = fp16_opt_level,
|
277 |
+
half_precision_backend = half_precision_backend,
|
278 |
+
bf16_full_eval = bf16_full_eval,
|
279 |
+
fp16_full_eval = fp16_full_eval,
|
280 |
+
tf32 = tf32,
|
281 |
+
local_rank = local_rank,
|
282 |
+
ddp_backend = ddp_backend,
|
283 |
+
tpu_num_cores = tpu_num_cores,
|
284 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
285 |
+
debug = debug,
|
286 |
+
dataloader_drop_last = dataloader_drop_last,
|
287 |
+
eval_steps = eval_steps,
|
288 |
+
dataloader_num_workers = dataloader_num_workers,
|
289 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
290 |
+
past_index = past_index,
|
291 |
+
run_name = run_name,
|
292 |
+
disable_tqdm = disable_tqdm,
|
293 |
+
remove_unused_columns = remove_unused_columns,
|
294 |
+
label_names = label_names,
|
295 |
+
load_best_model_at_end = load_best_model_at_end,
|
296 |
+
metric_for_best_model = metric_for_best_model,
|
297 |
+
greater_is_better = greater_is_better,
|
298 |
+
ignore_data_skip = ignore_data_skip,
|
299 |
+
fsdp = fsdp,
|
300 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
301 |
+
fsdp_config = fsdp_config,
|
302 |
+
tp_size = tp_size,
|
303 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
304 |
+
accelerator_config = accelerator_config,
|
305 |
+
deepspeed = deepspeed,
|
306 |
+
label_smoothing_factor = label_smoothing_factor,
|
307 |
+
optim = optim,
|
308 |
+
optim_args = optim_args,
|
309 |
+
adafactor = adafactor,
|
310 |
+
group_by_length = group_by_length,
|
311 |
+
length_column_name = length_column_name,
|
312 |
+
report_to = report_to,
|
313 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
314 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
315 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
316 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
317 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
318 |
+
skip_memory_metrics = skip_memory_metrics,
|
319 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
320 |
+
push_to_hub = push_to_hub,
|
321 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
322 |
+
hub_model_id = hub_model_id,
|
323 |
+
hub_strategy = hub_strategy,
|
324 |
+
hub_token = hub_token,
|
325 |
+
hub_private_repo = hub_private_repo,
|
326 |
+
hub_always_push = hub_always_push,
|
327 |
+
gradient_checkpointing = gradient_checkpointing,
|
328 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
329 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
330 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
331 |
+
fp16_backend = fp16_backend,
|
332 |
+
evaluation_strategy = evaluation_strategy,
|
333 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
334 |
+
push_to_hub_organization = push_to_hub_organization,
|
335 |
+
push_to_hub_token = push_to_hub_token,
|
336 |
+
mp_parameters = mp_parameters,
|
337 |
+
auto_find_batch_size = auto_find_batch_size,
|
338 |
+
full_determinism = full_determinism,
|
339 |
+
torchdynamo = torchdynamo,
|
340 |
+
ray_scope = ray_scope,
|
341 |
+
ddp_timeout = ddp_timeout,
|
342 |
+
torch_compile = torch_compile,
|
343 |
+
torch_compile_backend = torch_compile_backend,
|
344 |
+
torch_compile_mode = torch_compile_mode,
|
345 |
+
dispatch_batches = dispatch_batches,
|
346 |
+
split_batches = split_batches,
|
347 |
+
include_tokens_per_second = include_tokens_per_second,
|
348 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
349 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
350 |
+
optim_target_modules = optim_target_modules,
|
351 |
+
batch_eval_metrics = batch_eval_metrics,
|
352 |
+
eval_on_start = eval_on_start,
|
353 |
+
use_liger_kernel = use_liger_kernel,
|
354 |
+
eval_use_gather_object = eval_use_gather_object,
|
355 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
356 |
+
max_length = max_length,
|
357 |
+
disable_dropout = disable_dropout,
|
358 |
+
dataset_num_proc = dataset_num_proc,
|
359 |
+
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
360 |
+
self.vllm_sampling_params = vllm_sampling_params
|
361 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
362 |
+
pass
|
363 |
+
|
364 |
+
class _UnslothRewardTrainer(Trainer):
|
365 |
+
_tag_names = ["trl", "reward-trainer"]
|
366 |
+
|
367 |
+
def __init__(
|
368 |
+
self,
|
369 |
+
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
370 |
+
args: Optional[RewardConfig] = None,
|
371 |
+
data_collator: Optional[DataCollator] = None,
|
372 |
+
train_dataset: Optional[Dataset] = None,
|
373 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
374 |
+
processing_class: Optional[
|
375 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
376 |
+
] = None,
|
377 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
378 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
379 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
380 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
381 |
+
None,
|
382 |
+
None,
|
383 |
+
),
|
384 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
385 |
+
peft_config: Optional[dict] = None,
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Initialize RewardTrainer.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
model (`transformers.PreTrainedModel`):
|
392 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
393 |
+
args (`RewardConfig`):
|
394 |
+
The arguments to use for training.
|
395 |
+
data_collator (`transformers.DataCollator`):
|
396 |
+
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
397 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
398 |
+
train_dataset (`datasets.Dataset`):
|
399 |
+
The dataset to use for training.
|
400 |
+
eval_dataset (`datasets.Dataset`):
|
401 |
+
The dataset to use for evaluation.
|
402 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
403 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
404 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
405 |
+
reuse the fine-tuned model.
|
406 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
407 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
408 |
+
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
409 |
+
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
410 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
411 |
+
The callbacks to use for training.
|
412 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
413 |
+
The optimizer and scheduler to use for training.
|
414 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
415 |
+
The function to use to preprocess the logits before computing the metrics.
|
416 |
+
peft_config (`dict`, defaults to `None`):
|
417 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
418 |
+
"""
|
419 |
+
if not is_peft_available() and peft_config is not None:
|
420 |
+
raise ValueError(
|
421 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
422 |
+
)
|
423 |
+
elif is_peft_available() and peft_config is not None:
|
424 |
+
if not isinstance(model, PeftModel):
|
425 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
426 |
+
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
427 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
428 |
+
)
|
429 |
+
|
430 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
431 |
+
|
432 |
+
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
433 |
+
warnings.warn(
|
434 |
+
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
435 |
+
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
436 |
+
UserWarning,
|
437 |
+
)
|
438 |
+
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
439 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
440 |
+
|
441 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
442 |
+
|
443 |
+
model = model
|
444 |
+
|
445 |
+
# Disable dropout in the model
|
446 |
+
if args.disable_dropout:
|
447 |
+
disable_dropout_in_model(model)
|
448 |
+
|
449 |
+
if compute_metrics is None:
|
450 |
+
compute_metrics = compute_accuracy
|
451 |
+
|
452 |
+
if data_collator is None:
|
453 |
+
if processing_class is None:
|
454 |
+
raise ValueError(
|
455 |
+
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
456 |
+
)
|
457 |
+
|
458 |
+
max_length = args.max_length
|
459 |
+
|
460 |
+
data_collator = RewardDataCollatorWithPadding(processing_class)
|
461 |
+
|
462 |
+
if args.remove_unused_columns:
|
463 |
+
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
464 |
+
args.remove_unused_columns = False
|
465 |
+
except FrozenInstanceError:
|
466 |
+
args = replace(args, remove_unused_columns=False)
|
467 |
+
# warn users
|
468 |
+
warnings.warn(
|
469 |
+
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
470 |
+
" we have set it for you, but you should do it yourself in the future.",
|
471 |
+
UserWarning,
|
472 |
+
)
|
473 |
+
|
474 |
+
self.use_reward_data_collator = True
|
475 |
+
else:
|
476 |
+
self.use_reward_data_collator = False
|
477 |
+
|
478 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
479 |
+
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
480 |
+
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
481 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
482 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
483 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
484 |
+
# issued.
|
485 |
+
model.warnings_issued["estimate_tokens"] = True
|
486 |
+
|
487 |
+
if "input_ids_chosen" not in train_dataset.column_names:
|
488 |
+
with PartialState().local_main_process_first():
|
489 |
+
fn_kwargs = {"tokenizer": processing_class}
|
490 |
+
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
491 |
+
train_dataset = train_dataset.map(
|
492 |
+
_tokenize,
|
493 |
+
batched=True,
|
494 |
+
fn_kwargs=fn_kwargs,
|
495 |
+
num_proc=args.dataset_num_proc,
|
496 |
+
)
|
497 |
+
# This filter is important because otherwise you get samples that exceed the model's context length and
|
498 |
+
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
499 |
+
# user might get surprised if N samples are missing from training.
|
500 |
+
train_dataset = train_dataset.filter(
|
501 |
+
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
502 |
+
num_proc=args.dataset_num_proc,
|
503 |
+
)
|
504 |
+
if eval_dataset is not None:
|
505 |
+
eval_dataset = eval_dataset.map(
|
506 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
507 |
+
)
|
508 |
+
eval_dataset = eval_dataset.map(
|
509 |
+
_tokenize,
|
510 |
+
fn_kwargs=fn_kwargs,
|
511 |
+
batched=True,
|
512 |
+
num_proc=args.dataset_num_proc,
|
513 |
+
)
|
514 |
+
# This filter is important because otherwise you get samples that exceed the model's context length and
|
515 |
+
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
516 |
+
# user might get surprised if N samples are missing from training.
|
517 |
+
eval_dataset = eval_dataset.filter(
|
518 |
+
lambda x: len(x["input_ids_chosen"]) <= max_length
|
519 |
+
and len(x["input_ids_rejected"]) <= max_length,
|
520 |
+
num_proc=args.dataset_num_proc,
|
521 |
+
)
|
522 |
+
|
523 |
+
super().__init__(
|
524 |
+
model=model,
|
525 |
+
args=args,
|
526 |
+
data_collator=data_collator,
|
527 |
+
train_dataset=train_dataset,
|
528 |
+
eval_dataset=eval_dataset,
|
529 |
+
processing_class=processing_class,
|
530 |
+
model_init=model_init,
|
531 |
+
compute_metrics=compute_metrics,
|
532 |
+
callbacks=callbacks,
|
533 |
+
optimizers=optimizers,
|
534 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
535 |
+
)
|
536 |
+
|
537 |
+
# Add tags for models that have been loaded with the correct transformers version
|
538 |
+
if hasattr(self.model, "add_model_tags"):
|
539 |
+
self.model.add_model_tags(self._tag_names)
|
540 |
+
|
541 |
+
def compute_loss(
|
542 |
+
self,
|
543 |
+
model: Union[PreTrainedModel, nn.Module],
|
544 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
545 |
+
return_outputs=False,
|
546 |
+
num_items_in_batch=None,
|
547 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
548 |
+
rewards_chosen = model(
|
549 |
+
input_ids=inputs["input_ids_chosen"],
|
550 |
+
attention_mask=inputs["attention_mask_chosen"],
|
551 |
+
return_dict=True,
|
552 |
+
)["logits"]
|
553 |
+
rewards_rejected = model(
|
554 |
+
input_ids=inputs["input_ids_rejected"],
|
555 |
+
attention_mask=inputs["attention_mask_rejected"],
|
556 |
+
return_dict=True,
|
557 |
+
)["logits"]
|
558 |
+
# calculate loss, optionally modulate with margin
|
559 |
+
if "margin" in inputs:
|
560 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
561 |
+
else:
|
562 |
+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
563 |
+
|
564 |
+
if self.args.center_rewards_coefficient is not None:
|
565 |
+
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
566 |
+
|
567 |
+
if return_outputs:
|
568 |
+
return loss, {
|
569 |
+
"rewards_chosen": rewards_chosen,
|
570 |
+
"rewards_rejected": rewards_rejected,
|
571 |
+
}
|
572 |
+
return loss
|
573 |
+
|
574 |
+
def prediction_step(
|
575 |
+
self,
|
576 |
+
model: Union[PreTrainedModel, nn.Module],
|
577 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
578 |
+
prediction_loss_only: bool,
|
579 |
+
ignore_keys: Optional[list[str]] = None,
|
580 |
+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
581 |
+
inputs = self._prepare_inputs(inputs)
|
582 |
+
if ignore_keys is None:
|
583 |
+
if hasattr(self.model, "config"):
|
584 |
+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
585 |
+
else:
|
586 |
+
ignore_keys = []
|
587 |
+
|
588 |
+
with torch.no_grad():
|
589 |
+
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
590 |
+
|
591 |
+
if prediction_loss_only:
|
592 |
+
return (loss, None, None)
|
593 |
+
|
594 |
+
loss = loss.detach()
|
595 |
+
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
596 |
+
logits = nested_detach(logits)
|
597 |
+
# Stack accepted against rejected, mean over logits
|
598 |
+
# and softmax to get preferences between accepted and rejected to sum to 1
|
599 |
+
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
600 |
+
|
601 |
+
labels = torch.zeros(logits.shape[0])
|
602 |
+
labels = self._prepare_inputs(labels)
|
603 |
+
|
604 |
+
return loss, logits, labels
|
605 |
+
|
606 |
+
def evaluate(self, *args, **kwargs):
|
607 |
+
num_print_samples = kwargs.pop("num_print_samples", 4)
|
608 |
+
self.visualize_samples(num_print_samples)
|
609 |
+
return super().evaluate(*args, **kwargs)
|
610 |
+
|
611 |
+
def visualize_samples(self, num_print_samples: int):
|
612 |
+
"""
|
613 |
+
Visualize the reward model logits prediction
|
614 |
+
|
615 |
+
Args:
|
616 |
+
num_print_samples (`int`, defaults to `4`):
|
617 |
+
The number of samples to print. Set to `-1` to print all samples.
|
618 |
+
"""
|
619 |
+
eval_dataloader = self.get_eval_dataloader()
|
620 |
+
table = defaultdict(list)
|
621 |
+
for _, inputs in enumerate(eval_dataloader):
|
622 |
+
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
623 |
+
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
624 |
+
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
625 |
+
table["chosen_text"].extend(gather_object(chosen_text))
|
626 |
+
table["rejected_text"].extend(gather_object(rejected_text))
|
627 |
+
table["logits"].extend(
|
628 |
+
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
629 |
+
)
|
630 |
+
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
631 |
+
break
|
632 |
+
df = pd.DataFrame(table)
|
633 |
+
if self.accelerator.process_index == 0:
|
634 |
+
print_rich_table(df[:num_print_samples])
|
635 |
+
if "wandb" in self.args.report_to:
|
636 |
+
import wandb
|
637 |
+
|
638 |
+
if wandb.run is not None:
|
639 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
640 |
+
|
641 |
+
if "comet_ml" in self.args.report_to:
|
642 |
+
log_table_to_comet_experiment(
|
643 |
+
name="completions.csv",
|
644 |
+
table=df,
|
645 |
+
)
|
646 |
+
|
647 |
+
def create_model_card(
|
648 |
+
self,
|
649 |
+
model_name: Optional[str] = None,
|
650 |
+
dataset_name: Optional[str] = None,
|
651 |
+
tags: Union[str, list[str], None] = None,
|
652 |
+
):
|
653 |
+
"""
|
654 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
655 |
+
|
656 |
+
Args:
|
657 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
658 |
+
Name of the model.
|
659 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
660 |
+
Name of the dataset used for training.
|
661 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
662 |
+
Tags to be associated with the model card.
|
663 |
+
"""
|
664 |
+
if not self.is_world_process_zero():
|
665 |
+
return
|
666 |
+
|
667 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
668 |
+
base_model = self.model.config._name_or_path
|
669 |
+
else:
|
670 |
+
base_model = None
|
671 |
+
|
672 |
+
tags = tags or []
|
673 |
+
if isinstance(tags, str):
|
674 |
+
tags = [tags]
|
675 |
+
|
676 |
+
if hasattr(self.model.config, "unsloth_version"):
|
677 |
+
tags.append("unsloth")
|
678 |
+
|
679 |
+
model_card = generate_model_card(
|
680 |
+
base_model=base_model,
|
681 |
+
model_name=model_name,
|
682 |
+
hub_model_id=self.hub_model_id,
|
683 |
+
dataset_name=dataset_name,
|
684 |
+
tags=tags,
|
685 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
686 |
+
comet_url=get_comet_experiment_url(),
|
687 |
+
trainer_name="Reward",
|
688 |
+
)
|
689 |
+
|
690 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
691 |
+
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
692 |
+
"""
|
693 |
+
|
694 |
+
"""
|
695 |
+
def __init__(
|
696 |
+
self,
|
697 |
+
model = None,
|
698 |
+
args = None,
|
699 |
+
data_collator = None,
|
700 |
+
train_dataset = None,
|
701 |
+
eval_dataset = None,
|
702 |
+
processing_class = None,
|
703 |
+
model_init = None,
|
704 |
+
compute_metrics = None,
|
705 |
+
callbacks = None,
|
706 |
+
preprocess_logits_for_metrics = None,
|
707 |
+
peft_config = None,
|
708 |
+
**kwargs
|
709 |
+
):
|
710 |
+
if args is None: args = UnslothRewardConfig()
|
711 |
+
use_bf16 = getattr(args, 'bf16', False)
|
712 |
+
use_fp16 = getattr(args, 'fp16', False)
|
713 |
+
force_float32 = False
|
714 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
715 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
716 |
+
force_float32 = True
|
717 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
718 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
719 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
720 |
+
from unsloth_zoo.utils import _get_dtype
|
721 |
+
dtype = _get_dtype(dtype)
|
722 |
+
float16 = dtype == torch.float16
|
723 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
724 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
725 |
+
if force_float32:
|
726 |
+
args.fp16 = False
|
727 |
+
args.bf16 = False
|
728 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
729 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
730 |
+
args.fp16 = float16
|
731 |
+
args.bf16 = not float16
|
732 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
733 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
734 |
+
args.eval_strategy = 'steps'
|
735 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
736 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
737 |
+
if ga_steps is not None and ga_steps > 1:
|
738 |
+
from transformers import __version__ as transformers_version
|
739 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
740 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
741 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
742 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
743 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
744 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
745 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
746 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
747 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
748 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
749 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
750 |
+
if force_float32:
|
751 |
+
args.bf16_full_eval = False
|
752 |
+
args.fp16_full_eval = False
|
753 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
754 |
+
args.bf16_full_eval = True
|
755 |
+
args.fp16_full_eval = False
|
756 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
757 |
+
args.bf16_full_eval = args.bf16
|
758 |
+
args.fp16_full_eval = args.fp16
|
759 |
+
_output_logits = False
|
760 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
761 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
762 |
+
if _output_logits:
|
763 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
764 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
765 |
+
pass
|
766 |
+
else:
|
767 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
768 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
769 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
770 |
+
max_seq_length = model.max_seq_length
|
771 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
772 |
+
if model is not None and hasattr(model, 'for_training'):
|
773 |
+
model.for_training()
|
774 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
775 |
+
if 'processing_class' in locals():
|
776 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
777 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
778 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
779 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
780 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
781 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
782 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
783 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
784 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
785 |
+
else:
|
786 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
787 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
788 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
789 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
790 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
791 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
792 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
793 |
+
else:
|
794 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
795 |
+
other_metrics = []
|
796 |
+
|
797 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
798 |
+
PatchRLStatistics('reward_trainer', other_metrics)
|
799 |
+
|
800 |
+
super().__init__(
|
801 |
+
model = model,
|
802 |
+
args = args,
|
803 |
+
data_collator = data_collator,
|
804 |
+
train_dataset = train_dataset,
|
805 |
+
eval_dataset = eval_dataset,
|
806 |
+
processing_class = processing_class,
|
807 |
+
model_init = model_init,
|
808 |
+
compute_metrics = compute_metrics,
|
809 |
+
callbacks = callbacks,
|
810 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
811 |
+
peft_config = peft_config,**kwargs)
|
812 |
+
if hasattr(self, 'neftune_hook_handle'):
|
813 |
+
self.neftune_hook_handle.remove()
|
814 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
815 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
816 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
817 |
+
pass
|
818 |
+
|
819 |
+
pass
|
unsloth_compiled_cache/UnslothSFTTrainer.py
ADDED
@@ -0,0 +1,1027 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothSFTConfig(SFTConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`SFTTrainer`].
|
47 |
+
|
48 |
+
Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
|
49 |
+
[`~transformers.TrainingArguments`] documentation.
|
50 |
+
|
51 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
52 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
53 |
+
command line.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
> Parameters that control the model
|
57 |
+
|
58 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
59 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
60 |
+
argument of the [`SFTTrainer`] is provided as a string.
|
61 |
+
use_liger (`bool`, *optional*, defaults to `False`):
|
62 |
+
Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
|
63 |
+
|
64 |
+
> Parameters that control the data preprocessing
|
65 |
+
|
66 |
+
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
67 |
+
Name of the column that contains text data in the dataset.
|
68 |
+
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
69 |
+
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
70 |
+
`skip_prepare_dataset`.
|
71 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
72 |
+
Number of processes to use for processing the dataset.
|
73 |
+
max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
|
74 |
+
Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
|
75 |
+
right.
|
76 |
+
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
77 |
+
packing (`bool`, *optional*, defaults to `False`):
|
78 |
+
Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
|
79 |
+
length.
|
80 |
+
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
81 |
+
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
82 |
+
|
83 |
+
> Parameters that control the training
|
84 |
+
|
85 |
+
learning_rate (`float`, *optional*, defaults to `2e-5`):
|
86 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
87 |
+
[`~transformers.TrainingArguments`].
|
88 |
+
|
89 |
+
"""
|
90 |
+
vllm_sampling_params: Optional[Any] = field(
|
91 |
+
default = None,
|
92 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
93 |
+
)
|
94 |
+
unsloth_num_chunks : Optional[int] = field(
|
95 |
+
default = -1,
|
96 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
97 |
+
)
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
output_dir = None,
|
101 |
+
overwrite_output_dir = None,
|
102 |
+
do_train = False,
|
103 |
+
do_eval = False,
|
104 |
+
do_predict = False,
|
105 |
+
eval_strategy = 'no',
|
106 |
+
prediction_loss_only = False,
|
107 |
+
per_device_train_batch_size = 4,
|
108 |
+
per_device_eval_batch_size = 4,
|
109 |
+
per_gpu_train_batch_size = None,
|
110 |
+
per_gpu_eval_batch_size = None,
|
111 |
+
gradient_accumulation_steps = 2,
|
112 |
+
eval_accumulation_steps = 2,
|
113 |
+
eval_delay = 0,
|
114 |
+
torch_empty_cache_steps = 250,
|
115 |
+
learning_rate = 5e-05,
|
116 |
+
weight_decay = 0.01,
|
117 |
+
adam_beta1 = 0.9,
|
118 |
+
adam_beta2 = 0.999,
|
119 |
+
adam_epsilon = 1e-08,
|
120 |
+
max_grad_norm = 1.0,
|
121 |
+
num_train_epochs = 3.0,
|
122 |
+
max_steps = -1,
|
123 |
+
lr_scheduler_type = 'linear',
|
124 |
+
warmup_ratio = 0.1,
|
125 |
+
warmup_steps = 0,
|
126 |
+
log_level = 'passive',
|
127 |
+
log_level_replica = 'warning',
|
128 |
+
log_on_each_node = True,
|
129 |
+
logging_dir = None,
|
130 |
+
logging_strategy = 'steps',
|
131 |
+
logging_first_step = False,
|
132 |
+
logging_steps = 1,
|
133 |
+
logging_nan_inf_filter = False,
|
134 |
+
save_strategy = 'steps',
|
135 |
+
save_steps = 500,
|
136 |
+
save_total_limit = None,
|
137 |
+
save_safetensors = True,
|
138 |
+
save_on_each_node = False,
|
139 |
+
save_only_model = False,
|
140 |
+
restore_callback_states_from_checkpoint = False,
|
141 |
+
no_cuda = False,
|
142 |
+
use_cpu = False,
|
143 |
+
use_mps_device = False,
|
144 |
+
seed = 3407,
|
145 |
+
data_seed = 3407,
|
146 |
+
jit_mode_eval = False,
|
147 |
+
use_ipex = False,
|
148 |
+
bf16 = False,
|
149 |
+
fp16 = False,
|
150 |
+
fp16_opt_level = 'O1',
|
151 |
+
half_precision_backend = 'auto',
|
152 |
+
bf16_full_eval = False,
|
153 |
+
fp16_full_eval = False,
|
154 |
+
tf32 = None,
|
155 |
+
local_rank = -1,
|
156 |
+
ddp_backend = None,
|
157 |
+
tpu_num_cores = None,
|
158 |
+
tpu_metrics_debug = False,
|
159 |
+
debug = '',
|
160 |
+
dataloader_drop_last = False,
|
161 |
+
eval_steps = None,
|
162 |
+
dataloader_num_workers = 0,
|
163 |
+
dataloader_prefetch_factor = None,
|
164 |
+
past_index = -1,
|
165 |
+
run_name = None,
|
166 |
+
disable_tqdm = None,
|
167 |
+
remove_unused_columns = True,
|
168 |
+
label_names = None,
|
169 |
+
load_best_model_at_end = False,
|
170 |
+
metric_for_best_model = None,
|
171 |
+
greater_is_better = None,
|
172 |
+
ignore_data_skip = False,
|
173 |
+
fsdp = '',
|
174 |
+
fsdp_min_num_params = 0,
|
175 |
+
fsdp_config = None,
|
176 |
+
tp_size = 0,
|
177 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
178 |
+
accelerator_config = None,
|
179 |
+
deepspeed = None,
|
180 |
+
label_smoothing_factor = 0.0,
|
181 |
+
optim = 'adamw_8bit',
|
182 |
+
optim_args = None,
|
183 |
+
adafactor = False,
|
184 |
+
group_by_length = False,
|
185 |
+
length_column_name = 'length',
|
186 |
+
report_to = None,
|
187 |
+
ddp_find_unused_parameters = None,
|
188 |
+
ddp_bucket_cap_mb = None,
|
189 |
+
ddp_broadcast_buffers = None,
|
190 |
+
dataloader_pin_memory = True,
|
191 |
+
dataloader_persistent_workers = False,
|
192 |
+
skip_memory_metrics = True,
|
193 |
+
use_legacy_prediction_loop = False,
|
194 |
+
push_to_hub = False,
|
195 |
+
resume_from_checkpoint = None,
|
196 |
+
hub_model_id = None,
|
197 |
+
hub_strategy = 'every_save',
|
198 |
+
hub_token = None,
|
199 |
+
hub_private_repo = None,
|
200 |
+
hub_always_push = False,
|
201 |
+
gradient_checkpointing = False,
|
202 |
+
gradient_checkpointing_kwargs = None,
|
203 |
+
include_inputs_for_metrics = False,
|
204 |
+
eval_do_concat_batches = True,
|
205 |
+
fp16_backend = 'auto',
|
206 |
+
evaluation_strategy = None,
|
207 |
+
push_to_hub_model_id = None,
|
208 |
+
push_to_hub_organization = None,
|
209 |
+
push_to_hub_token = None,
|
210 |
+
mp_parameters = '',
|
211 |
+
auto_find_batch_size = False,
|
212 |
+
full_determinism = False,
|
213 |
+
torchdynamo = None,
|
214 |
+
ray_scope = 'last',
|
215 |
+
ddp_timeout = 1800,
|
216 |
+
torch_compile = False,
|
217 |
+
torch_compile_backend = None,
|
218 |
+
torch_compile_mode = None,
|
219 |
+
dispatch_batches = None,
|
220 |
+
split_batches = None,
|
221 |
+
include_tokens_per_second = False,
|
222 |
+
include_num_input_tokens_seen = False,
|
223 |
+
neftune_noise_alpha = None,
|
224 |
+
optim_target_modules = None,
|
225 |
+
batch_eval_metrics = False,
|
226 |
+
eval_on_start = False,
|
227 |
+
use_liger_kernel = False,
|
228 |
+
eval_use_gather_object = False,
|
229 |
+
average_tokens_across_devices = False,
|
230 |
+
model_init_kwargs = None,
|
231 |
+
use_liger = False,
|
232 |
+
dataset_text_field = 'text',
|
233 |
+
dataset_kwargs = None,
|
234 |
+
dataset_num_proc = None,
|
235 |
+
max_seq_length = None,
|
236 |
+
packing = False,
|
237 |
+
eval_packing = None,
|
238 |
+
dataset_batch_size = None,
|
239 |
+
num_of_sequences = None,
|
240 |
+
chars_per_token = None,
|
241 |
+
vllm_sampling_params = None,
|
242 |
+
unsloth_num_chunks = -1,
|
243 |
+
**kwargs,
|
244 |
+
):
|
245 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
246 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
247 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
248 |
+
output_dir = 'unsloth_training_checkpoints'
|
249 |
+
save_strategy = 'no'
|
250 |
+
if dataset_num_proc is None:
|
251 |
+
from multiprocessing import cpu_count
|
252 |
+
dataset_num_proc = cpu_count()
|
253 |
+
|
254 |
+
super().__init__(
|
255 |
+
output_dir = output_dir,
|
256 |
+
overwrite_output_dir = overwrite_output_dir,
|
257 |
+
do_train = do_train,
|
258 |
+
do_eval = do_eval,
|
259 |
+
do_predict = do_predict,
|
260 |
+
eval_strategy = eval_strategy,
|
261 |
+
prediction_loss_only = prediction_loss_only,
|
262 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
263 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
264 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
265 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
266 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
267 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
268 |
+
eval_delay = eval_delay,
|
269 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
270 |
+
learning_rate = learning_rate,
|
271 |
+
weight_decay = weight_decay,
|
272 |
+
adam_beta1 = adam_beta1,
|
273 |
+
adam_beta2 = adam_beta2,
|
274 |
+
adam_epsilon = adam_epsilon,
|
275 |
+
max_grad_norm = max_grad_norm,
|
276 |
+
num_train_epochs = num_train_epochs,
|
277 |
+
max_steps = max_steps,
|
278 |
+
lr_scheduler_type = lr_scheduler_type,
|
279 |
+
warmup_ratio = warmup_ratio,
|
280 |
+
warmup_steps = warmup_steps,
|
281 |
+
log_level = log_level,
|
282 |
+
log_level_replica = log_level_replica,
|
283 |
+
log_on_each_node = log_on_each_node,
|
284 |
+
logging_dir = logging_dir,
|
285 |
+
logging_strategy = logging_strategy,
|
286 |
+
logging_first_step = logging_first_step,
|
287 |
+
logging_steps = logging_steps,
|
288 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
289 |
+
save_strategy = save_strategy,
|
290 |
+
save_steps = save_steps,
|
291 |
+
save_total_limit = save_total_limit,
|
292 |
+
save_safetensors = save_safetensors,
|
293 |
+
save_on_each_node = save_on_each_node,
|
294 |
+
save_only_model = save_only_model,
|
295 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
296 |
+
no_cuda = no_cuda,
|
297 |
+
use_cpu = use_cpu,
|
298 |
+
use_mps_device = use_mps_device,
|
299 |
+
seed = seed,
|
300 |
+
data_seed = data_seed,
|
301 |
+
jit_mode_eval = jit_mode_eval,
|
302 |
+
use_ipex = use_ipex,
|
303 |
+
bf16 = bf16,
|
304 |
+
fp16 = fp16,
|
305 |
+
fp16_opt_level = fp16_opt_level,
|
306 |
+
half_precision_backend = half_precision_backend,
|
307 |
+
bf16_full_eval = bf16_full_eval,
|
308 |
+
fp16_full_eval = fp16_full_eval,
|
309 |
+
tf32 = tf32,
|
310 |
+
local_rank = local_rank,
|
311 |
+
ddp_backend = ddp_backend,
|
312 |
+
tpu_num_cores = tpu_num_cores,
|
313 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
314 |
+
debug = debug,
|
315 |
+
dataloader_drop_last = dataloader_drop_last,
|
316 |
+
eval_steps = eval_steps,
|
317 |
+
dataloader_num_workers = dataloader_num_workers,
|
318 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
319 |
+
past_index = past_index,
|
320 |
+
run_name = run_name,
|
321 |
+
disable_tqdm = disable_tqdm,
|
322 |
+
remove_unused_columns = remove_unused_columns,
|
323 |
+
label_names = label_names,
|
324 |
+
load_best_model_at_end = load_best_model_at_end,
|
325 |
+
metric_for_best_model = metric_for_best_model,
|
326 |
+
greater_is_better = greater_is_better,
|
327 |
+
ignore_data_skip = ignore_data_skip,
|
328 |
+
fsdp = fsdp,
|
329 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
330 |
+
fsdp_config = fsdp_config,
|
331 |
+
tp_size = tp_size,
|
332 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
333 |
+
accelerator_config = accelerator_config,
|
334 |
+
deepspeed = deepspeed,
|
335 |
+
label_smoothing_factor = label_smoothing_factor,
|
336 |
+
optim = optim,
|
337 |
+
optim_args = optim_args,
|
338 |
+
adafactor = adafactor,
|
339 |
+
group_by_length = group_by_length,
|
340 |
+
length_column_name = length_column_name,
|
341 |
+
report_to = report_to,
|
342 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
343 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
344 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
345 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
346 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
347 |
+
skip_memory_metrics = skip_memory_metrics,
|
348 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
349 |
+
push_to_hub = push_to_hub,
|
350 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
351 |
+
hub_model_id = hub_model_id,
|
352 |
+
hub_strategy = hub_strategy,
|
353 |
+
hub_token = hub_token,
|
354 |
+
hub_private_repo = hub_private_repo,
|
355 |
+
hub_always_push = hub_always_push,
|
356 |
+
gradient_checkpointing = gradient_checkpointing,
|
357 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
358 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
359 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
360 |
+
fp16_backend = fp16_backend,
|
361 |
+
evaluation_strategy = evaluation_strategy,
|
362 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
363 |
+
push_to_hub_organization = push_to_hub_organization,
|
364 |
+
push_to_hub_token = push_to_hub_token,
|
365 |
+
mp_parameters = mp_parameters,
|
366 |
+
auto_find_batch_size = auto_find_batch_size,
|
367 |
+
full_determinism = full_determinism,
|
368 |
+
torchdynamo = torchdynamo,
|
369 |
+
ray_scope = ray_scope,
|
370 |
+
ddp_timeout = ddp_timeout,
|
371 |
+
torch_compile = torch_compile,
|
372 |
+
torch_compile_backend = torch_compile_backend,
|
373 |
+
torch_compile_mode = torch_compile_mode,
|
374 |
+
dispatch_batches = dispatch_batches,
|
375 |
+
split_batches = split_batches,
|
376 |
+
include_tokens_per_second = include_tokens_per_second,
|
377 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
378 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
379 |
+
optim_target_modules = optim_target_modules,
|
380 |
+
batch_eval_metrics = batch_eval_metrics,
|
381 |
+
eval_on_start = eval_on_start,
|
382 |
+
use_liger_kernel = use_liger_kernel,
|
383 |
+
eval_use_gather_object = eval_use_gather_object,
|
384 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
385 |
+
model_init_kwargs = model_init_kwargs,
|
386 |
+
use_liger = use_liger,
|
387 |
+
dataset_text_field = dataset_text_field,
|
388 |
+
dataset_kwargs = dataset_kwargs,
|
389 |
+
dataset_num_proc = dataset_num_proc,
|
390 |
+
max_seq_length = max_seq_length,
|
391 |
+
packing = packing,
|
392 |
+
eval_packing = eval_packing,
|
393 |
+
dataset_batch_size = dataset_batch_size,
|
394 |
+
num_of_sequences = num_of_sequences,
|
395 |
+
chars_per_token = chars_per_token,**kwargs)
|
396 |
+
self.vllm_sampling_params = vllm_sampling_params
|
397 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
398 |
+
pass
|
399 |
+
|
400 |
+
class _UnslothSFTTrainer(Trainer):
|
401 |
+
""""""
|
402 |
+
|
403 |
+
_tag_names = ["trl", "sft"]
|
404 |
+
|
405 |
+
@deprecate_kwarg(
|
406 |
+
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
|
407 |
+
)
|
408 |
+
def __init__(
|
409 |
+
self,
|
410 |
+
model: Union[str, nn.Module, PreTrainedModel],
|
411 |
+
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
412 |
+
data_collator: Optional[DataCollator] = None, # type: ignore
|
413 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
414 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
415 |
+
processing_class: Optional[
|
416 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
417 |
+
] = None,
|
418 |
+
compute_loss_func: Optional[Callable] = None,
|
419 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
420 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
421 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
422 |
+
optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
423 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
424 |
+
peft_config: Optional["PeftConfig"] = None,
|
425 |
+
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
426 |
+
):
|
427 |
+
# Args
|
428 |
+
if args is None:
|
429 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
430 |
+
model_name = model_name.split("/")[-1]
|
431 |
+
args = SFTConfig(f"{model_name}-SFT")
|
432 |
+
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
433 |
+
dict_args = args.to_dict()
|
434 |
+
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
435 |
+
dict_args.pop("push_to_hub_token")
|
436 |
+
args = SFTConfig(**dict_args)
|
437 |
+
|
438 |
+
# Model
|
439 |
+
if args.model_init_kwargs is not None and not isinstance(model, str):
|
440 |
+
warnings.warn(
|
441 |
+
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
442 |
+
"The `model_init_kwargs` will be ignored."
|
443 |
+
)
|
444 |
+
if isinstance(model, str):
|
445 |
+
model = self._create_model_from_path(model, args)
|
446 |
+
|
447 |
+
# PEFT configuration and model wrapping
|
448 |
+
if False:
|
449 |
+
model = self._prepare_peft_model(model, peft_config, args)
|
450 |
+
|
451 |
+
# Handle the tokenizer
|
452 |
+
if processing_class is None:
|
453 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
454 |
+
if processing_class.pad_token is None:
|
455 |
+
processing_class.pad_token = processing_class.eos_token # required for padding when collating data
|
456 |
+
|
457 |
+
# Dataset
|
458 |
+
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
459 |
+
if preprocess_dataset:
|
460 |
+
train_dataset = self._prepare_dataset(
|
461 |
+
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
462 |
+
)
|
463 |
+
if eval_dataset is not None:
|
464 |
+
packing = args.packing if args.eval_packing is None else args.eval_packing
|
465 |
+
if isinstance(eval_dataset, dict):
|
466 |
+
eval_dataset = {
|
467 |
+
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
468 |
+
for key, dataset in eval_dataset.items()
|
469 |
+
}
|
470 |
+
else:
|
471 |
+
eval_dataset = self._prepare_dataset(
|
472 |
+
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
473 |
+
)
|
474 |
+
|
475 |
+
# Data collator
|
476 |
+
if data_collator is None:
|
477 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
|
478 |
+
|
479 |
+
# Initialize the metrics
|
480 |
+
self._metrics = defaultdict(list)
|
481 |
+
|
482 |
+
# Initialize the Trainer. Parent class will handle:
|
483 |
+
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
484 |
+
# - FSDP setup
|
485 |
+
# - Distributed training setup
|
486 |
+
# - Optimizer and scheduler creation
|
487 |
+
# Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
|
488 |
+
super_init_kwargs = {}
|
489 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
490 |
+
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
|
491 |
+
else:
|
492 |
+
if optimizer_cls_and_kwargs is not None:
|
493 |
+
warnings.warn(
|
494 |
+
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
|
495 |
+
"The default optimizer will be used. "
|
496 |
+
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
|
497 |
+
)
|
498 |
+
super().__init__(
|
499 |
+
model=model,
|
500 |
+
args=args,
|
501 |
+
data_collator=data_collator,
|
502 |
+
train_dataset=train_dataset,
|
503 |
+
eval_dataset=eval_dataset,
|
504 |
+
processing_class=processing_class,
|
505 |
+
compute_loss_func=compute_loss_func,
|
506 |
+
compute_metrics=compute_metrics,
|
507 |
+
callbacks=callbacks,
|
508 |
+
optimizers=optimizers,
|
509 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
510 |
+
**super_init_kwargs,
|
511 |
+
)
|
512 |
+
|
513 |
+
# Add tags for models that have been loaded with the correct transformers version
|
514 |
+
if hasattr(self.model, "add_model_tags"):
|
515 |
+
self.model.add_model_tags(self._tag_names)
|
516 |
+
|
517 |
+
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
|
518 |
+
"""Creates a model from a path or model identifier."""
|
519 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
520 |
+
# Handle torch dtype
|
521 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
522 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
523 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
524 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
525 |
+
torch_dtype = getattr(torch, torch_dtype)
|
526 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
527 |
+
else:
|
528 |
+
raise ValueError(
|
529 |
+
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
530 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
531 |
+
)
|
532 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
533 |
+
if args.gradient_checkpointing:
|
534 |
+
model_init_kwargs["use_cache"] = False
|
535 |
+
|
536 |
+
# Create model
|
537 |
+
if args.use_liger:
|
538 |
+
if not is_liger_kernel_available():
|
539 |
+
raise ImportError("Please install Liger-kernel for use_liger=True")
|
540 |
+
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
541 |
+
else:
|
542 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
543 |
+
return model
|
544 |
+
|
545 |
+
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
546 |
+
"""Prepares a model for PEFT training."""
|
547 |
+
if not is_peft_available():
|
548 |
+
raise ImportError("To use PeftModel, you need to install the `peft` library.")
|
549 |
+
|
550 |
+
if not isinstance(peft_config, PeftConfig):
|
551 |
+
raise ValueError(
|
552 |
+
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
|
553 |
+
"to pass a PeftConfig object to the SFTTrainer."
|
554 |
+
)
|
555 |
+
|
556 |
+
if isinstance(model, PeftModel):
|
557 |
+
return model
|
558 |
+
|
559 |
+
# Handle quantized models (QLoRA)
|
560 |
+
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
561 |
+
|
562 |
+
is_sharded_qlora = False
|
563 |
+
if getattr(model, "is_loaded_in_4bit", False):
|
564 |
+
# Check if model is sharded (FSDP/DS-Zero3)
|
565 |
+
for _, param in model.named_parameters():
|
566 |
+
if param.__class__.__name__ == "Params4bit":
|
567 |
+
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
568 |
+
break
|
569 |
+
|
570 |
+
# Prepare model for kbit training if needed
|
571 |
+
if is_qlora and not is_sharded_qlora:
|
572 |
+
model = self._prepare_model_for_kbit_training(model, args)
|
573 |
+
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
|
574 |
+
args = dataclasses.replace(args, gradient_checkpointing=False)
|
575 |
+
elif args.gradient_checkpointing:
|
576 |
+
model = self._enable_gradient_checkpointing(model, args)
|
577 |
+
|
578 |
+
# Create PEFT model
|
579 |
+
if (
|
580 |
+
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
|
581 |
+
and getattr(model, "is_loaded_in_4bit", False)
|
582 |
+
and is_sharded_qlora
|
583 |
+
):
|
584 |
+
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
585 |
+
else:
|
586 |
+
model = get_peft_model(model, peft_config)
|
587 |
+
|
588 |
+
# Handle bf16 casting for 4-bit models
|
589 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
590 |
+
peft_module_casting_to_bf16(model)
|
591 |
+
|
592 |
+
return model
|
593 |
+
|
594 |
+
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
595 |
+
"""Prepares a quantized model for kbit training."""
|
596 |
+
prepare_model_kwargs = {
|
597 |
+
"use_gradient_checkpointing": args.gradient_checkpointing,
|
598 |
+
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
|
599 |
+
}
|
600 |
+
|
601 |
+
return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
602 |
+
|
603 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
604 |
+
"""Enables gradient checkpointing for the model."""
|
605 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
606 |
+
use_reentrant = (
|
607 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
608 |
+
)
|
609 |
+
|
610 |
+
if use_reentrant:
|
611 |
+
if hasattr(model, "enable_input_require_grads"):
|
612 |
+
model.enable_input_require_grads()
|
613 |
+
else:
|
614 |
+
|
615 |
+
def make_inputs_require_grad(module, input, output):
|
616 |
+
output.requires_grad_(True)
|
617 |
+
|
618 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
619 |
+
|
620 |
+
return model
|
621 |
+
|
622 |
+
def _prepare_dataset(
|
623 |
+
self,
|
624 |
+
dataset: Union[Dataset, IterableDataset],
|
625 |
+
processing_class,
|
626 |
+
args,
|
627 |
+
packing: bool,
|
628 |
+
formatting_func: Optional[Callable[[dict], str]],
|
629 |
+
dataset_name: str,
|
630 |
+
) -> Union[Dataset, IterableDataset]:
|
631 |
+
# All Unsloth Zoo code licensed under LGPLv3
|
632 |
+
if isinstance(dataset, ConstantLengthDataset): return dataset
|
633 |
+
|
634 |
+
map_kwargs = {}
|
635 |
+
use_desc = isinstance(dataset, Dataset)
|
636 |
+
is_vlm = hasattr(processing_class, "tokenizer")
|
637 |
+
tokenizer = processing_class
|
638 |
+
if is_vlm: tokenizer = processing_class.tokenizer
|
639 |
+
|
640 |
+
# Get max length
|
641 |
+
max_seq_length = getattr(args, "max_length", 0)
|
642 |
+
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
643 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
644 |
+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
645 |
+
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
646 |
+
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
647 |
+
do_truncation = max_seq_length != 0
|
648 |
+
do_formatting_func = False
|
649 |
+
do_tokenize = True
|
650 |
+
|
651 |
+
# Get correct column names
|
652 |
+
column_names = set(next(iter(dataset)).keys())
|
653 |
+
used_column_names = ["input_ids"]
|
654 |
+
if "attention_mask" in column_names:
|
655 |
+
used_column_names.append("attention_mask")
|
656 |
+
|
657 |
+
# Check if already tokenized so skip
|
658 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
659 |
+
if "labels" in column_names:
|
660 |
+
# Most likely forgot data collator!
|
661 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
662 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
663 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
664 |
+
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
665 |
+
used_column_names.append("labels")
|
666 |
+
do_tokenize = False
|
667 |
+
elif "input_ids" in column_names:
|
668 |
+
# Skip dataset prep, and set data collator
|
669 |
+
if is_vlm and not hasattr(tokenizer, "pad"):
|
670 |
+
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
671 |
+
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
672 |
+
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
673 |
+
do_tokenize = False
|
674 |
+
elif dataset_text_field not in column_names:
|
675 |
+
do_formatting_func = True
|
676 |
+
if formatting_func is None:
|
677 |
+
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
678 |
+
pass
|
679 |
+
|
680 |
+
if do_tokenize:
|
681 |
+
# Check double BOS tokens
|
682 |
+
if do_formatting_func:
|
683 |
+
test_text = formatting_func(dataset[0])
|
684 |
+
if not isinstance(test_text, list):
|
685 |
+
raise ValueError(
|
686 |
+
"Unsloth: The `formatting_func` should return a list of processed strings."
|
687 |
+
)
|
688 |
+
test_text = test_text[0]
|
689 |
+
else:
|
690 |
+
test_text = dataset[0][dataset_text_field]
|
691 |
+
|
692 |
+
# Get chat template
|
693 |
+
chat_template = getattr(processing_class, 'chat_template', '')
|
694 |
+
if chat_template == '' and is_vlm:
|
695 |
+
chat_template = getattr(tokenizer, 'chat_template', '')
|
696 |
+
if chat_template is None:
|
697 |
+
chat_template = ''
|
698 |
+
|
699 |
+
# Get bos_token
|
700 |
+
add_special_tokens = True
|
701 |
+
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
702 |
+
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
703 |
+
bos_token = bos_token_1 or bos_token_2
|
704 |
+
|
705 |
+
if bos_token is not None:
|
706 |
+
if test_text.startswith(bos_token) or bos_token in chat_template:
|
707 |
+
add_special_tokens = False
|
708 |
+
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
709 |
+
pass
|
710 |
+
|
711 |
+
# Create tokenize function
|
712 |
+
def _tokenize(example):
|
713 |
+
return tokenizer(
|
714 |
+
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
715 |
+
truncation = do_truncation,
|
716 |
+
max_length = max_seq_length,
|
717 |
+
return_token_type_ids = False,
|
718 |
+
add_special_tokens = add_special_tokens,
|
719 |
+
)
|
720 |
+
pass
|
721 |
+
|
722 |
+
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
723 |
+
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
724 |
+
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
725 |
+
|
726 |
+
# If VLM, switch data collator since .pad is needed!
|
727 |
+
if is_vlm and not hasattr(processing_class, "pad"):
|
728 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
729 |
+
self.data_collator = data_collator
|
730 |
+
pass
|
731 |
+
pass
|
732 |
+
if packing:
|
733 |
+
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
734 |
+
return dataset
|
735 |
+
|
736 |
+
if max_seq_length == 0:
|
737 |
+
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
738 |
+
|
739 |
+
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
740 |
+
dataset = dataset.select_columns(used_column_names).map(
|
741 |
+
pack_examples,
|
742 |
+
batched = True,
|
743 |
+
fn_kwargs = {"seq_length": max_seq_length,},
|
744 |
+
**map_kwargs,
|
745 |
+
)
|
746 |
+
pass
|
747 |
+
return dataset
|
748 |
+
|
749 |
+
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
750 |
+
outputs = super().compute_loss(
|
751 |
+
model,
|
752 |
+
inputs,
|
753 |
+
return_outputs = return_outputs,
|
754 |
+
num_items_in_batch = num_items_in_batch,
|
755 |
+
)
|
756 |
+
return outputs
|
757 |
+
|
758 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
759 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
760 |
+
|
761 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
762 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
763 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
764 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
765 |
+
|
766 |
+
logs = {**logs, **metrics}
|
767 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
768 |
+
super().log(logs, start_time)
|
769 |
+
else: # transformers<=4.46
|
770 |
+
super().log(logs)
|
771 |
+
self._metrics.clear()
|
772 |
+
|
773 |
+
def create_model_card(
|
774 |
+
self,
|
775 |
+
model_name: Optional[str] = None,
|
776 |
+
dataset_name: Optional[str] = None,
|
777 |
+
tags: Union[str, list[str], None] = None,
|
778 |
+
):
|
779 |
+
"""
|
780 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
781 |
+
|
782 |
+
Args:
|
783 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
784 |
+
Name of the model.
|
785 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
786 |
+
Name of the dataset used for training.
|
787 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
788 |
+
Tags to be associated with the model card.
|
789 |
+
"""
|
790 |
+
if not self.is_world_process_zero():
|
791 |
+
return
|
792 |
+
|
793 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
794 |
+
base_model = self.model.config._name_or_path
|
795 |
+
else:
|
796 |
+
base_model = None
|
797 |
+
|
798 |
+
tags = tags or []
|
799 |
+
if isinstance(tags, str):
|
800 |
+
tags = [tags]
|
801 |
+
|
802 |
+
if hasattr(self.model.config, "unsloth_version"):
|
803 |
+
tags.append("unsloth")
|
804 |
+
|
805 |
+
model_card = generate_model_card(
|
806 |
+
base_model=base_model,
|
807 |
+
model_name=model_name,
|
808 |
+
hub_model_id=self.hub_model_id,
|
809 |
+
dataset_name=dataset_name,
|
810 |
+
tags=tags,
|
811 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
812 |
+
comet_url=get_comet_experiment_url(),
|
813 |
+
trainer_name="SFT",
|
814 |
+
)
|
815 |
+
|
816 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
817 |
+
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
818 |
+
"""
|
819 |
+
|
820 |
+
Trainer for Supervised Fine-Tuning (SFT) method.
|
821 |
+
|
822 |
+
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
823 |
+
|
824 |
+
Example:
|
825 |
+
|
826 |
+
```python
|
827 |
+
from datasets import load_dataset
|
828 |
+
from trl import SFTTrainer
|
829 |
+
|
830 |
+
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
831 |
+
|
832 |
+
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
833 |
+
trainer.train()
|
834 |
+
```
|
835 |
+
|
836 |
+
Args:
|
837 |
+
model (`Union[str, PreTrainedModel]`):
|
838 |
+
Model to be trained. Can be either:
|
839 |
+
|
840 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
841 |
+
a path to a *directory* containing model weights saved using
|
842 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
843 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
844 |
+
in `args.model_init_kwargs`.
|
845 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
846 |
+
args ([`SFTConfig`], *optional*, defaults to `None`):
|
847 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
848 |
+
data_collator (`DataCollator`, *optional*):
|
849 |
+
Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
|
850 |
+
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
851 |
+
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
852 |
+
tokenizer.
|
853 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
854 |
+
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
855 |
+
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
856 |
+
|
857 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
858 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
859 |
+
and content).
|
860 |
+
|
861 |
+
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
862 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
863 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
864 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
865 |
+
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
866 |
+
with [`~transformers.AutoTokenizer.from_pretrained`].
|
867 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
868 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
869 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
870 |
+
|
871 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
872 |
+
method.
|
873 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
874 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
875 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
876 |
+
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
|
877 |
+
A tuple containing the optimizer class and keyword arguments to use.
|
878 |
+
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
|
879 |
+
|
880 |
+
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
|
881 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
|
882 |
+
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
883 |
+
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
884 |
+
by this function will be reflected in the predictions received by `compute_metrics`.
|
885 |
+
|
886 |
+
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
887 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
888 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
889 |
+
formatting_func (`Optional[Callable]`):
|
890 |
+
Formatting function applied to the dataset before tokenization.
|
891 |
+
|
892 |
+
"""
|
893 |
+
def __init__(
|
894 |
+
self,
|
895 |
+
model,
|
896 |
+
args = None,
|
897 |
+
data_collator = None,
|
898 |
+
train_dataset = None,
|
899 |
+
eval_dataset = None,
|
900 |
+
processing_class = None,
|
901 |
+
compute_loss_func = None,
|
902 |
+
compute_metrics = None,
|
903 |
+
callbacks = None,
|
904 |
+
optimizer_cls_and_kwargs = None,
|
905 |
+
preprocess_logits_for_metrics = None,
|
906 |
+
peft_config = None,
|
907 |
+
formatting_func = None,
|
908 |
+
**kwargs
|
909 |
+
):
|
910 |
+
if args is None: args = UnslothSFTConfig()
|
911 |
+
use_bf16 = getattr(args, 'bf16', False)
|
912 |
+
use_fp16 = getattr(args, 'fp16', False)
|
913 |
+
force_float32 = False
|
914 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
915 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
916 |
+
force_float32 = True
|
917 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
918 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
919 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
920 |
+
from unsloth_zoo.utils import _get_dtype
|
921 |
+
dtype = _get_dtype(dtype)
|
922 |
+
float16 = dtype == torch.float16
|
923 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
924 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
925 |
+
if force_float32:
|
926 |
+
args.fp16 = False
|
927 |
+
args.bf16 = False
|
928 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
929 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
930 |
+
args.fp16 = float16
|
931 |
+
args.bf16 = not float16
|
932 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
933 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
934 |
+
args.eval_strategy = 'steps'
|
935 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
936 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
937 |
+
if ga_steps is not None and ga_steps > 1:
|
938 |
+
from transformers import __version__ as transformers_version
|
939 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
940 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
941 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
942 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
943 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
944 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
945 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
946 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
947 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
948 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
949 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
950 |
+
if force_float32:
|
951 |
+
args.bf16_full_eval = False
|
952 |
+
args.fp16_full_eval = False
|
953 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
954 |
+
args.bf16_full_eval = True
|
955 |
+
args.fp16_full_eval = False
|
956 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
957 |
+
args.bf16_full_eval = args.bf16
|
958 |
+
args.fp16_full_eval = args.fp16
|
959 |
+
_output_logits = False
|
960 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
961 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
962 |
+
if _output_logits:
|
963 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
964 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
965 |
+
pass
|
966 |
+
else:
|
967 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
968 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
969 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
970 |
+
max_seq_length = model.max_seq_length
|
971 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
972 |
+
if model is not None and hasattr(model, 'for_training'):
|
973 |
+
model.for_training()
|
974 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
975 |
+
if 'processing_class' in locals():
|
976 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
977 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
978 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
979 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
980 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
981 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
982 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
983 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
984 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
985 |
+
else:
|
986 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
987 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
988 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
989 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
990 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
991 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
992 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
993 |
+
else:
|
994 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
995 |
+
other_metrics = []
|
996 |
+
|
997 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
998 |
+
PatchRLStatistics('sft_trainer', other_metrics)
|
999 |
+
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
1000 |
+
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
1001 |
+
from unsloth_zoo.training_utils import fix_zero_training_loss
|
1002 |
+
if 'tokenizer' not in locals(): tokenizer = processing_class
|
1003 |
+
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
1004 |
+
fix_zero_training_loss(model, tokenizer, train_dataset)
|
1005 |
+
|
1006 |
+
super().__init__(
|
1007 |
+
model = model,
|
1008 |
+
args = args,
|
1009 |
+
data_collator = data_collator,
|
1010 |
+
train_dataset = train_dataset,
|
1011 |
+
eval_dataset = eval_dataset,
|
1012 |
+
processing_class = processing_class,
|
1013 |
+
compute_loss_func = compute_loss_func,
|
1014 |
+
compute_metrics = compute_metrics,
|
1015 |
+
callbacks = callbacks,
|
1016 |
+
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
1017 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
1018 |
+
peft_config = peft_config,
|
1019 |
+
formatting_func = formatting_func,**kwargs)
|
1020 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1021 |
+
self.neftune_hook_handle.remove()
|
1022 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1023 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1024 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1025 |
+
pass
|
1026 |
+
|
1027 |
+
pass
|
unsloth_compiled_cache/UnslothXPOTrainer.py
ADDED
@@ -0,0 +1,1010 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2025.3.15
|
3 |
+
2025.3.17
|
4 |
+
4.50.0.dev0
|
5 |
+
0.15.2
|
6 |
+
__UNSLOTH_VERSIONING__
|
7 |
+
"""
|
8 |
+
from torch import Tensor
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation)
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
from typing import *
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from packaging.version import Version
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from contextlib import nullcontext
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
24 |
+
|
25 |
+
torch_compile_options = {
|
26 |
+
"epilogue_fusion" : True,
|
27 |
+
"max_autotune" : False,
|
28 |
+
"shape_padding" : True,
|
29 |
+
"trace.enabled" : False,
|
30 |
+
"triton.cudagraphs" : False,
|
31 |
+
}
|
32 |
+
|
33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
34 |
+
def selective_log_softmax(logits, index):
|
35 |
+
logits = logits.to(torch.float32)
|
36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
37 |
+
# loop to reduce peak mem consumption
|
38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
41 |
+
return per_token_logps
|
42 |
+
@dataclass
|
43 |
+
class UnslothXPOConfig(XPOConfig):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Configuration class for the [`XPOTrainer`].
|
47 |
+
|
48 |
+
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
|
52 |
+
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
|
53 |
+
and the last alpha is used for the rest of the epochs.
|
54 |
+
|
55 |
+
"""
|
56 |
+
vllm_sampling_params: Optional[Any] = field(
|
57 |
+
default = None,
|
58 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
59 |
+
)
|
60 |
+
unsloth_num_chunks : Optional[int] = field(
|
61 |
+
default = -1,
|
62 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
63 |
+
)
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
output_dir = None,
|
67 |
+
overwrite_output_dir = None,
|
68 |
+
do_train = False,
|
69 |
+
do_eval = False,
|
70 |
+
do_predict = False,
|
71 |
+
eval_strategy = 'no',
|
72 |
+
prediction_loss_only = False,
|
73 |
+
per_device_train_batch_size = 4,
|
74 |
+
per_device_eval_batch_size = 4,
|
75 |
+
per_gpu_train_batch_size = None,
|
76 |
+
per_gpu_eval_batch_size = None,
|
77 |
+
gradient_accumulation_steps = 2,
|
78 |
+
eval_accumulation_steps = 2,
|
79 |
+
eval_delay = 0,
|
80 |
+
torch_empty_cache_steps = 250,
|
81 |
+
learning_rate = 5e-05,
|
82 |
+
weight_decay = 0.01,
|
83 |
+
adam_beta1 = 0.9,
|
84 |
+
adam_beta2 = 0.999,
|
85 |
+
adam_epsilon = 1e-08,
|
86 |
+
max_grad_norm = 1.0,
|
87 |
+
num_train_epochs = 3.0,
|
88 |
+
max_steps = -1,
|
89 |
+
lr_scheduler_type = 'linear',
|
90 |
+
warmup_ratio = 0.1,
|
91 |
+
warmup_steps = 0,
|
92 |
+
log_level = 'passive',
|
93 |
+
log_level_replica = 'warning',
|
94 |
+
log_on_each_node = True,
|
95 |
+
logging_dir = None,
|
96 |
+
logging_strategy = 'steps',
|
97 |
+
logging_first_step = False,
|
98 |
+
logging_steps = 1,
|
99 |
+
logging_nan_inf_filter = False,
|
100 |
+
save_strategy = 'steps',
|
101 |
+
save_steps = 500,
|
102 |
+
save_total_limit = None,
|
103 |
+
save_safetensors = True,
|
104 |
+
save_on_each_node = False,
|
105 |
+
save_only_model = False,
|
106 |
+
restore_callback_states_from_checkpoint = False,
|
107 |
+
no_cuda = False,
|
108 |
+
use_cpu = False,
|
109 |
+
use_mps_device = False,
|
110 |
+
seed = 3407,
|
111 |
+
data_seed = 3407,
|
112 |
+
jit_mode_eval = False,
|
113 |
+
use_ipex = False,
|
114 |
+
bf16 = False,
|
115 |
+
fp16 = False,
|
116 |
+
fp16_opt_level = 'O1',
|
117 |
+
half_precision_backend = 'auto',
|
118 |
+
bf16_full_eval = False,
|
119 |
+
fp16_full_eval = False,
|
120 |
+
tf32 = None,
|
121 |
+
local_rank = -1,
|
122 |
+
ddp_backend = None,
|
123 |
+
tpu_num_cores = None,
|
124 |
+
tpu_metrics_debug = False,
|
125 |
+
debug = '',
|
126 |
+
dataloader_drop_last = False,
|
127 |
+
eval_steps = None,
|
128 |
+
dataloader_num_workers = 0,
|
129 |
+
dataloader_prefetch_factor = None,
|
130 |
+
past_index = -1,
|
131 |
+
run_name = None,
|
132 |
+
disable_tqdm = None,
|
133 |
+
remove_unused_columns = True,
|
134 |
+
label_names = None,
|
135 |
+
load_best_model_at_end = False,
|
136 |
+
metric_for_best_model = None,
|
137 |
+
greater_is_better = None,
|
138 |
+
ignore_data_skip = False,
|
139 |
+
fsdp = '',
|
140 |
+
fsdp_min_num_params = 0,
|
141 |
+
fsdp_config = None,
|
142 |
+
tp_size = 0,
|
143 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
144 |
+
accelerator_config = None,
|
145 |
+
deepspeed = None,
|
146 |
+
label_smoothing_factor = 0.0,
|
147 |
+
optim = 'adamw_8bit',
|
148 |
+
optim_args = None,
|
149 |
+
adafactor = False,
|
150 |
+
group_by_length = False,
|
151 |
+
length_column_name = 'length',
|
152 |
+
report_to = None,
|
153 |
+
ddp_find_unused_parameters = None,
|
154 |
+
ddp_bucket_cap_mb = None,
|
155 |
+
ddp_broadcast_buffers = None,
|
156 |
+
dataloader_pin_memory = True,
|
157 |
+
dataloader_persistent_workers = False,
|
158 |
+
skip_memory_metrics = True,
|
159 |
+
use_legacy_prediction_loop = False,
|
160 |
+
push_to_hub = False,
|
161 |
+
resume_from_checkpoint = None,
|
162 |
+
hub_model_id = None,
|
163 |
+
hub_strategy = 'every_save',
|
164 |
+
hub_token = None,
|
165 |
+
hub_private_repo = None,
|
166 |
+
hub_always_push = False,
|
167 |
+
gradient_checkpointing = False,
|
168 |
+
gradient_checkpointing_kwargs = None,
|
169 |
+
include_inputs_for_metrics = False,
|
170 |
+
eval_do_concat_batches = True,
|
171 |
+
fp16_backend = 'auto',
|
172 |
+
evaluation_strategy = None,
|
173 |
+
push_to_hub_model_id = None,
|
174 |
+
push_to_hub_organization = None,
|
175 |
+
push_to_hub_token = None,
|
176 |
+
mp_parameters = '',
|
177 |
+
auto_find_batch_size = False,
|
178 |
+
full_determinism = False,
|
179 |
+
torchdynamo = None,
|
180 |
+
ray_scope = 'last',
|
181 |
+
ddp_timeout = 1800,
|
182 |
+
torch_compile = False,
|
183 |
+
torch_compile_backend = None,
|
184 |
+
torch_compile_mode = None,
|
185 |
+
dispatch_batches = None,
|
186 |
+
split_batches = None,
|
187 |
+
include_tokens_per_second = False,
|
188 |
+
include_num_input_tokens_seen = False,
|
189 |
+
neftune_noise_alpha = None,
|
190 |
+
optim_target_modules = None,
|
191 |
+
batch_eval_metrics = False,
|
192 |
+
eval_on_start = False,
|
193 |
+
use_liger_kernel = False,
|
194 |
+
eval_use_gather_object = False,
|
195 |
+
average_tokens_across_devices = False,
|
196 |
+
reward_model_path = None,
|
197 |
+
judge = None,
|
198 |
+
max_new_tokens = 64,
|
199 |
+
max_length = 512,
|
200 |
+
temperature = 0.9,
|
201 |
+
missing_eos_penalty = None,
|
202 |
+
loss_type = 'sigmoid',
|
203 |
+
dataset_num_proc = None,
|
204 |
+
disable_dropout = True,
|
205 |
+
use_vllm = False,
|
206 |
+
ds3_gather_for_generation = True,
|
207 |
+
vllm_sampling_params = None,
|
208 |
+
unsloth_num_chunks = -1,
|
209 |
+
**kwargs,
|
210 |
+
):
|
211 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
212 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
213 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
214 |
+
output_dir = 'unsloth_training_checkpoints'
|
215 |
+
save_strategy = 'no'
|
216 |
+
if dataset_num_proc is None:
|
217 |
+
from multiprocessing import cpu_count
|
218 |
+
dataset_num_proc = cpu_count()
|
219 |
+
|
220 |
+
super().__init__(
|
221 |
+
output_dir = output_dir,
|
222 |
+
overwrite_output_dir = overwrite_output_dir,
|
223 |
+
do_train = do_train,
|
224 |
+
do_eval = do_eval,
|
225 |
+
do_predict = do_predict,
|
226 |
+
eval_strategy = eval_strategy,
|
227 |
+
prediction_loss_only = prediction_loss_only,
|
228 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
229 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
230 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
231 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
232 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
233 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
234 |
+
eval_delay = eval_delay,
|
235 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
236 |
+
learning_rate = learning_rate,
|
237 |
+
weight_decay = weight_decay,
|
238 |
+
adam_beta1 = adam_beta1,
|
239 |
+
adam_beta2 = adam_beta2,
|
240 |
+
adam_epsilon = adam_epsilon,
|
241 |
+
max_grad_norm = max_grad_norm,
|
242 |
+
num_train_epochs = num_train_epochs,
|
243 |
+
max_steps = max_steps,
|
244 |
+
lr_scheduler_type = lr_scheduler_type,
|
245 |
+
warmup_ratio = warmup_ratio,
|
246 |
+
warmup_steps = warmup_steps,
|
247 |
+
log_level = log_level,
|
248 |
+
log_level_replica = log_level_replica,
|
249 |
+
log_on_each_node = log_on_each_node,
|
250 |
+
logging_dir = logging_dir,
|
251 |
+
logging_strategy = logging_strategy,
|
252 |
+
logging_first_step = logging_first_step,
|
253 |
+
logging_steps = logging_steps,
|
254 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
255 |
+
save_strategy = save_strategy,
|
256 |
+
save_steps = save_steps,
|
257 |
+
save_total_limit = save_total_limit,
|
258 |
+
save_safetensors = save_safetensors,
|
259 |
+
save_on_each_node = save_on_each_node,
|
260 |
+
save_only_model = save_only_model,
|
261 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
262 |
+
no_cuda = no_cuda,
|
263 |
+
use_cpu = use_cpu,
|
264 |
+
use_mps_device = use_mps_device,
|
265 |
+
seed = seed,
|
266 |
+
data_seed = data_seed,
|
267 |
+
jit_mode_eval = jit_mode_eval,
|
268 |
+
use_ipex = use_ipex,
|
269 |
+
bf16 = bf16,
|
270 |
+
fp16 = fp16,
|
271 |
+
fp16_opt_level = fp16_opt_level,
|
272 |
+
half_precision_backend = half_precision_backend,
|
273 |
+
bf16_full_eval = bf16_full_eval,
|
274 |
+
fp16_full_eval = fp16_full_eval,
|
275 |
+
tf32 = tf32,
|
276 |
+
local_rank = local_rank,
|
277 |
+
ddp_backend = ddp_backend,
|
278 |
+
tpu_num_cores = tpu_num_cores,
|
279 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
280 |
+
debug = debug,
|
281 |
+
dataloader_drop_last = dataloader_drop_last,
|
282 |
+
eval_steps = eval_steps,
|
283 |
+
dataloader_num_workers = dataloader_num_workers,
|
284 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
285 |
+
past_index = past_index,
|
286 |
+
run_name = run_name,
|
287 |
+
disable_tqdm = disable_tqdm,
|
288 |
+
remove_unused_columns = remove_unused_columns,
|
289 |
+
label_names = label_names,
|
290 |
+
load_best_model_at_end = load_best_model_at_end,
|
291 |
+
metric_for_best_model = metric_for_best_model,
|
292 |
+
greater_is_better = greater_is_better,
|
293 |
+
ignore_data_skip = ignore_data_skip,
|
294 |
+
fsdp = fsdp,
|
295 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
296 |
+
fsdp_config = fsdp_config,
|
297 |
+
tp_size = tp_size,
|
298 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
299 |
+
accelerator_config = accelerator_config,
|
300 |
+
deepspeed = deepspeed,
|
301 |
+
label_smoothing_factor = label_smoothing_factor,
|
302 |
+
optim = optim,
|
303 |
+
optim_args = optim_args,
|
304 |
+
adafactor = adafactor,
|
305 |
+
group_by_length = group_by_length,
|
306 |
+
length_column_name = length_column_name,
|
307 |
+
report_to = report_to,
|
308 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
309 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
310 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
311 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
312 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
313 |
+
skip_memory_metrics = skip_memory_metrics,
|
314 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
315 |
+
push_to_hub = push_to_hub,
|
316 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
317 |
+
hub_model_id = hub_model_id,
|
318 |
+
hub_strategy = hub_strategy,
|
319 |
+
hub_token = hub_token,
|
320 |
+
hub_private_repo = hub_private_repo,
|
321 |
+
hub_always_push = hub_always_push,
|
322 |
+
gradient_checkpointing = gradient_checkpointing,
|
323 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
324 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
325 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
326 |
+
fp16_backend = fp16_backend,
|
327 |
+
evaluation_strategy = evaluation_strategy,
|
328 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
329 |
+
push_to_hub_organization = push_to_hub_organization,
|
330 |
+
push_to_hub_token = push_to_hub_token,
|
331 |
+
mp_parameters = mp_parameters,
|
332 |
+
auto_find_batch_size = auto_find_batch_size,
|
333 |
+
full_determinism = full_determinism,
|
334 |
+
torchdynamo = torchdynamo,
|
335 |
+
ray_scope = ray_scope,
|
336 |
+
ddp_timeout = ddp_timeout,
|
337 |
+
torch_compile = torch_compile,
|
338 |
+
torch_compile_backend = torch_compile_backend,
|
339 |
+
torch_compile_mode = torch_compile_mode,
|
340 |
+
dispatch_batches = dispatch_batches,
|
341 |
+
split_batches = split_batches,
|
342 |
+
include_tokens_per_second = include_tokens_per_second,
|
343 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
344 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
345 |
+
optim_target_modules = optim_target_modules,
|
346 |
+
batch_eval_metrics = batch_eval_metrics,
|
347 |
+
eval_on_start = eval_on_start,
|
348 |
+
use_liger_kernel = use_liger_kernel,
|
349 |
+
eval_use_gather_object = eval_use_gather_object,
|
350 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
351 |
+
reward_model_path = reward_model_path,
|
352 |
+
judge = judge,
|
353 |
+
max_new_tokens = max_new_tokens,
|
354 |
+
max_length = max_length,
|
355 |
+
temperature = temperature,
|
356 |
+
missing_eos_penalty = missing_eos_penalty,
|
357 |
+
loss_type = loss_type,
|
358 |
+
dataset_num_proc = dataset_num_proc,
|
359 |
+
disable_dropout = disable_dropout,
|
360 |
+
use_vllm = use_vllm,
|
361 |
+
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
362 |
+
self.vllm_sampling_params = vllm_sampling_params
|
363 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
364 |
+
pass
|
365 |
+
|
366 |
+
class _UnslothXPOTrainer(OnlineDPOTrainer):
|
367 |
+
r""""""
|
368 |
+
|
369 |
+
_tag_names = ["trl", "xpo"]
|
370 |
+
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
374 |
+
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
375 |
+
reward_model: Optional[nn.Module] = None,
|
376 |
+
judge: Optional[BasePairwiseJudge] = None,
|
377 |
+
args: Optional[XPOConfig] = None,
|
378 |
+
data_collator: Optional[Callable] = None,
|
379 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
380 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
381 |
+
processing_class: Optional[
|
382 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
383 |
+
] = None,
|
384 |
+
peft_config: Optional[dict] = None,
|
385 |
+
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
386 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
387 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
388 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
389 |
+
) -> None:
|
390 |
+
super().__init__(
|
391 |
+
model=model,
|
392 |
+
ref_model=ref_model,
|
393 |
+
judge=judge,
|
394 |
+
reward_model=reward_model,
|
395 |
+
args=args,
|
396 |
+
data_collator=data_collator,
|
397 |
+
train_dataset=train_dataset,
|
398 |
+
eval_dataset=eval_dataset,
|
399 |
+
processing_class=processing_class,
|
400 |
+
reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
|
401 |
+
peft_config=peft_config,
|
402 |
+
compute_metrics=compute_metrics,
|
403 |
+
callbacks=callbacks,
|
404 |
+
optimizers=optimizers,
|
405 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
406 |
+
)
|
407 |
+
|
408 |
+
self._alpha = self.args.alpha
|
409 |
+
|
410 |
+
# Overwrite the stats dictionary to include XPO specific statistics
|
411 |
+
self.stats = {
|
412 |
+
# Remove "non_score_reward", "rlhf_reward", "scores"
|
413 |
+
# Add "loss/dpo", "loss/xpo"
|
414 |
+
"loss/dpo": [],
|
415 |
+
"loss/xpo": [],
|
416 |
+
"objective/kl": [],
|
417 |
+
"objective/entropy": [],
|
418 |
+
"rewards/chosen": [],
|
419 |
+
"rewards/rejected": [],
|
420 |
+
"rewards/accuracies": [],
|
421 |
+
"rewards/margins": [],
|
422 |
+
"logps/chosen": [],
|
423 |
+
"logps/rejected": [],
|
424 |
+
# Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
|
425 |
+
"val/model_contain_eos_token": [],
|
426 |
+
"val/ref_contain_eos_token": [],
|
427 |
+
"alpha": [],
|
428 |
+
"beta": [],
|
429 |
+
}
|
430 |
+
if self.reward_model is not None:
|
431 |
+
# Replace "scores" by "model_scores" and "ref_scores"
|
432 |
+
self.stats["objective/model_scores"] = []
|
433 |
+
self.stats["objective/ref_scores"] = []
|
434 |
+
self.stats["objective/scores_margin"] = []
|
435 |
+
|
436 |
+
@property
|
437 |
+
def alpha(self):
|
438 |
+
if isinstance(self._alpha, list):
|
439 |
+
epoch = self.state.epoch
|
440 |
+
return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
|
441 |
+
else:
|
442 |
+
return self._alpha
|
443 |
+
|
444 |
+
def _generate_completions(self, prompts, model):
|
445 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
446 |
+
model_output = unwrapped_model.generate(
|
447 |
+
input_ids=prompts["input_ids"],
|
448 |
+
attention_mask=prompts["attention_mask"],
|
449 |
+
generation_config=self.generation_config,
|
450 |
+
)
|
451 |
+
|
452 |
+
ref_model = model if self.ref_model is None else self.ref_model
|
453 |
+
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
454 |
+
ref_output = unwrapped_ref_model.generate(
|
455 |
+
input_ids=prompts["input_ids"],
|
456 |
+
attention_mask=prompts["attention_mask"],
|
457 |
+
generation_config=self.generation_config,
|
458 |
+
)
|
459 |
+
|
460 |
+
return model_output, ref_output
|
461 |
+
|
462 |
+
def _process_completions(self, model_output, ref_output, prompts):
|
463 |
+
context_length = prompts["input_ids"].shape[1]
|
464 |
+
|
465 |
+
# Process model completions
|
466 |
+
model_completion_ids = model_output[:, context_length:]
|
467 |
+
model_completion_ids, model_completion_mask = truncate_right(
|
468 |
+
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
469 |
+
)
|
470 |
+
model_data = {
|
471 |
+
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
472 |
+
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
473 |
+
"raw": prompts["raw"],
|
474 |
+
}
|
475 |
+
|
476 |
+
# Process reference model completions
|
477 |
+
ref_completion_ids = ref_output[:, context_length:]
|
478 |
+
ref_completion_ids, ref_completion_mask = truncate_right(
|
479 |
+
ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
480 |
+
)
|
481 |
+
ref_data = {
|
482 |
+
"input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
|
483 |
+
"attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
|
484 |
+
"raw": prompts["raw"],
|
485 |
+
}
|
486 |
+
|
487 |
+
return model_data, ref_data
|
488 |
+
|
489 |
+
def _compute_rewards(self, model_data, ref_data, context_length):
|
490 |
+
with torch.no_grad():
|
491 |
+
_, model_scores, _ = get_reward(
|
492 |
+
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
493 |
+
)
|
494 |
+
_, ref_scores, _ = get_reward(
|
495 |
+
self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
|
496 |
+
)
|
497 |
+
|
498 |
+
# Apply EOS penalty if needed
|
499 |
+
if self.args.missing_eos_penalty is not None:
|
500 |
+
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
501 |
+
ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
502 |
+
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
503 |
+
ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
|
504 |
+
|
505 |
+
return model_scores, ref_scores
|
506 |
+
|
507 |
+
def _compute_judge(self, model_data, ref_data, context_length):
|
508 |
+
prompts = model_data["raw"]
|
509 |
+
model_data_completions = self.processing_class.batch_decode(
|
510 |
+
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
511 |
+
)
|
512 |
+
model_data_completions = [completion.strip() for completion in model_data_completions]
|
513 |
+
|
514 |
+
ref_data_completions = self.processing_class.batch_decode(
|
515 |
+
ref_data["input_ids"][:, context_length:], skip_special_tokens=True
|
516 |
+
)
|
517 |
+
ref_data_completions = [completion.strip() for completion in ref_data_completions]
|
518 |
+
|
519 |
+
if is_conversational({"prompt": prompts[0]}):
|
520 |
+
model_data_completions = [
|
521 |
+
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
522 |
+
]
|
523 |
+
environment = jinja2.Environment()
|
524 |
+
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
525 |
+
prompts = [template.render(messages=message) for message in prompts]
|
526 |
+
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
527 |
+
|
528 |
+
ref_data_completions = [
|
529 |
+
[{"role": "assistant", "content": completion}] for completion in ref_data_completions
|
530 |
+
]
|
531 |
+
ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
|
532 |
+
|
533 |
+
ranks_of_first_completion = self.judge.judge(
|
534 |
+
prompts,
|
535 |
+
list(zip(model_data_completions, ref_data_completions)),
|
536 |
+
)
|
537 |
+
# convert ranks to a True/False mask:
|
538 |
+
# when rank == 0, it means the first completion is the best
|
539 |
+
# when rank == 1, it means the second completion is the best
|
540 |
+
return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
|
541 |
+
|
542 |
+
def _compute_logprobs(self, model, model_data, ref_data, context_length):
|
543 |
+
def compute_logprobs_for_data(m, data):
|
544 |
+
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
545 |
+
logits = output.logits[:, context_length - 1 : -1]
|
546 |
+
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
547 |
+
return token_logprobs
|
548 |
+
|
549 |
+
# Compute logprobs for model completions
|
550 |
+
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
551 |
+
# Compute logprobs for model on reference completions (for XPO loss)
|
552 |
+
model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
553 |
+
|
554 |
+
# Compute logprobs for reference model completions
|
555 |
+
with torch.no_grad():
|
556 |
+
if self.ref_model is None:
|
557 |
+
with model.disable_adapter():
|
558 |
+
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
559 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
560 |
+
else:
|
561 |
+
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
562 |
+
ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
|
563 |
+
|
564 |
+
# Mask padding tokens
|
565 |
+
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
566 |
+
ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
|
567 |
+
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
568 |
+
model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
569 |
+
ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
570 |
+
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
571 |
+
|
572 |
+
return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
|
573 |
+
|
574 |
+
def _compute_losses(
|
575 |
+
self,
|
576 |
+
model_logprobs_model_data,
|
577 |
+
model_logprobs_ref_data,
|
578 |
+
ref_logprobs_ref_data,
|
579 |
+
ref_logprobs_model_data,
|
580 |
+
chosen_mask,
|
581 |
+
):
|
582 |
+
# Compute log probs
|
583 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
584 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
585 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
586 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
587 |
+
|
588 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
589 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
590 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
591 |
+
|
592 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
593 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
594 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
595 |
+
|
596 |
+
# Compute logits as the difference between chosen and rejected log ratios
|
597 |
+
logits = chosen_log_ratios - rejected_log_ratios
|
598 |
+
|
599 |
+
if self.args.loss_type == "sigmoid":
|
600 |
+
dpo_losses = -F.logsigmoid(self.beta * logits)
|
601 |
+
elif self.args.loss_type == "ipo":
|
602 |
+
dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
|
603 |
+
else:
|
604 |
+
raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
|
605 |
+
|
606 |
+
# Compute XPO specific loss
|
607 |
+
xpo_losses = self.alpha * model_logprobs_ref_data_sum
|
608 |
+
|
609 |
+
# Total loss
|
610 |
+
loss = (dpo_losses + xpo_losses).mean()
|
611 |
+
|
612 |
+
return loss, dpo_losses, xpo_losses
|
613 |
+
|
614 |
+
def _log_statistics(
|
615 |
+
self,
|
616 |
+
model_data,
|
617 |
+
ref_data,
|
618 |
+
model_logprobs_model_data,
|
619 |
+
model_logprobs_ref_data,
|
620 |
+
ref_logprobs_ref_data,
|
621 |
+
ref_logprobs_model_data,
|
622 |
+
chosen_mask,
|
623 |
+
dpo_losses,
|
624 |
+
xpo_losses,
|
625 |
+
context_length,
|
626 |
+
model_scores=None,
|
627 |
+
ref_scores=None,
|
628 |
+
):
|
629 |
+
# Helper function to gather and compute mean
|
630 |
+
def gather_mean(tensor):
|
631 |
+
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
632 |
+
|
633 |
+
# Log losses
|
634 |
+
self.stats["loss/dpo"].append(gather_mean(dpo_losses))
|
635 |
+
self.stats["loss/xpo"].append(gather_mean(xpo_losses))
|
636 |
+
|
637 |
+
# Log scores
|
638 |
+
if self.reward_model is not None:
|
639 |
+
self.stats["objective/model_scores"].append(gather_mean(model_scores))
|
640 |
+
self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
|
641 |
+
self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
|
642 |
+
|
643 |
+
# Log logprobs
|
644 |
+
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
645 |
+
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
646 |
+
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
647 |
+
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
648 |
+
|
649 |
+
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
650 |
+
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
651 |
+
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
652 |
+
|
653 |
+
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
654 |
+
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
655 |
+
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
656 |
+
|
657 |
+
self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
|
658 |
+
self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
|
659 |
+
|
660 |
+
# Log rewards
|
661 |
+
# Compute various statistics
|
662 |
+
chosen_rewards = chosen_log_ratios * self.beta
|
663 |
+
rejected_rewards = rejected_log_ratios * self.beta
|
664 |
+
self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
|
665 |
+
self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
|
666 |
+
|
667 |
+
# Calculate KL divergence for model and ref data
|
668 |
+
kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
|
669 |
+
kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
|
670 |
+
mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
|
671 |
+
self.stats["objective/kl"].append(gather_mean(mean_kl))
|
672 |
+
|
673 |
+
# Calculate entropy for model and ref data
|
674 |
+
entropy_model_data = -model_logprobs_model_data.sum(1)
|
675 |
+
entropy_ref_data = -model_logprobs_ref_data.sum(1)
|
676 |
+
mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
|
677 |
+
self.stats["objective/entropy"].append(gather_mean(mean_entropy))
|
678 |
+
|
679 |
+
# Calculate margins
|
680 |
+
margin = chosen_rewards - rejected_rewards
|
681 |
+
self.stats["rewards/margins"].append(gather_mean(margin.mean()))
|
682 |
+
|
683 |
+
# Calculate accuracy
|
684 |
+
accuracy = (margin > 0).float()
|
685 |
+
self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
|
686 |
+
|
687 |
+
# Log EOS token statistics
|
688 |
+
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
689 |
+
ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
690 |
+
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
691 |
+
self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
|
692 |
+
|
693 |
+
# Log alpha and beta
|
694 |
+
self.stats["alpha"].append(self.alpha)
|
695 |
+
self.stats["beta"].append(self.beta)
|
696 |
+
|
697 |
+
def training_step(
|
698 |
+
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
699 |
+
) -> torch.Tensor:
|
700 |
+
model.train()
|
701 |
+
|
702 |
+
# Apply chat template and tokenize the input
|
703 |
+
batch_size = len(next(iter(inputs.values())))
|
704 |
+
prompts = inputs["prompt"]
|
705 |
+
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
706 |
+
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
707 |
+
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
708 |
+
inputs = self.data_collator(inputs)
|
709 |
+
|
710 |
+
# need the prompt_ only
|
711 |
+
inputs = self._prepare_inputs(inputs)
|
712 |
+
context_length = inputs["prompt_input_ids"].shape[1]
|
713 |
+
prompts = {
|
714 |
+
"input_ids": inputs["prompt_input_ids"],
|
715 |
+
"attention_mask": inputs["prompt_attention_mask"],
|
716 |
+
"raw": prompts,
|
717 |
+
}
|
718 |
+
del inputs
|
719 |
+
|
720 |
+
# Sample completions from both the model and the reference model
|
721 |
+
model_output, ref_output = self._generate_completions(prompts, model)
|
722 |
+
|
723 |
+
# Process model completions
|
724 |
+
model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
|
725 |
+
|
726 |
+
# Compute rewards
|
727 |
+
if self.reward_model is not None:
|
728 |
+
model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
|
729 |
+
chosen_mask = model_scores >= ref_scores
|
730 |
+
else:
|
731 |
+
model_scores, ref_scores = None, None
|
732 |
+
chosen_mask = self._compute_judge(model_data, ref_data, context_length)
|
733 |
+
|
734 |
+
# Compute logprobs
|
735 |
+
model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
|
736 |
+
self._compute_logprobs(model, model_data, ref_data, context_length)
|
737 |
+
)
|
738 |
+
|
739 |
+
# Compute loss
|
740 |
+
loss, dpo_losses, xpo_losses = self._compute_losses(
|
741 |
+
model_logprobs_model_data,
|
742 |
+
model_logprobs_ref_data,
|
743 |
+
ref_logprobs_ref_data,
|
744 |
+
ref_logprobs_model_data,
|
745 |
+
chosen_mask,
|
746 |
+
)
|
747 |
+
|
748 |
+
# Log everything
|
749 |
+
self._log_statistics(
|
750 |
+
model_data,
|
751 |
+
ref_data,
|
752 |
+
model_logprobs_model_data.detach(),
|
753 |
+
model_logprobs_ref_data.detach(),
|
754 |
+
ref_logprobs_ref_data,
|
755 |
+
ref_logprobs_model_data,
|
756 |
+
chosen_mask,
|
757 |
+
dpo_losses.detach(),
|
758 |
+
xpo_losses.detach(),
|
759 |
+
context_length,
|
760 |
+
model_scores,
|
761 |
+
ref_scores,
|
762 |
+
)
|
763 |
+
|
764 |
+
if (
|
765 |
+
self.args.torch_empty_cache_steps is not None
|
766 |
+
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
767 |
+
):
|
768 |
+
empty_cache()
|
769 |
+
|
770 |
+
kwargs = {}
|
771 |
+
# For LOMO optimizers you need to explicitly use the learning rate
|
772 |
+
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
773 |
+
kwargs["learning_rate"] = self._get_learning_rate()
|
774 |
+
|
775 |
+
if self.args.n_gpu > 1:
|
776 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
777 |
+
|
778 |
+
if self.use_apex:
|
779 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
780 |
+
scaled_loss.backward()
|
781 |
+
else:
|
782 |
+
self.accelerator.backward(loss, **kwargs)
|
783 |
+
|
784 |
+
return loss.detach() / self.args.gradient_accumulation_steps
|
785 |
+
|
786 |
+
def create_model_card(
|
787 |
+
self,
|
788 |
+
model_name: Optional[str] = None,
|
789 |
+
dataset_name: Optional[str] = None,
|
790 |
+
tags: Union[str, list[str], None] = None,
|
791 |
+
):
|
792 |
+
"""
|
793 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
794 |
+
|
795 |
+
Args:
|
796 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
797 |
+
Name of the model.
|
798 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
799 |
+
Name of the dataset used for training.
|
800 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
801 |
+
Tags to be associated with the model card.
|
802 |
+
"""
|
803 |
+
if not self.is_world_process_zero():
|
804 |
+
return
|
805 |
+
|
806 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
807 |
+
base_model = self.model.config._name_or_path
|
808 |
+
else:
|
809 |
+
base_model = None
|
810 |
+
|
811 |
+
tags = tags or []
|
812 |
+
if isinstance(tags, str):
|
813 |
+
tags = [tags]
|
814 |
+
|
815 |
+
if hasattr(self.model.config, "unsloth_version"):
|
816 |
+
tags.append("unsloth")
|
817 |
+
|
818 |
+
citation = textwrap.dedent("""\
|
819 |
+
@article{jung2024binary,
|
820 |
+
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
|
821 |
+
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
|
822 |
+
year = 2024,
|
823 |
+
eprint = {arXiv:2405.21046}
|
824 |
+
}""")
|
825 |
+
|
826 |
+
model_card = generate_model_card(
|
827 |
+
base_model=base_model,
|
828 |
+
model_name=model_name,
|
829 |
+
hub_model_id=self.hub_model_id,
|
830 |
+
dataset_name=dataset_name,
|
831 |
+
tags=tags,
|
832 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
833 |
+
comet_url=get_comet_experiment_url(),
|
834 |
+
trainer_name="XPO",
|
835 |
+
trainer_citation=citation,
|
836 |
+
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
|
837 |
+
paper_id="2405.21046",
|
838 |
+
)
|
839 |
+
|
840 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
841 |
+
class UnslothXPOTrainer(_UnslothXPOTrainer):
|
842 |
+
"""
|
843 |
+
|
844 |
+
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
|
845 |
+
|
846 |
+
Args:
|
847 |
+
model (`transformers.PreTrainedModel`):
|
848 |
+
The model to train, preferably an `AutoModelForCausalLM`.
|
849 |
+
ref_model (`PreTrainedModelWrapper`):
|
850 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
851 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
852 |
+
reward_model (`transformers.PreTrainedModel`):
|
853 |
+
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
854 |
+
judge (`BasePairwiseJudge`):
|
855 |
+
The judge to use for pairwise comparison of model completions.
|
856 |
+
args (`XPOConfig`):
|
857 |
+
The XPO config arguments to use for training.
|
858 |
+
data_collator (`transformers.DataCollator`):
|
859 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
860 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
861 |
+
train_dataset (`datasets.Dataset`):
|
862 |
+
The dataset to use for training.
|
863 |
+
eval_dataset (`datasets.Dataset`):
|
864 |
+
The dataset to use for evaluation.
|
865 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
866 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
867 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
868 |
+
reuse the fine-tuned model.
|
869 |
+
peft_config (`dict`):
|
870 |
+
The peft config to use for training.
|
871 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
872 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
873 |
+
a dictionary string to metric values.
|
874 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
875 |
+
The callbacks to use for training.
|
876 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
877 |
+
The optimizer and scheduler to use for training.
|
878 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
879 |
+
The function to use to preprocess the logits before computing the metrics.
|
880 |
+
|
881 |
+
"""
|
882 |
+
def __init__(
|
883 |
+
self,
|
884 |
+
model = None,
|
885 |
+
ref_model = None,
|
886 |
+
reward_model = None,
|
887 |
+
judge = None,
|
888 |
+
args = None,
|
889 |
+
data_collator = None,
|
890 |
+
train_dataset = None,
|
891 |
+
eval_dataset = None,
|
892 |
+
processing_class = None,
|
893 |
+
peft_config = None,
|
894 |
+
compute_metrics = None,
|
895 |
+
callbacks = None,
|
896 |
+
preprocess_logits_for_metrics = None,
|
897 |
+
**kwargs
|
898 |
+
):
|
899 |
+
if args is None: args = UnslothXPOConfig()
|
900 |
+
use_bf16 = getattr(args, 'bf16', False)
|
901 |
+
use_fp16 = getattr(args, 'fp16', False)
|
902 |
+
force_float32 = False
|
903 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
904 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
905 |
+
force_float32 = True
|
906 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
907 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
908 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
909 |
+
from unsloth_zoo.utils import _get_dtype
|
910 |
+
dtype = _get_dtype(dtype)
|
911 |
+
float16 = dtype == torch.float16
|
912 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
913 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
914 |
+
if force_float32:
|
915 |
+
args.fp16 = False
|
916 |
+
args.bf16 = False
|
917 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
918 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
919 |
+
args.fp16 = float16
|
920 |
+
args.bf16 = not float16
|
921 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
922 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
923 |
+
args.eval_strategy = 'steps'
|
924 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
925 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
926 |
+
if ga_steps is not None and ga_steps > 1:
|
927 |
+
from transformers import __version__ as transformers_version
|
928 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
929 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
930 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
931 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
932 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
933 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
934 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
935 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
936 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
937 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
938 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
939 |
+
if force_float32:
|
940 |
+
args.bf16_full_eval = False
|
941 |
+
args.fp16_full_eval = False
|
942 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
943 |
+
args.bf16_full_eval = True
|
944 |
+
args.fp16_full_eval = False
|
945 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
946 |
+
args.bf16_full_eval = args.bf16
|
947 |
+
args.fp16_full_eval = args.fp16
|
948 |
+
_output_logits = False
|
949 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
950 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
951 |
+
if _output_logits:
|
952 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
953 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
954 |
+
pass
|
955 |
+
else:
|
956 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
957 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
958 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
959 |
+
max_seq_length = model.max_seq_length
|
960 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
961 |
+
if model is not None and hasattr(model, 'for_training'):
|
962 |
+
model.for_training()
|
963 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
964 |
+
if 'processing_class' in locals():
|
965 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
966 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
967 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
968 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
969 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
970 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
971 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
972 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
973 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
974 |
+
else:
|
975 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
976 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
977 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
978 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
979 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
980 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
981 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
982 |
+
else:
|
983 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
984 |
+
other_metrics = []
|
985 |
+
|
986 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
987 |
+
PatchRLStatistics('xpo_trainer', other_metrics)
|
988 |
+
|
989 |
+
super().__init__(
|
990 |
+
model = model,
|
991 |
+
ref_model = ref_model,
|
992 |
+
reward_model = reward_model,
|
993 |
+
judge = judge,
|
994 |
+
args = args,
|
995 |
+
data_collator = data_collator,
|
996 |
+
train_dataset = train_dataset,
|
997 |
+
eval_dataset = eval_dataset,
|
998 |
+
processing_class = processing_class,
|
999 |
+
peft_config = peft_config,
|
1000 |
+
compute_metrics = compute_metrics,
|
1001 |
+
callbacks = callbacks,
|
1002 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
1003 |
+
if hasattr(self, 'neftune_hook_handle'):
|
1004 |
+
self.neftune_hook_handle.remove()
|
1005 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
1006 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
1007 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
1008 |
+
pass
|
1009 |
+
|
1010 |
+
pass
|
unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
ADDED
Binary file (32.9 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc
ADDED
Binary file (91.7 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc
ADDED
Binary file (75.6 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
ADDED
Binary file (45.5 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e432dd59a309384c1403b9b00dbd3c130f21ba431cd431132cbbe27003d587c4
|
3 |
+
size 103569
|
unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc
ADDED
Binary file (37.7 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
ADDED
Binary file (78.5 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc
ADDED
Binary file (87.4 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
ADDED
Binary file (47.3 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc
ADDED
Binary file (75.6 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
ADDED
Binary file (67.1 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc
ADDED
Binary file (62.7 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc
ADDED
Binary file (36.4 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
ADDED
Binary file (54.2 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc
ADDED
Binary file (38.9 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc
ADDED
Binary file (47.8 kB). View file
|
|
unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc
ADDED
Binary file (49.9 kB). View file
|
|