Spaces:
Running
Running
Commit
·
d8dd7a1
verified
·
0
Parent(s):
first commit
Browse files- .gitignore +98 -0
- README.md +291 -0
- config.py +28 -0
- config/train_smollm3.py +107 -0
- config/train_smollm3_dpo.py +38 -0
- config/train_smollm3_long_context.py +38 -0
- create_sample_dataset.py +41 -0
- data.py +238 -0
- model.py +188 -0
- requirements.txt +35 -0
- test_setup.py +206 -0
- train.py +144 -0
- trainer.py +242 -0
.gitignore
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
MANIFEST
|
23 |
+
|
24 |
+
# PyTorch
|
25 |
+
*.pth
|
26 |
+
*.pt
|
27 |
+
*.ckpt
|
28 |
+
|
29 |
+
# Jupyter Notebook
|
30 |
+
.ipynb_checkpoints
|
31 |
+
|
32 |
+
# Environment
|
33 |
+
.env
|
34 |
+
.venv
|
35 |
+
env/
|
36 |
+
venv/
|
37 |
+
ENV/
|
38 |
+
env.bak/
|
39 |
+
venv.bak/
|
40 |
+
|
41 |
+
# IDE
|
42 |
+
.vscode/
|
43 |
+
.idea/
|
44 |
+
*.swp
|
45 |
+
*.swo
|
46 |
+
*~
|
47 |
+
|
48 |
+
# OS
|
49 |
+
.DS_Store
|
50 |
+
.DS_Store?
|
51 |
+
._*
|
52 |
+
.Spotlight-V100
|
53 |
+
.Trashes
|
54 |
+
ehthumbs.db
|
55 |
+
Thumbs.db
|
56 |
+
|
57 |
+
# Logs
|
58 |
+
*.log
|
59 |
+
logs/
|
60 |
+
tensorboard_logs/
|
61 |
+
|
62 |
+
# Model outputs
|
63 |
+
output/
|
64 |
+
checkpoints/
|
65 |
+
models/
|
66 |
+
wandb/
|
67 |
+
|
68 |
+
# Datasets
|
69 |
+
data/
|
70 |
+
datasets/
|
71 |
+
my_dataset/
|
72 |
+
test_dataset/
|
73 |
+
|
74 |
+
# Temporary files
|
75 |
+
tmp/
|
76 |
+
temp/
|
77 |
+
*.tmp
|
78 |
+
*.temp
|
79 |
+
|
80 |
+
# Hugging Face cache
|
81 |
+
.cache/
|
82 |
+
transformers_cache/
|
83 |
+
|
84 |
+
# Accelerate
|
85 |
+
accelerate_config.yaml
|
86 |
+
|
87 |
+
# Training outputs
|
88 |
+
runs/
|
89 |
+
*.json
|
90 |
+
!config/*.json
|
91 |
+
!*.json.example
|
92 |
+
|
93 |
+
# Evaluation results
|
94 |
+
eval_results/
|
95 |
+
test_results/
|
96 |
+
|
97 |
+
# Documentation
|
98 |
+
docs/_build/
|
README.md
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SmolLM3 Fine-tuning for FlexAI Console
|
2 |
+
|
3 |
+
This repository provides a complete setup for fine-tuning SmolLM3 models using the FlexAI console, following the nanoGPT structure but adapted for modern transformer models.
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
SmolLM3 is a 3B-parameter transformer decoder model optimized for efficiency, long-context reasoning, and multilingual support. This setup allows you to fine-tune SmolLM3 for various tasks including:
|
8 |
+
|
9 |
+
- **Supervised Fine-tuning (SFT)**: Adapt the model for instruction following
|
10 |
+
- **Direct Preference Optimization (DPO)**: Improve model alignment
|
11 |
+
- **Long-context fine-tuning**: Support for up to 128k tokens
|
12 |
+
- **Tool calling**: Fine-tune for function calling capabilities
|
13 |
+
|
14 |
+
## Quick Start
|
15 |
+
|
16 |
+
### 1. Repository Setup
|
17 |
+
|
18 |
+
The repository follows the FlexAI console structure with the following key files:
|
19 |
+
|
20 |
+
- `train.py`: Main entry point script
|
21 |
+
- `config/train_smollm3.py`: Default configuration
|
22 |
+
- `model.py`: Model wrapper and loading
|
23 |
+
- `data.py`: Dataset handling and preprocessing
|
24 |
+
- `trainer.py`: Training loop and trainer setup
|
25 |
+
- `requirements.txt`: Dependencies
|
26 |
+
|
27 |
+
### 2. FlexAI Console Configuration
|
28 |
+
|
29 |
+
When setting up a Fine Tuning Job in the FlexAI console, use these settings:
|
30 |
+
|
31 |
+
#### Basic Configuration
|
32 |
+
- **Name**: `smollm3-finetune`
|
33 |
+
- **Cluster**: Your organization's designated cluster
|
34 |
+
- **Checkpoint**: (Optional) Previous training job checkpoint
|
35 |
+
- **Node Count**: 1
|
36 |
+
- **Accelerator Count**: 1-8 (depending on your needs)
|
37 |
+
|
38 |
+
#### Repository Settings
|
39 |
+
- **Repository URL**: `https://github.com/your-username/flexai-finetune`
|
40 |
+
- **Repository Revision**: `main`
|
41 |
+
|
42 |
+
#### Dataset Configuration
|
43 |
+
- **Datasets**: Your dataset (mounted under `/input`)
|
44 |
+
- **Mount Directory**: `my_dataset`
|
45 |
+
|
46 |
+
#### Entry Point
|
47 |
+
```
|
48 |
+
train.py config/train_smollm3.py --dataset_dir=my_dataset --init_from=resume --out_dir=/input-checkpoint --max_iters=1500
|
49 |
+
```
|
50 |
+
|
51 |
+
### 3. Dataset Format
|
52 |
+
|
53 |
+
The script supports multiple dataset formats:
|
54 |
+
|
55 |
+
#### Chat Format (Recommended)
|
56 |
+
```json
|
57 |
+
[
|
58 |
+
{
|
59 |
+
"messages": [
|
60 |
+
{"role": "user", "content": "What is machine learning?"},
|
61 |
+
{"role": "assistant", "content": "Machine learning is a subset of AI..."}
|
62 |
+
]
|
63 |
+
}
|
64 |
+
]
|
65 |
+
```
|
66 |
+
|
67 |
+
#### Instruction Format
|
68 |
+
```json
|
69 |
+
[
|
70 |
+
{
|
71 |
+
"instruction": "What is machine learning?",
|
72 |
+
"output": "Machine learning is a subset of AI..."
|
73 |
+
}
|
74 |
+
]
|
75 |
+
```
|
76 |
+
|
77 |
+
#### User-Assistant Format
|
78 |
+
```json
|
79 |
+
[
|
80 |
+
{
|
81 |
+
"user": "What is machine learning?",
|
82 |
+
"assistant": "Machine learning is a subset of AI..."
|
83 |
+
}
|
84 |
+
]
|
85 |
+
```
|
86 |
+
|
87 |
+
### 4. Configuration Options
|
88 |
+
|
89 |
+
The default configuration in `config/train_smollm3.py` includes:
|
90 |
+
|
91 |
+
```python
|
92 |
+
@dataclass
|
93 |
+
class SmolLM3Config:
|
94 |
+
# Model configuration
|
95 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
96 |
+
max_seq_length: int = 4096
|
97 |
+
use_flash_attention: bool = True
|
98 |
+
|
99 |
+
# Training configuration
|
100 |
+
batch_size: int = 4
|
101 |
+
gradient_accumulation_steps: int = 4
|
102 |
+
learning_rate: float = 2e-5
|
103 |
+
max_iters: int = 1000
|
104 |
+
|
105 |
+
# Mixed precision
|
106 |
+
fp16: bool = True
|
107 |
+
bf16: bool = False
|
108 |
+
```
|
109 |
+
|
110 |
+
### 5. Command Line Arguments
|
111 |
+
|
112 |
+
The `train.py` script accepts various arguments:
|
113 |
+
|
114 |
+
```bash
|
115 |
+
# Basic usage
|
116 |
+
python train.py config/train_smollm3.py
|
117 |
+
|
118 |
+
# With custom parameters
|
119 |
+
python train.py config/train_smollm3.py \
|
120 |
+
--dataset_dir=my_dataset \
|
121 |
+
--out_dir=/output-checkpoint \
|
122 |
+
--init_from=resume \
|
123 |
+
--max_iters=1500 \
|
124 |
+
--batch_size=8 \
|
125 |
+
--learning_rate=1e-5 \
|
126 |
+
--max_seq_length=8192
|
127 |
+
```
|
128 |
+
|
129 |
+
## Advanced Usage
|
130 |
+
|
131 |
+
### 1. Custom Configuration
|
132 |
+
|
133 |
+
Create a custom configuration file:
|
134 |
+
|
135 |
+
```python
|
136 |
+
# config/my_config.py
|
137 |
+
from config.train_smollm3 import SmolLM3Config
|
138 |
+
|
139 |
+
config = SmolLM3Config(
|
140 |
+
model_name="HuggingFaceTB/SmolLM3-3B-Instruct",
|
141 |
+
max_seq_length=8192,
|
142 |
+
batch_size=2,
|
143 |
+
learning_rate=1e-5,
|
144 |
+
max_iters=2000,
|
145 |
+
use_flash_attention=True,
|
146 |
+
fp16=True
|
147 |
+
)
|
148 |
+
```
|
149 |
+
|
150 |
+
### 2. Long-Context Fine-tuning
|
151 |
+
|
152 |
+
For long-context tasks (up to 128k tokens):
|
153 |
+
|
154 |
+
```python
|
155 |
+
config = SmolLM3Config(
|
156 |
+
max_seq_length=131072, # 128k tokens
|
157 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
158 |
+
use_flash_attention=True,
|
159 |
+
gradient_checkpointing=True
|
160 |
+
)
|
161 |
+
```
|
162 |
+
|
163 |
+
### 3. DPO Training
|
164 |
+
|
165 |
+
For preference optimization, use the DPO trainer:
|
166 |
+
|
167 |
+
```python
|
168 |
+
from trainer import SmolLM3DPOTrainer
|
169 |
+
|
170 |
+
dpo_trainer = SmolLM3DPOTrainer(
|
171 |
+
model=model,
|
172 |
+
dataset=dataset,
|
173 |
+
config=config,
|
174 |
+
output_dir="./dpo-output"
|
175 |
+
)
|
176 |
+
|
177 |
+
dpo_trainer.train()
|
178 |
+
```
|
179 |
+
|
180 |
+
### 4. Tool Calling Fine-tuning
|
181 |
+
|
182 |
+
Include tool calling examples in your dataset:
|
183 |
+
|
184 |
+
```json
|
185 |
+
[
|
186 |
+
{
|
187 |
+
"messages": [
|
188 |
+
{"role": "user", "content": "What's the weather in New York?"},
|
189 |
+
{"role": "assistant", "content": "<tool_call>\n<invoke name=\"get_weather\">\n<parameter name=\"location\">New York</parameter>\n</invoke>\n</tool_call>"},
|
190 |
+
{"role": "tool", "content": "The weather in New York is 72°F and sunny."},
|
191 |
+
{"role": "assistant", "content": "The weather in New York is currently 72°F and sunny."}
|
192 |
+
]
|
193 |
+
}
|
194 |
+
]
|
195 |
+
```
|
196 |
+
|
197 |
+
## Model Variants
|
198 |
+
|
199 |
+
SmolLM3 comes in several variants:
|
200 |
+
|
201 |
+
- **SmolLM3-3B-Base**: Base model for general fine-tuning
|
202 |
+
- **SmolLM3-3B**: Instruction-tuned model
|
203 |
+
- **SmolLM3-3B-Instruct**: Enhanced instruction model
|
204 |
+
- **Quantized versions**: Available for deployment
|
205 |
+
|
206 |
+
## Hardware Requirements
|
207 |
+
|
208 |
+
### Minimum Requirements
|
209 |
+
- **GPU**: 16GB+ VRAM (for 3B model)
|
210 |
+
- **RAM**: 32GB+ system memory
|
211 |
+
- **Storage**: 50GB+ free space
|
212 |
+
|
213 |
+
### Recommended
|
214 |
+
- **GPU**: A100/H100 or similar
|
215 |
+
- **RAM**: 64GB+ system memory
|
216 |
+
- **Storage**: 100GB+ SSD
|
217 |
+
|
218 |
+
## Troubleshooting
|
219 |
+
|
220 |
+
### Common Issues
|
221 |
+
|
222 |
+
1. **Out of Memory (OOM)**
|
223 |
+
- Reduce `batch_size`
|
224 |
+
- Increase `gradient_accumulation_steps`
|
225 |
+
- Enable `gradient_checkpointing`
|
226 |
+
- Use `fp16` or `bf16`
|
227 |
+
|
228 |
+
2. **Slow Training**
|
229 |
+
- Enable `flash_attention`
|
230 |
+
- Use mixed precision (`fp16`/`bf16`)
|
231 |
+
- Increase `dataloader_num_workers`
|
232 |
+
|
233 |
+
3. **Dataset Loading Issues**
|
234 |
+
- Check dataset format
|
235 |
+
- Ensure proper JSON structure
|
236 |
+
- Verify file permissions
|
237 |
+
|
238 |
+
### Debug Mode
|
239 |
+
|
240 |
+
Enable debug logging:
|
241 |
+
|
242 |
+
```python
|
243 |
+
import logging
|
244 |
+
logging.basicConfig(level=logging.DEBUG)
|
245 |
+
```
|
246 |
+
|
247 |
+
## Evaluation
|
248 |
+
|
249 |
+
After training, evaluate your model:
|
250 |
+
|
251 |
+
```python
|
252 |
+
from transformers import pipeline
|
253 |
+
|
254 |
+
pipe = pipeline(
|
255 |
+
task="text-generation",
|
256 |
+
model="./output-checkpoint",
|
257 |
+
device=0,
|
258 |
+
max_new_tokens=256,
|
259 |
+
do_sample=True,
|
260 |
+
temperature=0.7
|
261 |
+
)
|
262 |
+
|
263 |
+
# Test the model
|
264 |
+
messages = [{"role": "user", "content": "Explain gravity in simple terms."}]
|
265 |
+
outputs = pipe(messages)
|
266 |
+
print(outputs[0]["generated_text"][-1]["content"])
|
267 |
+
```
|
268 |
+
|
269 |
+
## Deployment
|
270 |
+
|
271 |
+
### Using vLLM
|
272 |
+
```bash
|
273 |
+
vllm serve ./output-checkpoint --enable-auto-tool-choice
|
274 |
+
```
|
275 |
+
|
276 |
+
### Using llama.cpp
|
277 |
+
```bash
|
278 |
+
# Convert to GGUF format
|
279 |
+
python -m llama_cpp.convert_model ./output-checkpoint --outfile model.gguf
|
280 |
+
```
|
281 |
+
|
282 |
+
## Resources
|
283 |
+
|
284 |
+
- [SmolLM3 Blog Post](https://huggingface.co/blog/smollm3)
|
285 |
+
- [Model Repository](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
|
286 |
+
- [GitHub Repository](https://github.com/huggingface/smollm)
|
287 |
+
- [SmolTalk Dataset](https://huggingface.co/datasets/HuggingFaceTB/smoltalk)
|
288 |
+
|
289 |
+
## License
|
290 |
+
|
291 |
+
This project follows the same license as the SmolLM3 model. Please refer to the Hugging Face model page for licensing information.
|
config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration management for SmolLM3 fine-tuning
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import importlib.util
|
7 |
+
from typing import Any
|
8 |
+
from config.train_smollm3 import SmolLM3Config, get_config as get_default_config
|
9 |
+
|
10 |
+
def get_config(config_path: str) -> SmolLM3Config:
|
11 |
+
"""Load configuration from file or return default"""
|
12 |
+
if os.path.exists(config_path):
|
13 |
+
# Load from file if it exists
|
14 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
15 |
+
config_module = importlib.util.module_from_spec(spec)
|
16 |
+
spec.loader.exec_module(config_module)
|
17 |
+
|
18 |
+
if hasattr(config_module, 'config'):
|
19 |
+
return config_module.config
|
20 |
+
else:
|
21 |
+
# Try to find a config class
|
22 |
+
for attr_name in dir(config_module):
|
23 |
+
attr = getattr(config_module, attr_name)
|
24 |
+
if isinstance(attr, SmolLM3Config):
|
25 |
+
return attr
|
26 |
+
|
27 |
+
# Return default configuration
|
28 |
+
return get_default_config(config_path)
|
config/train_smollm3.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 Training Configuration
|
3 |
+
Based on nanoGPT structure but adapted for SmolLM3
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class SmolLM3Config:
|
12 |
+
"""Configuration for SmolLM3 fine-tuning"""
|
13 |
+
|
14 |
+
# Model configuration
|
15 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
16 |
+
max_seq_length: int = 4096
|
17 |
+
use_flash_attention: bool = True
|
18 |
+
use_gradient_checkpointing: bool = True
|
19 |
+
|
20 |
+
# Training configuration
|
21 |
+
batch_size: int = 4
|
22 |
+
gradient_accumulation_steps: int = 4
|
23 |
+
learning_rate: float = 2e-5
|
24 |
+
weight_decay: float = 0.01
|
25 |
+
warmup_steps: int = 100
|
26 |
+
max_iters: int = 1000
|
27 |
+
eval_interval: int = 100
|
28 |
+
log_interval: int = 10
|
29 |
+
save_interval: int = 500
|
30 |
+
|
31 |
+
# Optimizer configuration
|
32 |
+
optimizer: str = "adamw"
|
33 |
+
beta1: float = 0.9
|
34 |
+
beta2: float = 0.95
|
35 |
+
eps: float = 1e-8
|
36 |
+
|
37 |
+
# Scheduler configuration
|
38 |
+
scheduler: str = "cosine"
|
39 |
+
min_lr: float = 1e-6
|
40 |
+
|
41 |
+
# Mixed precision
|
42 |
+
fp16: bool = True
|
43 |
+
bf16: bool = False
|
44 |
+
|
45 |
+
# DDP configuration
|
46 |
+
ddp_backend: str = "nccl"
|
47 |
+
ddp_find_unused_parameters: bool = False
|
48 |
+
|
49 |
+
# Logging and saving
|
50 |
+
save_steps: int = 500
|
51 |
+
eval_steps: int = 100
|
52 |
+
logging_steps: int = 10
|
53 |
+
save_total_limit: Optional[int] = 3
|
54 |
+
|
55 |
+
# Evaluation
|
56 |
+
eval_strategy: str = "steps"
|
57 |
+
metric_for_best_model: str = "eval_loss"
|
58 |
+
greater_is_better: bool = False
|
59 |
+
load_best_model_at_end: bool = True
|
60 |
+
|
61 |
+
# Data configuration
|
62 |
+
data_dir: str = "my_dataset"
|
63 |
+
train_file: str = "train.json"
|
64 |
+
validation_file: Optional[str] = None
|
65 |
+
test_file: Optional[str] = None
|
66 |
+
|
67 |
+
# Chat template configuration
|
68 |
+
use_chat_template: bool = True
|
69 |
+
chat_template_kwargs: dict = None
|
70 |
+
|
71 |
+
def __post_init__(self):
|
72 |
+
if self.chat_template_kwargs is None:
|
73 |
+
self.chat_template_kwargs = {
|
74 |
+
"enable_thinking": False,
|
75 |
+
"add_generation_prompt": True
|
76 |
+
}
|
77 |
+
|
78 |
+
# Validate configuration
|
79 |
+
if self.fp16 and self.bf16:
|
80 |
+
raise ValueError("Cannot use both fp16 and bf16")
|
81 |
+
|
82 |
+
if self.max_seq_length > 131072: # 128k limit
|
83 |
+
raise ValueError("max_seq_length cannot exceed 131072")
|
84 |
+
|
85 |
+
def get_config(config_path: str) -> SmolLM3Config:
|
86 |
+
"""Load configuration from file or return default"""
|
87 |
+
if os.path.exists(config_path):
|
88 |
+
# Load from file if it exists
|
89 |
+
import importlib.util
|
90 |
+
spec = importlib.util.spec_from_file_location("config_module", config_path)
|
91 |
+
config_module = importlib.util.module_from_spec(spec)
|
92 |
+
spec.loader.exec_module(config_module)
|
93 |
+
|
94 |
+
if hasattr(config_module, 'config'):
|
95 |
+
return config_module.config
|
96 |
+
else:
|
97 |
+
# Try to find a config class
|
98 |
+
for attr_name in dir(config_module):
|
99 |
+
attr = getattr(config_module, attr_name)
|
100 |
+
if isinstance(attr, SmolLM3Config):
|
101 |
+
return attr
|
102 |
+
|
103 |
+
# Return default configuration
|
104 |
+
return SmolLM3Config()
|
105 |
+
|
106 |
+
# Default configuration instance
|
107 |
+
config = SmolLM3Config()
|
config/train_smollm3_dpo.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 DPO Training Configuration
|
3 |
+
Optimized for Direct Preference Optimization
|
4 |
+
"""
|
5 |
+
|
6 |
+
from config.train_smollm3 import SmolLM3Config
|
7 |
+
|
8 |
+
config = SmolLM3Config(
|
9 |
+
# Model configuration
|
10 |
+
model_name="HuggingFaceTB/SmolLM3-3B-Instruct", # Start from instruction-tuned model
|
11 |
+
max_seq_length=4096,
|
12 |
+
use_flash_attention=True,
|
13 |
+
use_gradient_checkpointing=True,
|
14 |
+
|
15 |
+
# Training configuration
|
16 |
+
batch_size=2, # Smaller batch size for DPO
|
17 |
+
gradient_accumulation_steps=4,
|
18 |
+
learning_rate=5e-6, # Very low learning rate for DPO
|
19 |
+
weight_decay=0.01,
|
20 |
+
warmup_steps=100,
|
21 |
+
max_iters=1000,
|
22 |
+
|
23 |
+
# Mixed precision
|
24 |
+
fp16=True,
|
25 |
+
bf16=False,
|
26 |
+
|
27 |
+
# Logging and saving
|
28 |
+
save_steps=200,
|
29 |
+
eval_steps=100,
|
30 |
+
logging_steps=20,
|
31 |
+
|
32 |
+
# Chat template configuration
|
33 |
+
use_chat_template=True,
|
34 |
+
chat_template_kwargs={
|
35 |
+
"enable_thinking": False, # Disable reasoning for preference learning
|
36 |
+
"add_generation_prompt": True
|
37 |
+
}
|
38 |
+
)
|
config/train_smollm3_long_context.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 Long-Context Training Configuration
|
3 |
+
Optimized for long-context tasks (up to 128k tokens)
|
4 |
+
"""
|
5 |
+
|
6 |
+
from config.train_smollm3 import SmolLM3Config
|
7 |
+
|
8 |
+
config = SmolLM3Config(
|
9 |
+
# Model configuration
|
10 |
+
model_name="HuggingFaceTB/SmolLM3-3B",
|
11 |
+
max_seq_length=131072, # 128k tokens
|
12 |
+
use_flash_attention=True,
|
13 |
+
use_gradient_checkpointing=True,
|
14 |
+
|
15 |
+
# Training configuration
|
16 |
+
batch_size=1, # Reduced for long sequences
|
17 |
+
gradient_accumulation_steps=8, # Increased to maintain effective batch size
|
18 |
+
learning_rate=1e-5, # Lower learning rate for stability
|
19 |
+
weight_decay=0.01,
|
20 |
+
warmup_steps=200,
|
21 |
+
max_iters=500,
|
22 |
+
|
23 |
+
# Mixed precision
|
24 |
+
fp16=True,
|
25 |
+
bf16=False,
|
26 |
+
|
27 |
+
# Logging and saving
|
28 |
+
save_steps=100,
|
29 |
+
eval_steps=50,
|
30 |
+
logging_steps=10,
|
31 |
+
|
32 |
+
# Chat template configuration
|
33 |
+
use_chat_template=True,
|
34 |
+
chat_template_kwargs={
|
35 |
+
"enable_thinking": True, # Enable reasoning mode
|
36 |
+
"add_generation_prompt": True
|
37 |
+
}
|
38 |
+
)
|
create_sample_dataset.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Sample Dataset Creation Script
|
4 |
+
Creates sample datasets for testing SmolLM3 fine-tuning
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
from data import create_sample_dataset
|
11 |
+
|
12 |
+
def main():
|
13 |
+
parser = argparse.ArgumentParser(description='Create sample dataset for SmolLM3 fine-tuning')
|
14 |
+
parser.add_argument('--output_dir', type=str, default='my_dataset',
|
15 |
+
help='Output directory for the dataset')
|
16 |
+
parser.add_argument('--format', type=str, default='chat',
|
17 |
+
choices=['chat', 'instruction', 'user_assistant'],
|
18 |
+
help='Dataset format')
|
19 |
+
parser.add_argument('--num_samples', type=int, default=100,
|
20 |
+
help='Number of samples to create')
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# Create sample dataset
|
25 |
+
output_path = create_sample_dataset(args.output_dir)
|
26 |
+
|
27 |
+
print(f"Sample dataset created in: {output_path}")
|
28 |
+
print(f"Format: {args.format}")
|
29 |
+
print(f"Samples: {args.num_samples}")
|
30 |
+
print("\nFiles created:")
|
31 |
+
print(f"- {os.path.join(output_path, 'train.json')}")
|
32 |
+
print(f"- {os.path.join(output_path, 'validation.json')}")
|
33 |
+
|
34 |
+
# Show sample data
|
35 |
+
with open(os.path.join(output_path, 'train.json'), 'r') as f:
|
36 |
+
data = json.load(f)
|
37 |
+
print(f"\nSample data:")
|
38 |
+
print(json.dumps(data[0], indent=2))
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
data.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 Dataset Handler
|
3 |
+
Handles data loading, preprocessing, and tokenization for SmolLM3 fine-tuning
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
from typing import Dict, List, Optional, Union
|
10 |
+
from datasets import Dataset, load_dataset
|
11 |
+
from transformers import PreTrainedTokenizer
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class SmolLM3Dataset:
|
17 |
+
"""Dataset handler for SmolLM3 fine-tuning"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
data_path: str,
|
22 |
+
tokenizer: PreTrainedTokenizer,
|
23 |
+
max_seq_length: int = 4096,
|
24 |
+
use_chat_template: bool = True,
|
25 |
+
chat_template_kwargs: Optional[Dict] = None
|
26 |
+
):
|
27 |
+
self.data_path = data_path
|
28 |
+
self.tokenizer = tokenizer
|
29 |
+
self.max_seq_length = max_seq_length
|
30 |
+
self.use_chat_template = use_chat_template
|
31 |
+
self.chat_template_kwargs = chat_template_kwargs or {}
|
32 |
+
|
33 |
+
# Load and process dataset
|
34 |
+
self.dataset = self._load_dataset()
|
35 |
+
self.processed_dataset = self._process_dataset()
|
36 |
+
|
37 |
+
def _load_dataset(self) -> Dataset:
|
38 |
+
"""Load dataset from various formats"""
|
39 |
+
logger.info(f"Loading dataset from {self.data_path}")
|
40 |
+
|
41 |
+
# Check if it's a Hugging Face dataset
|
42 |
+
if os.path.isdir(self.data_path):
|
43 |
+
# Local directory
|
44 |
+
try:
|
45 |
+
dataset = load_dataset("json", data_files={
|
46 |
+
"train": os.path.join(self.data_path, "train.json"),
|
47 |
+
"validation": os.path.join(self.data_path, "validation.json") if os.path.exists(os.path.join(self.data_path, "validation.json")) else None,
|
48 |
+
"test": os.path.join(self.data_path, "test.json") if os.path.exists(os.path.join(self.data_path, "test.json")) else None
|
49 |
+
})
|
50 |
+
logger.info("Loaded dataset from local JSON files")
|
51 |
+
return dataset
|
52 |
+
except Exception as e:
|
53 |
+
logger.warning(f"Failed to load as JSON dataset: {e}")
|
54 |
+
|
55 |
+
# Try to load as a single JSON file
|
56 |
+
if os.path.isfile(self.data_path) and self.data_path.endswith('.json'):
|
57 |
+
try:
|
58 |
+
with open(self.data_path, 'r', encoding='utf-8') as f:
|
59 |
+
data = json.load(f)
|
60 |
+
|
61 |
+
# Convert to dataset format
|
62 |
+
if isinstance(data, list):
|
63 |
+
dataset = Dataset.from_list(data)
|
64 |
+
else:
|
65 |
+
dataset = Dataset.from_dict(data)
|
66 |
+
|
67 |
+
logger.info("Loaded dataset from single JSON file")
|
68 |
+
return dataset
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f"Failed to load JSON file: {e}")
|
71 |
+
raise
|
72 |
+
|
73 |
+
# Try to load as a Hugging Face dataset name
|
74 |
+
try:
|
75 |
+
dataset = load_dataset(self.data_path)
|
76 |
+
logger.info(f"Loaded Hugging Face dataset: {self.data_path}")
|
77 |
+
return dataset
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Failed to load dataset: {e}")
|
80 |
+
raise
|
81 |
+
|
82 |
+
def _process_dataset(self) -> Dataset:
|
83 |
+
"""Process the dataset for training"""
|
84 |
+
logger.info("Processing dataset for training")
|
85 |
+
|
86 |
+
def format_chat_template(example):
|
87 |
+
"""Format example using chat template"""
|
88 |
+
if self.use_chat_template:
|
89 |
+
try:
|
90 |
+
# Handle different input formats
|
91 |
+
if "messages" in example:
|
92 |
+
messages = example["messages"]
|
93 |
+
elif "conversations" in example:
|
94 |
+
messages = example["conversations"]
|
95 |
+
elif "user" in example and "assistant" in example:
|
96 |
+
messages = [
|
97 |
+
{"role": "user", "content": example["user"]},
|
98 |
+
{"role": "assistant", "content": example["assistant"]}
|
99 |
+
]
|
100 |
+
elif "instruction" in example and "output" in example:
|
101 |
+
messages = [
|
102 |
+
{"role": "user", "content": example["instruction"]},
|
103 |
+
{"role": "assistant", "content": example["output"]}
|
104 |
+
]
|
105 |
+
elif "prompt" in example and "completion" in example:
|
106 |
+
messages = [
|
107 |
+
{"role": "user", "content": example["prompt"]},
|
108 |
+
{"role": "assistant", "content": example["completion"]}
|
109 |
+
]
|
110 |
+
else:
|
111 |
+
# Fallback: treat as plain text
|
112 |
+
return {"text": str(example)}
|
113 |
+
|
114 |
+
# Apply chat template
|
115 |
+
text = self.tokenizer.apply_chat_template(
|
116 |
+
messages,
|
117 |
+
tokenize=False,
|
118 |
+
**self.chat_template_kwargs
|
119 |
+
)
|
120 |
+
return {"text": text}
|
121 |
+
except Exception as e:
|
122 |
+
logger.warning(f"Failed to apply chat template: {e}")
|
123 |
+
# Fallback to plain text
|
124 |
+
return {"text": str(example)}
|
125 |
+
else:
|
126 |
+
# Use plain text
|
127 |
+
if "text" in example:
|
128 |
+
return {"text": example["text"]}
|
129 |
+
else:
|
130 |
+
return {"text": str(example)}
|
131 |
+
|
132 |
+
def tokenize_function(examples):
|
133 |
+
"""Tokenize the examples"""
|
134 |
+
# Tokenize the texts
|
135 |
+
tokenized = self.tokenizer(
|
136 |
+
examples["text"],
|
137 |
+
truncation=True,
|
138 |
+
padding=False,
|
139 |
+
max_length=self.max_seq_length,
|
140 |
+
return_overflowing_tokens=True,
|
141 |
+
return_length=True,
|
142 |
+
)
|
143 |
+
|
144 |
+
# Calculate input length
|
145 |
+
input_length = [len(x) for x in tokenized["input_ids"]]
|
146 |
+
|
147 |
+
# Create labels (same as input_ids for causal LM)
|
148 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
149 |
+
|
150 |
+
return {
|
151 |
+
"input_ids": tokenized["input_ids"],
|
152 |
+
"attention_mask": tokenized["attention_mask"],
|
153 |
+
"labels": tokenized["labels"],
|
154 |
+
"length": input_length,
|
155 |
+
}
|
156 |
+
|
157 |
+
# Process the dataset
|
158 |
+
processed_dataset = self.dataset.map(
|
159 |
+
format_chat_template,
|
160 |
+
remove_columns=self.dataset["train"].column_names,
|
161 |
+
desc="Formatting dataset"
|
162 |
+
)
|
163 |
+
|
164 |
+
# Tokenize the dataset
|
165 |
+
tokenized_dataset = processed_dataset.map(
|
166 |
+
tokenize_function,
|
167 |
+
remove_columns=processed_dataset["train"].column_names,
|
168 |
+
desc="Tokenizing dataset",
|
169 |
+
batched=True,
|
170 |
+
)
|
171 |
+
|
172 |
+
logger.info(f"Dataset processed. Train samples: {len(tokenized_dataset['train'])}")
|
173 |
+
if "validation" in tokenized_dataset:
|
174 |
+
logger.info(f"Validation samples: {len(tokenized_dataset['validation'])}")
|
175 |
+
|
176 |
+
return tokenized_dataset
|
177 |
+
|
178 |
+
def get_train_dataset(self) -> Dataset:
|
179 |
+
"""Get training dataset"""
|
180 |
+
return self.processed_dataset["train"]
|
181 |
+
|
182 |
+
def get_eval_dataset(self) -> Optional[Dataset]:
|
183 |
+
"""Get evaluation dataset if available"""
|
184 |
+
if "validation" in self.processed_dataset:
|
185 |
+
return self.processed_dataset["validation"]
|
186 |
+
elif "test" in self.processed_dataset:
|
187 |
+
return self.processed_dataset["test"]
|
188 |
+
else:
|
189 |
+
return None
|
190 |
+
|
191 |
+
def get_data_collator(self):
|
192 |
+
"""Get data collator for training"""
|
193 |
+
from transformers import DataCollatorForLanguageModeling
|
194 |
+
|
195 |
+
return DataCollatorForLanguageModeling(
|
196 |
+
tokenizer=self.tokenizer,
|
197 |
+
mlm=False, # We're doing causal LM, not masked LM
|
198 |
+
)
|
199 |
+
|
200 |
+
def create_sample_dataset(output_path: str = "my_dataset"):
|
201 |
+
"""Create a sample dataset for testing"""
|
202 |
+
os.makedirs(output_path, exist_ok=True)
|
203 |
+
|
204 |
+
# Sample conversations
|
205 |
+
conversations = [
|
206 |
+
{
|
207 |
+
"messages": [
|
208 |
+
{"role": "user", "content": "What is machine learning?"},
|
209 |
+
{"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed."}
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"messages": [
|
214 |
+
{"role": "user", "content": "Explain gravity in simple terms."},
|
215 |
+
{"role": "assistant", "content": "Gravity is the force that pulls objects toward each other, like how the Earth pulls things down to the ground."}
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"messages": [
|
220 |
+
{"role": "user", "content": "How do I make a cup of coffee?"},
|
221 |
+
{"role": "assistant", "content": "To make a cup of coffee: 1) Boil water, 2) Add coffee grounds to a filter, 3) Pour hot water over the grounds, 4) Let it brew for a few minutes, 5) Enjoy!"}
|
222 |
+
]
|
223 |
+
}
|
224 |
+
]
|
225 |
+
|
226 |
+
# Split into train/validation
|
227 |
+
train_data = conversations[:2]
|
228 |
+
validation_data = conversations[2:]
|
229 |
+
|
230 |
+
# Save to files
|
231 |
+
with open(os.path.join(output_path, "train.json"), 'w', encoding='utf-8') as f:
|
232 |
+
json.dump(train_data, f, indent=2, ensure_ascii=False)
|
233 |
+
|
234 |
+
with open(os.path.join(output_path, "validation.json"), 'w', encoding='utf-8') as f:
|
235 |
+
json.dump(validation_data, f, indent=2, ensure_ascii=False)
|
236 |
+
|
237 |
+
logger.info(f"Sample dataset created in {output_path}")
|
238 |
+
return output_path
|
model.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 Model Wrapper
|
3 |
+
Handles model loading, tokenizer, and training setup
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from transformers import (
|
10 |
+
AutoModelForCausalLM,
|
11 |
+
AutoTokenizer,
|
12 |
+
AutoConfig,
|
13 |
+
TrainingArguments,
|
14 |
+
Trainer
|
15 |
+
)
|
16 |
+
from typing import Optional, Dict, Any
|
17 |
+
import logging
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
class SmolLM3Model:
|
22 |
+
"""Wrapper for SmolLM3 model and tokenizer"""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B",
|
27 |
+
max_seq_length: int = 4096,
|
28 |
+
config: Optional[Any] = None,
|
29 |
+
device_map: Optional[str] = None,
|
30 |
+
torch_dtype: Optional[torch.dtype] = None
|
31 |
+
):
|
32 |
+
self.model_name = model_name
|
33 |
+
self.max_seq_length = max_seq_length
|
34 |
+
self.config = config
|
35 |
+
|
36 |
+
# Set device and dtype
|
37 |
+
if torch_dtype is None:
|
38 |
+
if torch.cuda.is_available():
|
39 |
+
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
40 |
+
else:
|
41 |
+
self.torch_dtype = torch.float32
|
42 |
+
else:
|
43 |
+
self.torch_dtype = torch_dtype
|
44 |
+
|
45 |
+
if device_map is None:
|
46 |
+
self.device_map = "auto" if torch.cuda.is_available() else "cpu"
|
47 |
+
else:
|
48 |
+
self.device_map = device_map
|
49 |
+
|
50 |
+
# Load tokenizer and model
|
51 |
+
self._load_tokenizer()
|
52 |
+
self._load_model()
|
53 |
+
|
54 |
+
def _load_tokenizer(self):
|
55 |
+
"""Load the tokenizer"""
|
56 |
+
logger.info(f"Loading tokenizer from {self.model_name}")
|
57 |
+
try:
|
58 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
59 |
+
self.model_name,
|
60 |
+
trust_remote_code=True,
|
61 |
+
use_fast=True
|
62 |
+
)
|
63 |
+
|
64 |
+
# Set pad token if not present
|
65 |
+
if self.tokenizer.pad_token is None:
|
66 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
67 |
+
|
68 |
+
logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}")
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Failed to load tokenizer: {e}")
|
72 |
+
raise
|
73 |
+
|
74 |
+
def _load_model(self):
|
75 |
+
"""Load the model"""
|
76 |
+
logger.info(f"Loading model from {self.model_name}")
|
77 |
+
try:
|
78 |
+
# Load model configuration
|
79 |
+
model_config = AutoConfig.from_pretrained(
|
80 |
+
self.model_name,
|
81 |
+
trust_remote_code=True
|
82 |
+
)
|
83 |
+
|
84 |
+
# Update configuration if needed
|
85 |
+
if hasattr(model_config, 'max_position_embeddings'):
|
86 |
+
model_config.max_position_embeddings = self.max_seq_length
|
87 |
+
|
88 |
+
# Load model
|
89 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
90 |
+
self.model_name,
|
91 |
+
config=model_config,
|
92 |
+
torch_dtype=self.torch_dtype,
|
93 |
+
device_map=self.device_map,
|
94 |
+
trust_remote_code=True,
|
95 |
+
use_flash_attention_2=self.config.use_flash_attention if self.config else True,
|
96 |
+
use_cache=False # Disable KV cache for training
|
97 |
+
)
|
98 |
+
|
99 |
+
# Enable gradient checkpointing if specified
|
100 |
+
if self.config and self.config.use_gradient_checkpointing:
|
101 |
+
self.model.gradient_checkpointing_enable()
|
102 |
+
|
103 |
+
logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}")
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
logger.error(f"Failed to load model: {e}")
|
107 |
+
raise
|
108 |
+
|
109 |
+
def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
|
110 |
+
"""Get training arguments for the Trainer"""
|
111 |
+
if self.config is None:
|
112 |
+
raise ValueError("Config is required to get training arguments")
|
113 |
+
|
114 |
+
# Merge config with kwargs
|
115 |
+
training_args = {
|
116 |
+
"output_dir": output_dir,
|
117 |
+
"per_device_train_batch_size": self.config.batch_size,
|
118 |
+
"per_device_eval_batch_size": self.config.batch_size,
|
119 |
+
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
|
120 |
+
"learning_rate": self.config.learning_rate,
|
121 |
+
"weight_decay": self.config.weight_decay,
|
122 |
+
"warmup_steps": self.config.warmup_steps,
|
123 |
+
"max_steps": self.config.max_iters,
|
124 |
+
"save_steps": self.config.save_steps,
|
125 |
+
"eval_steps": self.config.eval_steps,
|
126 |
+
"logging_steps": self.config.logging_steps,
|
127 |
+
"save_total_limit": self.config.save_total_limit,
|
128 |
+
"evaluation_strategy": self.config.eval_strategy,
|
129 |
+
"metric_for_best_model": self.config.metric_for_best_model,
|
130 |
+
"greater_is_better": self.config.greater_is_better,
|
131 |
+
"load_best_model_at_end": self.config.load_best_model_at_end,
|
132 |
+
"fp16": self.config.fp16,
|
133 |
+
"bf16": self.config.bf16,
|
134 |
+
"ddp_backend": self.config.ddp_backend,
|
135 |
+
"ddp_find_unused_parameters": self.config.ddp_find_unused_parameters,
|
136 |
+
"report_to": "none", # Disable external logging
|
137 |
+
"remove_unused_columns": False,
|
138 |
+
"dataloader_pin_memory": False,
|
139 |
+
"group_by_length": True,
|
140 |
+
"length_column_name": "length",
|
141 |
+
"ignore_data_skip": False,
|
142 |
+
"seed": 42,
|
143 |
+
"data_seed": 42,
|
144 |
+
"dataloader_num_workers": 4,
|
145 |
+
"max_grad_norm": 1.0,
|
146 |
+
"optim": self.config.optimizer,
|
147 |
+
"lr_scheduler_type": self.config.scheduler,
|
148 |
+
"warmup_ratio": 0.1,
|
149 |
+
"save_strategy": "steps",
|
150 |
+
"logging_strategy": "steps",
|
151 |
+
"prediction_loss_only": True,
|
152 |
+
}
|
153 |
+
|
154 |
+
# Override with kwargs
|
155 |
+
training_args.update(kwargs)
|
156 |
+
|
157 |
+
return TrainingArguments(**training_args)
|
158 |
+
|
159 |
+
def save_pretrained(self, path: str):
|
160 |
+
"""Save model and tokenizer"""
|
161 |
+
logger.info(f"Saving model and tokenizer to {path}")
|
162 |
+
os.makedirs(path, exist_ok=True)
|
163 |
+
|
164 |
+
self.model.save_pretrained(path)
|
165 |
+
self.tokenizer.save_pretrained(path)
|
166 |
+
|
167 |
+
# Save configuration
|
168 |
+
if self.config:
|
169 |
+
import json
|
170 |
+
config_dict = {k: v for k, v in self.config.__dict__.items()
|
171 |
+
if not k.startswith('_')}
|
172 |
+
with open(os.path.join(path, 'training_config.json'), 'w') as f:
|
173 |
+
json.dump(config_dict, f, indent=2, default=str)
|
174 |
+
|
175 |
+
def load_checkpoint(self, checkpoint_path: str):
|
176 |
+
"""Load model from checkpoint"""
|
177 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
178 |
+
try:
|
179 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
180 |
+
checkpoint_path,
|
181 |
+
torch_dtype=self.torch_dtype,
|
182 |
+
device_map=self.device_map,
|
183 |
+
trust_remote_code=True
|
184 |
+
)
|
185 |
+
logger.info("Checkpoint loaded successfully")
|
186 |
+
except Exception as e:
|
187 |
+
logger.error(f"Failed to load checkpoint: {e}")
|
188 |
+
raise
|
requirements.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core dependencies
|
2 |
+
torch>=2.0.0
|
3 |
+
transformers>=4.53.0
|
4 |
+
datasets>=2.14.0
|
5 |
+
accelerate>=0.20.0
|
6 |
+
trl>=0.7.0
|
7 |
+
|
8 |
+
# Hugging Face ecosystem
|
9 |
+
huggingface-hub>=0.16.0
|
10 |
+
tokenizers>=0.13.0
|
11 |
+
|
12 |
+
# Training and optimization
|
13 |
+
flash-attn>=2.0.0
|
14 |
+
xformers>=0.0.20
|
15 |
+
bitsandbytes>=0.41.0
|
16 |
+
|
17 |
+
# Utilities
|
18 |
+
numpy>=1.24.0
|
19 |
+
pandas>=2.0.0
|
20 |
+
scikit-learn>=1.3.0
|
21 |
+
tqdm>=4.65.0
|
22 |
+
wandb>=0.15.0
|
23 |
+
|
24 |
+
# Optional: for evaluation
|
25 |
+
lighteval>=0.1.0
|
26 |
+
evaluate>=0.4.0
|
27 |
+
|
28 |
+
# Optional: for deployment
|
29 |
+
vllm>=0.2.0
|
30 |
+
sentencepiece>=0.1.99
|
31 |
+
|
32 |
+
# Development
|
33 |
+
pytest>=7.0.0
|
34 |
+
black>=23.0.0
|
35 |
+
isort>=5.12.0
|
test_setup.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test Setup Script
|
4 |
+
Verifies that all components are working correctly
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import torch
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# Setup logging
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
def test_imports():
|
18 |
+
"""Test that all required modules can be imported"""
|
19 |
+
logger.info("Testing imports...")
|
20 |
+
|
21 |
+
try:
|
22 |
+
import transformers
|
23 |
+
logger.info(f"✓ transformers {transformers.__version__}")
|
24 |
+
except ImportError as e:
|
25 |
+
logger.error(f"✗ transformers: {e}")
|
26 |
+
return False
|
27 |
+
|
28 |
+
try:
|
29 |
+
import datasets
|
30 |
+
logger.info(f"✓ datasets {datasets.__version__}")
|
31 |
+
except ImportError as e:
|
32 |
+
logger.error(f"✗ datasets: {e}")
|
33 |
+
return False
|
34 |
+
|
35 |
+
try:
|
36 |
+
import trl
|
37 |
+
logger.info(f"✓ trl {trl.__version__}")
|
38 |
+
except ImportError as e:
|
39 |
+
logger.error(f"✗ trl: {e}")
|
40 |
+
return False
|
41 |
+
|
42 |
+
try:
|
43 |
+
import accelerate
|
44 |
+
logger.info(f"✓ accelerate {accelerate.__version__}")
|
45 |
+
except ImportError as e:
|
46 |
+
logger.error(f"✗ accelerate: {e}")
|
47 |
+
return False
|
48 |
+
|
49 |
+
return True
|
50 |
+
|
51 |
+
def test_local_imports():
|
52 |
+
"""Test that local modules can be imported"""
|
53 |
+
logger.info("Testing local imports...")
|
54 |
+
|
55 |
+
try:
|
56 |
+
from config import get_config
|
57 |
+
logger.info("✓ config module")
|
58 |
+
except ImportError as e:
|
59 |
+
logger.error(f"✗ config module: {e}")
|
60 |
+
return False
|
61 |
+
|
62 |
+
try:
|
63 |
+
from model import SmolLM3Model
|
64 |
+
logger.info("✓ model module")
|
65 |
+
except ImportError as e:
|
66 |
+
logger.error(f"✗ model module: {e}")
|
67 |
+
return False
|
68 |
+
|
69 |
+
try:
|
70 |
+
from data import SmolLM3Dataset
|
71 |
+
logger.info("✓ data module")
|
72 |
+
except ImportError as e:
|
73 |
+
logger.error(f"✗ data module: {e}")
|
74 |
+
return False
|
75 |
+
|
76 |
+
try:
|
77 |
+
from trainer import SmolLM3Trainer
|
78 |
+
logger.info("✓ trainer module")
|
79 |
+
except ImportError as e:
|
80 |
+
logger.error(f"✗ trainer module: {e}")
|
81 |
+
return False
|
82 |
+
|
83 |
+
return True
|
84 |
+
|
85 |
+
def test_config():
|
86 |
+
"""Test configuration loading"""
|
87 |
+
logger.info("Testing configuration...")
|
88 |
+
|
89 |
+
try:
|
90 |
+
from config import get_config
|
91 |
+
config = get_config("config/train_smollm3.py")
|
92 |
+
logger.info(f"✓ Configuration loaded: {config.model_name}")
|
93 |
+
return True
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"✗ Configuration loading failed: {e}")
|
96 |
+
return False
|
97 |
+
|
98 |
+
def test_dataset_creation():
|
99 |
+
"""Test dataset creation"""
|
100 |
+
logger.info("Testing dataset creation...")
|
101 |
+
|
102 |
+
try:
|
103 |
+
from data import create_sample_dataset
|
104 |
+
output_path = create_sample_dataset("test_dataset")
|
105 |
+
|
106 |
+
# Check if files were created
|
107 |
+
train_file = os.path.join(output_path, "train.json")
|
108 |
+
val_file = os.path.join(output_path, "validation.json")
|
109 |
+
|
110 |
+
if os.path.exists(train_file) and os.path.exists(val_file):
|
111 |
+
logger.info("✓ Sample dataset created successfully")
|
112 |
+
|
113 |
+
# Clean up
|
114 |
+
import shutil
|
115 |
+
shutil.rmtree(output_path)
|
116 |
+
return True
|
117 |
+
else:
|
118 |
+
logger.error("✗ Dataset files not created")
|
119 |
+
return False
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"✗ Dataset creation failed: {e}")
|
122 |
+
return False
|
123 |
+
|
124 |
+
def test_gpu_availability():
|
125 |
+
"""Test GPU availability"""
|
126 |
+
logger.info("Testing GPU availability...")
|
127 |
+
|
128 |
+
if torch.cuda.is_available():
|
129 |
+
logger.info(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
|
130 |
+
logger.info(f"✓ CUDA version: {torch.version.cuda}")
|
131 |
+
logger.info(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
132 |
+
return True
|
133 |
+
else:
|
134 |
+
logger.warning("⚠ No GPU available, will use CPU")
|
135 |
+
return True
|
136 |
+
|
137 |
+
def test_model_loading():
|
138 |
+
"""Test model loading (without downloading)"""
|
139 |
+
logger.info("Testing model loading...")
|
140 |
+
|
141 |
+
try:
|
142 |
+
from transformers import AutoTokenizer, AutoConfig
|
143 |
+
|
144 |
+
# Test tokenizer loading
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
146 |
+
"HuggingFaceTB/SmolLM3-3B",
|
147 |
+
trust_remote_code=True,
|
148 |
+
use_fast=True
|
149 |
+
)
|
150 |
+
logger.info(f"✓ Tokenizer loaded, vocab size: {tokenizer.vocab_size}")
|
151 |
+
|
152 |
+
# Test config loading
|
153 |
+
config = AutoConfig.from_pretrained(
|
154 |
+
"HuggingFaceTB/SmolLM3-3B",
|
155 |
+
trust_remote_code=True
|
156 |
+
)
|
157 |
+
logger.info(f"✓ Config loaded, model type: {config.model_type}")
|
158 |
+
|
159 |
+
return True
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"✗ Model loading test failed: {e}")
|
162 |
+
return False
|
163 |
+
|
164 |
+
def main():
|
165 |
+
"""Run all tests"""
|
166 |
+
logger.info("Starting SmolLM3 setup tests...")
|
167 |
+
|
168 |
+
tests = [
|
169 |
+
("Import Tests", test_imports),
|
170 |
+
("Local Import Tests", test_local_imports),
|
171 |
+
("Configuration Tests", test_config),
|
172 |
+
("Dataset Creation Tests", test_dataset_creation),
|
173 |
+
("GPU Availability Tests", test_gpu_availability),
|
174 |
+
("Model Loading Tests", test_model_loading),
|
175 |
+
]
|
176 |
+
|
177 |
+
passed = 0
|
178 |
+
total = len(tests)
|
179 |
+
|
180 |
+
for test_name, test_func in tests:
|
181 |
+
logger.info(f"\n{'='*50}")
|
182 |
+
logger.info(f"Running: {test_name}")
|
183 |
+
logger.info('='*50)
|
184 |
+
|
185 |
+
try:
|
186 |
+
if test_func():
|
187 |
+
passed += 1
|
188 |
+
logger.info(f"✓ {test_name} PASSED")
|
189 |
+
else:
|
190 |
+
logger.error(f"✗ {test_name} FAILED")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"✗ {test_name} FAILED with exception: {e}")
|
193 |
+
|
194 |
+
logger.info(f"\n{'='*50}")
|
195 |
+
logger.info(f"Test Results: {passed}/{total} tests passed")
|
196 |
+
logger.info('='*50)
|
197 |
+
|
198 |
+
if passed == total:
|
199 |
+
logger.info("🎉 All tests passed! Setup is ready for SmolLM3 fine-tuning.")
|
200 |
+
return 0
|
201 |
+
else:
|
202 |
+
logger.error("❌ Some tests failed. Please check the errors above.")
|
203 |
+
return 1
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
sys.exit(main())
|
train.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
SmolLM3 Fine-tuning Script for FlexAI Console
|
4 |
+
Based on the nanoGPT structure but adapted for SmolLM3 model
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import torch
|
12 |
+
import logging
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Optional, Dict, Any
|
15 |
+
|
16 |
+
# Add the current directory to the path for imports
|
17 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
18 |
+
|
19 |
+
from config import get_config
|
20 |
+
from model import SmolLM3Model
|
21 |
+
from data import SmolLM3Dataset
|
22 |
+
from trainer import SmolLM3Trainer
|
23 |
+
|
24 |
+
def setup_logging():
|
25 |
+
"""Setup logging configuration"""
|
26 |
+
logging.basicConfig(
|
27 |
+
level=logging.INFO,
|
28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
29 |
+
handlers=[
|
30 |
+
logging.StreamHandler(sys.stdout),
|
31 |
+
logging.FileHandler('training.log')
|
32 |
+
]
|
33 |
+
)
|
34 |
+
return logging.getLogger(__name__)
|
35 |
+
|
36 |
+
def parse_args():
|
37 |
+
"""Parse command line arguments"""
|
38 |
+
parser = argparse.ArgumentParser(description='SmolLM3 Fine-tuning Script')
|
39 |
+
|
40 |
+
# Configuration file
|
41 |
+
parser.add_argument('config', type=str, help='Path to configuration file')
|
42 |
+
|
43 |
+
# Dataset arguments
|
44 |
+
parser.add_argument('--dataset_dir', type=str, default='my_dataset',
|
45 |
+
help='Path to dataset directory within /input')
|
46 |
+
|
47 |
+
# Checkpoint arguments
|
48 |
+
parser.add_argument('--out_dir', type=str, default='/output-checkpoint',
|
49 |
+
help='Output directory for checkpoints')
|
50 |
+
parser.add_argument('--init_from', type=str, default='scratch',
|
51 |
+
choices=['scratch', 'resume', 'pretrained'],
|
52 |
+
help='Initialization method')
|
53 |
+
|
54 |
+
# Training arguments
|
55 |
+
parser.add_argument('--max_iters', type=int, default=None,
|
56 |
+
help='Maximum number of training iterations')
|
57 |
+
parser.add_argument('--batch_size', type=int, default=None,
|
58 |
+
help='Batch size for training')
|
59 |
+
parser.add_argument('--learning_rate', type=float, default=None,
|
60 |
+
help='Learning rate')
|
61 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=None,
|
62 |
+
help='Gradient accumulation steps')
|
63 |
+
|
64 |
+
# Model arguments
|
65 |
+
parser.add_argument('--model_name', type=str,
|
66 |
+
default='HuggingFaceTB/SmolLM3-3B',
|
67 |
+
help='Model name or path')
|
68 |
+
parser.add_argument('--max_seq_length', type=int, default=4096,
|
69 |
+
help='Maximum sequence length')
|
70 |
+
|
71 |
+
# Logging and saving
|
72 |
+
parser.add_argument('--save_steps', type=int, default=500,
|
73 |
+
help='Save checkpoint every N steps')
|
74 |
+
parser.add_argument('--eval_steps', type=int, default=100,
|
75 |
+
help='Evaluate every N steps')
|
76 |
+
parser.add_argument('--logging_steps', type=int, default=10,
|
77 |
+
help='Log every N steps')
|
78 |
+
|
79 |
+
return parser.parse_args()
|
80 |
+
|
81 |
+
def main():
|
82 |
+
"""Main training function"""
|
83 |
+
args = parse_args()
|
84 |
+
logger = setup_logging()
|
85 |
+
|
86 |
+
logger.info("Starting SmolLM3 fine-tuning...")
|
87 |
+
logger.info(f"Arguments: {vars(args)}")
|
88 |
+
|
89 |
+
# Load configuration
|
90 |
+
config = get_config(args.config)
|
91 |
+
|
92 |
+
# Override config with command line arguments
|
93 |
+
if args.max_iters is not None:
|
94 |
+
config.max_iters = args.max_iters
|
95 |
+
if args.batch_size is not None:
|
96 |
+
config.batch_size = args.batch_size
|
97 |
+
if args.learning_rate is not None:
|
98 |
+
config.learning_rate = args.learning_rate
|
99 |
+
if args.gradient_accumulation_steps is not None:
|
100 |
+
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
101 |
+
|
102 |
+
# Setup paths
|
103 |
+
dataset_path = os.path.join('/input', args.dataset_dir)
|
104 |
+
output_path = args.out_dir
|
105 |
+
|
106 |
+
# Ensure output directory exists
|
107 |
+
os.makedirs(output_path, exist_ok=True)
|
108 |
+
|
109 |
+
logger.info(f"Dataset path: {dataset_path}")
|
110 |
+
logger.info(f"Output path: {output_path}")
|
111 |
+
|
112 |
+
# Initialize model
|
113 |
+
model = SmolLM3Model(
|
114 |
+
model_name=args.model_name,
|
115 |
+
max_seq_length=args.max_seq_length,
|
116 |
+
config=config
|
117 |
+
)
|
118 |
+
|
119 |
+
# Load dataset
|
120 |
+
dataset = SmolLM3Dataset(
|
121 |
+
data_path=dataset_path,
|
122 |
+
tokenizer=model.tokenizer,
|
123 |
+
max_seq_length=args.max_seq_length
|
124 |
+
)
|
125 |
+
|
126 |
+
# Initialize trainer
|
127 |
+
trainer = SmolLM3Trainer(
|
128 |
+
model=model,
|
129 |
+
dataset=dataset,
|
130 |
+
config=config,
|
131 |
+
output_dir=output_path,
|
132 |
+
init_from=args.init_from
|
133 |
+
)
|
134 |
+
|
135 |
+
# Start training
|
136 |
+
try:
|
137 |
+
trainer.train()
|
138 |
+
logger.info("Training completed successfully!")
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Training failed: {e}")
|
141 |
+
raise
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
main()
|
trainer.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
SmolLM3 Trainer
|
3 |
+
Handles the training loop and integrates with Hugging Face Trainer
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import logging
|
9 |
+
from typing import Optional, Dict, Any
|
10 |
+
from transformers import Trainer, TrainingArguments
|
11 |
+
from trl import SFTTrainer
|
12 |
+
import json
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class SmolLM3Trainer:
|
17 |
+
"""Trainer for SmolLM3 fine-tuning"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model,
|
22 |
+
dataset,
|
23 |
+
config,
|
24 |
+
output_dir: str,
|
25 |
+
init_from: str = "scratch",
|
26 |
+
use_sft_trainer: bool = True
|
27 |
+
):
|
28 |
+
self.model = model
|
29 |
+
self.dataset = dataset
|
30 |
+
self.config = config
|
31 |
+
self.output_dir = output_dir
|
32 |
+
self.init_from = init_from
|
33 |
+
self.use_sft_trainer = use_sft_trainer
|
34 |
+
|
35 |
+
# Setup trainer
|
36 |
+
self.trainer = self._setup_trainer()
|
37 |
+
|
38 |
+
def _setup_trainer(self):
|
39 |
+
"""Setup the trainer"""
|
40 |
+
logger.info("Setting up trainer")
|
41 |
+
|
42 |
+
# Get training arguments
|
43 |
+
training_args = self.model.get_training_arguments(
|
44 |
+
output_dir=self.output_dir,
|
45 |
+
save_steps=self.config.save_steps,
|
46 |
+
eval_steps=self.config.eval_steps,
|
47 |
+
logging_steps=self.config.logging_steps,
|
48 |
+
max_steps=self.config.max_iters,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Get datasets
|
52 |
+
train_dataset = self.dataset.get_train_dataset()
|
53 |
+
eval_dataset = self.dataset.get_eval_dataset()
|
54 |
+
|
55 |
+
# Get data collator
|
56 |
+
data_collator = self.dataset.get_data_collator()
|
57 |
+
|
58 |
+
if self.use_sft_trainer:
|
59 |
+
# Use SFTTrainer for supervised fine-tuning
|
60 |
+
trainer = SFTTrainer(
|
61 |
+
model=self.model.model,
|
62 |
+
tokenizer=self.model.tokenizer,
|
63 |
+
train_dataset=train_dataset,
|
64 |
+
eval_dataset=eval_dataset,
|
65 |
+
args=training_args,
|
66 |
+
data_collator=data_collator,
|
67 |
+
dataset_text_field="text",
|
68 |
+
max_seq_length=self.config.max_seq_length,
|
69 |
+
packing=False, # Disable packing for better control
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
# Use standard Trainer
|
73 |
+
trainer = Trainer(
|
74 |
+
model=self.model.model,
|
75 |
+
tokenizer=self.model.tokenizer,
|
76 |
+
args=training_args,
|
77 |
+
train_dataset=train_dataset,
|
78 |
+
eval_dataset=eval_dataset,
|
79 |
+
data_collator=data_collator,
|
80 |
+
)
|
81 |
+
|
82 |
+
return trainer
|
83 |
+
|
84 |
+
def load_checkpoint(self, checkpoint_path: str):
|
85 |
+
"""Load checkpoint for resuming training"""
|
86 |
+
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
87 |
+
|
88 |
+
if self.init_from == "resume":
|
89 |
+
# Load the model from checkpoint
|
90 |
+
self.model.load_checkpoint(checkpoint_path)
|
91 |
+
|
92 |
+
# Update trainer with loaded model
|
93 |
+
self.trainer.model = self.model.model
|
94 |
+
|
95 |
+
logger.info("Checkpoint loaded successfully")
|
96 |
+
elif self.init_from == "pretrained":
|
97 |
+
# Model is already loaded from pretrained
|
98 |
+
logger.info("Using pretrained model")
|
99 |
+
else:
|
100 |
+
logger.info("Starting from scratch")
|
101 |
+
|
102 |
+
def train(self):
|
103 |
+
"""Start training"""
|
104 |
+
logger.info("Starting training")
|
105 |
+
|
106 |
+
# Load checkpoint if resuming
|
107 |
+
if self.init_from == "resume":
|
108 |
+
checkpoint_path = "/input-checkpoint"
|
109 |
+
if os.path.exists(checkpoint_path):
|
110 |
+
self.load_checkpoint(checkpoint_path)
|
111 |
+
else:
|
112 |
+
logger.warning(f"Checkpoint path {checkpoint_path} not found, starting from scratch")
|
113 |
+
|
114 |
+
# Start training
|
115 |
+
try:
|
116 |
+
train_result = self.trainer.train()
|
117 |
+
|
118 |
+
# Save the final model
|
119 |
+
self.trainer.save_model()
|
120 |
+
|
121 |
+
# Save training results
|
122 |
+
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
123 |
+
json.dump(train_result.metrics, f, indent=2)
|
124 |
+
|
125 |
+
logger.info("Training completed successfully!")
|
126 |
+
logger.info(f"Training metrics: {train_result.metrics}")
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Training failed: {e}")
|
130 |
+
raise
|
131 |
+
|
132 |
+
def evaluate(self):
|
133 |
+
"""Evaluate the model"""
|
134 |
+
logger.info("Starting evaluation")
|
135 |
+
|
136 |
+
try:
|
137 |
+
eval_results = self.trainer.evaluate()
|
138 |
+
|
139 |
+
# Save evaluation results
|
140 |
+
with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
|
141 |
+
json.dump(eval_results, f, indent=2)
|
142 |
+
|
143 |
+
logger.info(f"Evaluation completed: {eval_results}")
|
144 |
+
return eval_results
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
logger.error(f"Evaluation failed: {e}")
|
148 |
+
raise
|
149 |
+
|
150 |
+
def save_model(self, path: Optional[str] = None):
|
151 |
+
"""Save the trained model"""
|
152 |
+
save_path = path or self.output_dir
|
153 |
+
logger.info(f"Saving model to {save_path}")
|
154 |
+
|
155 |
+
try:
|
156 |
+
self.trainer.save_model(save_path)
|
157 |
+
self.model.tokenizer.save_pretrained(save_path)
|
158 |
+
|
159 |
+
# Save training configuration
|
160 |
+
if self.config:
|
161 |
+
config_dict = {k: v for k, v in self.config.__dict__.items()
|
162 |
+
if not k.startswith('_')}
|
163 |
+
with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
|
164 |
+
json.dump(config_dict, f, indent=2, default=str)
|
165 |
+
|
166 |
+
logger.info("Model saved successfully!")
|
167 |
+
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Failed to save model: {e}")
|
170 |
+
raise
|
171 |
+
|
172 |
+
class SmolLM3DPOTrainer:
|
173 |
+
"""DPO Trainer for SmolLM3 preference optimization"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
model,
|
178 |
+
dataset,
|
179 |
+
config,
|
180 |
+
output_dir: str,
|
181 |
+
ref_model=None
|
182 |
+
):
|
183 |
+
self.model = model
|
184 |
+
self.dataset = dataset
|
185 |
+
self.config = config
|
186 |
+
self.output_dir = output_dir
|
187 |
+
self.ref_model = ref_model
|
188 |
+
|
189 |
+
# Setup DPO trainer
|
190 |
+
self.trainer = self._setup_dpo_trainer()
|
191 |
+
|
192 |
+
def _setup_dpo_trainer(self):
|
193 |
+
"""Setup DPO trainer"""
|
194 |
+
from trl import DPOTrainer
|
195 |
+
|
196 |
+
# Get training arguments
|
197 |
+
training_args = self.model.get_training_arguments(
|
198 |
+
output_dir=self.output_dir,
|
199 |
+
save_steps=self.config.save_steps,
|
200 |
+
eval_steps=self.config.eval_steps,
|
201 |
+
logging_steps=self.config.logging_steps,
|
202 |
+
max_steps=self.config.max_iters,
|
203 |
+
)
|
204 |
+
|
205 |
+
# Get preference dataset
|
206 |
+
train_dataset = self.dataset.get_train_dataset()
|
207 |
+
eval_dataset = self.dataset.get_eval_dataset()
|
208 |
+
|
209 |
+
# Setup DPO trainer
|
210 |
+
trainer = DPOTrainer(
|
211 |
+
model=self.model.model,
|
212 |
+
ref_model=self.ref_model,
|
213 |
+
args=training_args,
|
214 |
+
train_dataset=train_dataset,
|
215 |
+
eval_dataset=eval_dataset,
|
216 |
+
tokenizer=self.model.tokenizer,
|
217 |
+
max_prompt_length=self.config.max_seq_length // 2,
|
218 |
+
max_length=self.config.max_seq_length,
|
219 |
+
)
|
220 |
+
|
221 |
+
return trainer
|
222 |
+
|
223 |
+
def train(self):
|
224 |
+
"""Start DPO training"""
|
225 |
+
logger.info("Starting DPO training")
|
226 |
+
|
227 |
+
try:
|
228 |
+
train_result = self.trainer.train()
|
229 |
+
|
230 |
+
# Save the final model
|
231 |
+
self.trainer.save_model()
|
232 |
+
|
233 |
+
# Save training results
|
234 |
+
with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
|
235 |
+
json.dump(train_result.metrics, f, indent=2)
|
236 |
+
|
237 |
+
logger.info("DPO training completed successfully!")
|
238 |
+
logger.info(f"Training metrics: {train_result.metrics}")
|
239 |
+
|
240 |
+
except Exception as e:
|
241 |
+
logger.error(f"DPO training failed: {e}")
|
242 |
+
raise
|