Tonic commited on
Commit
d8dd7a1
·
verified ·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # PyTorch
25
+ *.pth
26
+ *.pt
27
+ *.ckpt
28
+
29
+ # Jupyter Notebook
30
+ .ipynb_checkpoints
31
+
32
+ # Environment
33
+ .env
34
+ .venv
35
+ env/
36
+ venv/
37
+ ENV/
38
+ env.bak/
39
+ venv.bak/
40
+
41
+ # IDE
42
+ .vscode/
43
+ .idea/
44
+ *.swp
45
+ *.swo
46
+ *~
47
+
48
+ # OS
49
+ .DS_Store
50
+ .DS_Store?
51
+ ._*
52
+ .Spotlight-V100
53
+ .Trashes
54
+ ehthumbs.db
55
+ Thumbs.db
56
+
57
+ # Logs
58
+ *.log
59
+ logs/
60
+ tensorboard_logs/
61
+
62
+ # Model outputs
63
+ output/
64
+ checkpoints/
65
+ models/
66
+ wandb/
67
+
68
+ # Datasets
69
+ data/
70
+ datasets/
71
+ my_dataset/
72
+ test_dataset/
73
+
74
+ # Temporary files
75
+ tmp/
76
+ temp/
77
+ *.tmp
78
+ *.temp
79
+
80
+ # Hugging Face cache
81
+ .cache/
82
+ transformers_cache/
83
+
84
+ # Accelerate
85
+ accelerate_config.yaml
86
+
87
+ # Training outputs
88
+ runs/
89
+ *.json
90
+ !config/*.json
91
+ !*.json.example
92
+
93
+ # Evaluation results
94
+ eval_results/
95
+ test_results/
96
+
97
+ # Documentation
98
+ docs/_build/
README.md ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM3 Fine-tuning for FlexAI Console
2
+
3
+ This repository provides a complete setup for fine-tuning SmolLM3 models using the FlexAI console, following the nanoGPT structure but adapted for modern transformer models.
4
+
5
+ ## Overview
6
+
7
+ SmolLM3 is a 3B-parameter transformer decoder model optimized for efficiency, long-context reasoning, and multilingual support. This setup allows you to fine-tune SmolLM3 for various tasks including:
8
+
9
+ - **Supervised Fine-tuning (SFT)**: Adapt the model for instruction following
10
+ - **Direct Preference Optimization (DPO)**: Improve model alignment
11
+ - **Long-context fine-tuning**: Support for up to 128k tokens
12
+ - **Tool calling**: Fine-tune for function calling capabilities
13
+
14
+ ## Quick Start
15
+
16
+ ### 1. Repository Setup
17
+
18
+ The repository follows the FlexAI console structure with the following key files:
19
+
20
+ - `train.py`: Main entry point script
21
+ - `config/train_smollm3.py`: Default configuration
22
+ - `model.py`: Model wrapper and loading
23
+ - `data.py`: Dataset handling and preprocessing
24
+ - `trainer.py`: Training loop and trainer setup
25
+ - `requirements.txt`: Dependencies
26
+
27
+ ### 2. FlexAI Console Configuration
28
+
29
+ When setting up a Fine Tuning Job in the FlexAI console, use these settings:
30
+
31
+ #### Basic Configuration
32
+ - **Name**: `smollm3-finetune`
33
+ - **Cluster**: Your organization's designated cluster
34
+ - **Checkpoint**: (Optional) Previous training job checkpoint
35
+ - **Node Count**: 1
36
+ - **Accelerator Count**: 1-8 (depending on your needs)
37
+
38
+ #### Repository Settings
39
+ - **Repository URL**: `https://github.com/your-username/flexai-finetune`
40
+ - **Repository Revision**: `main`
41
+
42
+ #### Dataset Configuration
43
+ - **Datasets**: Your dataset (mounted under `/input`)
44
+ - **Mount Directory**: `my_dataset`
45
+
46
+ #### Entry Point
47
+ ```
48
+ train.py config/train_smollm3.py --dataset_dir=my_dataset --init_from=resume --out_dir=/input-checkpoint --max_iters=1500
49
+ ```
50
+
51
+ ### 3. Dataset Format
52
+
53
+ The script supports multiple dataset formats:
54
+
55
+ #### Chat Format (Recommended)
56
+ ```json
57
+ [
58
+ {
59
+ "messages": [
60
+ {"role": "user", "content": "What is machine learning?"},
61
+ {"role": "assistant", "content": "Machine learning is a subset of AI..."}
62
+ ]
63
+ }
64
+ ]
65
+ ```
66
+
67
+ #### Instruction Format
68
+ ```json
69
+ [
70
+ {
71
+ "instruction": "What is machine learning?",
72
+ "output": "Machine learning is a subset of AI..."
73
+ }
74
+ ]
75
+ ```
76
+
77
+ #### User-Assistant Format
78
+ ```json
79
+ [
80
+ {
81
+ "user": "What is machine learning?",
82
+ "assistant": "Machine learning is a subset of AI..."
83
+ }
84
+ ]
85
+ ```
86
+
87
+ ### 4. Configuration Options
88
+
89
+ The default configuration in `config/train_smollm3.py` includes:
90
+
91
+ ```python
92
+ @dataclass
93
+ class SmolLM3Config:
94
+ # Model configuration
95
+ model_name: str = "HuggingFaceTB/SmolLM3-3B"
96
+ max_seq_length: int = 4096
97
+ use_flash_attention: bool = True
98
+
99
+ # Training configuration
100
+ batch_size: int = 4
101
+ gradient_accumulation_steps: int = 4
102
+ learning_rate: float = 2e-5
103
+ max_iters: int = 1000
104
+
105
+ # Mixed precision
106
+ fp16: bool = True
107
+ bf16: bool = False
108
+ ```
109
+
110
+ ### 5. Command Line Arguments
111
+
112
+ The `train.py` script accepts various arguments:
113
+
114
+ ```bash
115
+ # Basic usage
116
+ python train.py config/train_smollm3.py
117
+
118
+ # With custom parameters
119
+ python train.py config/train_smollm3.py \
120
+ --dataset_dir=my_dataset \
121
+ --out_dir=/output-checkpoint \
122
+ --init_from=resume \
123
+ --max_iters=1500 \
124
+ --batch_size=8 \
125
+ --learning_rate=1e-5 \
126
+ --max_seq_length=8192
127
+ ```
128
+
129
+ ## Advanced Usage
130
+
131
+ ### 1. Custom Configuration
132
+
133
+ Create a custom configuration file:
134
+
135
+ ```python
136
+ # config/my_config.py
137
+ from config.train_smollm3 import SmolLM3Config
138
+
139
+ config = SmolLM3Config(
140
+ model_name="HuggingFaceTB/SmolLM3-3B-Instruct",
141
+ max_seq_length=8192,
142
+ batch_size=2,
143
+ learning_rate=1e-5,
144
+ max_iters=2000,
145
+ use_flash_attention=True,
146
+ fp16=True
147
+ )
148
+ ```
149
+
150
+ ### 2. Long-Context Fine-tuning
151
+
152
+ For long-context tasks (up to 128k tokens):
153
+
154
+ ```python
155
+ config = SmolLM3Config(
156
+ max_seq_length=131072, # 128k tokens
157
+ model_name="HuggingFaceTB/SmolLM3-3B",
158
+ use_flash_attention=True,
159
+ gradient_checkpointing=True
160
+ )
161
+ ```
162
+
163
+ ### 3. DPO Training
164
+
165
+ For preference optimization, use the DPO trainer:
166
+
167
+ ```python
168
+ from trainer import SmolLM3DPOTrainer
169
+
170
+ dpo_trainer = SmolLM3DPOTrainer(
171
+ model=model,
172
+ dataset=dataset,
173
+ config=config,
174
+ output_dir="./dpo-output"
175
+ )
176
+
177
+ dpo_trainer.train()
178
+ ```
179
+
180
+ ### 4. Tool Calling Fine-tuning
181
+
182
+ Include tool calling examples in your dataset:
183
+
184
+ ```json
185
+ [
186
+ {
187
+ "messages": [
188
+ {"role": "user", "content": "What's the weather in New York?"},
189
+ {"role": "assistant", "content": "<tool_call>\n<invoke name=\"get_weather\">\n<parameter name=\"location\">New York</parameter>\n</invoke>\n</tool_call>"},
190
+ {"role": "tool", "content": "The weather in New York is 72°F and sunny."},
191
+ {"role": "assistant", "content": "The weather in New York is currently 72°F and sunny."}
192
+ ]
193
+ }
194
+ ]
195
+ ```
196
+
197
+ ## Model Variants
198
+
199
+ SmolLM3 comes in several variants:
200
+
201
+ - **SmolLM3-3B-Base**: Base model for general fine-tuning
202
+ - **SmolLM3-3B**: Instruction-tuned model
203
+ - **SmolLM3-3B-Instruct**: Enhanced instruction model
204
+ - **Quantized versions**: Available for deployment
205
+
206
+ ## Hardware Requirements
207
+
208
+ ### Minimum Requirements
209
+ - **GPU**: 16GB+ VRAM (for 3B model)
210
+ - **RAM**: 32GB+ system memory
211
+ - **Storage**: 50GB+ free space
212
+
213
+ ### Recommended
214
+ - **GPU**: A100/H100 or similar
215
+ - **RAM**: 64GB+ system memory
216
+ - **Storage**: 100GB+ SSD
217
+
218
+ ## Troubleshooting
219
+
220
+ ### Common Issues
221
+
222
+ 1. **Out of Memory (OOM)**
223
+ - Reduce `batch_size`
224
+ - Increase `gradient_accumulation_steps`
225
+ - Enable `gradient_checkpointing`
226
+ - Use `fp16` or `bf16`
227
+
228
+ 2. **Slow Training**
229
+ - Enable `flash_attention`
230
+ - Use mixed precision (`fp16`/`bf16`)
231
+ - Increase `dataloader_num_workers`
232
+
233
+ 3. **Dataset Loading Issues**
234
+ - Check dataset format
235
+ - Ensure proper JSON structure
236
+ - Verify file permissions
237
+
238
+ ### Debug Mode
239
+
240
+ Enable debug logging:
241
+
242
+ ```python
243
+ import logging
244
+ logging.basicConfig(level=logging.DEBUG)
245
+ ```
246
+
247
+ ## Evaluation
248
+
249
+ After training, evaluate your model:
250
+
251
+ ```python
252
+ from transformers import pipeline
253
+
254
+ pipe = pipeline(
255
+ task="text-generation",
256
+ model="./output-checkpoint",
257
+ device=0,
258
+ max_new_tokens=256,
259
+ do_sample=True,
260
+ temperature=0.7
261
+ )
262
+
263
+ # Test the model
264
+ messages = [{"role": "user", "content": "Explain gravity in simple terms."}]
265
+ outputs = pipe(messages)
266
+ print(outputs[0]["generated_text"][-1]["content"])
267
+ ```
268
+
269
+ ## Deployment
270
+
271
+ ### Using vLLM
272
+ ```bash
273
+ vllm serve ./output-checkpoint --enable-auto-tool-choice
274
+ ```
275
+
276
+ ### Using llama.cpp
277
+ ```bash
278
+ # Convert to GGUF format
279
+ python -m llama_cpp.convert_model ./output-checkpoint --outfile model.gguf
280
+ ```
281
+
282
+ ## Resources
283
+
284
+ - [SmolLM3 Blog Post](https://huggingface.co/blog/smollm3)
285
+ - [Model Repository](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
286
+ - [GitHub Repository](https://github.com/huggingface/smollm)
287
+ - [SmolTalk Dataset](https://huggingface.co/datasets/HuggingFaceTB/smoltalk)
288
+
289
+ ## License
290
+
291
+ This project follows the same license as the SmolLM3 model. Please refer to the Hugging Face model page for licensing information.
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for SmolLM3 fine-tuning
3
+ """
4
+
5
+ import os
6
+ import importlib.util
7
+ from typing import Any
8
+ from config.train_smollm3 import SmolLM3Config, get_config as get_default_config
9
+
10
+ def get_config(config_path: str) -> SmolLM3Config:
11
+ """Load configuration from file or return default"""
12
+ if os.path.exists(config_path):
13
+ # Load from file if it exists
14
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
15
+ config_module = importlib.util.module_from_spec(spec)
16
+ spec.loader.exec_module(config_module)
17
+
18
+ if hasattr(config_module, 'config'):
19
+ return config_module.config
20
+ else:
21
+ # Try to find a config class
22
+ for attr_name in dir(config_module):
23
+ attr = getattr(config_module, attr_name)
24
+ if isinstance(attr, SmolLM3Config):
25
+ return attr
26
+
27
+ # Return default configuration
28
+ return get_default_config(config_path)
config/train_smollm3.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Training Configuration
3
+ Based on nanoGPT structure but adapted for SmolLM3
4
+ """
5
+
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional
9
+
10
+ @dataclass
11
+ class SmolLM3Config:
12
+ """Configuration for SmolLM3 fine-tuning"""
13
+
14
+ # Model configuration
15
+ model_name: str = "HuggingFaceTB/SmolLM3-3B"
16
+ max_seq_length: int = 4096
17
+ use_flash_attention: bool = True
18
+ use_gradient_checkpointing: bool = True
19
+
20
+ # Training configuration
21
+ batch_size: int = 4
22
+ gradient_accumulation_steps: int = 4
23
+ learning_rate: float = 2e-5
24
+ weight_decay: float = 0.01
25
+ warmup_steps: int = 100
26
+ max_iters: int = 1000
27
+ eval_interval: int = 100
28
+ log_interval: int = 10
29
+ save_interval: int = 500
30
+
31
+ # Optimizer configuration
32
+ optimizer: str = "adamw"
33
+ beta1: float = 0.9
34
+ beta2: float = 0.95
35
+ eps: float = 1e-8
36
+
37
+ # Scheduler configuration
38
+ scheduler: str = "cosine"
39
+ min_lr: float = 1e-6
40
+
41
+ # Mixed precision
42
+ fp16: bool = True
43
+ bf16: bool = False
44
+
45
+ # DDP configuration
46
+ ddp_backend: str = "nccl"
47
+ ddp_find_unused_parameters: bool = False
48
+
49
+ # Logging and saving
50
+ save_steps: int = 500
51
+ eval_steps: int = 100
52
+ logging_steps: int = 10
53
+ save_total_limit: Optional[int] = 3
54
+
55
+ # Evaluation
56
+ eval_strategy: str = "steps"
57
+ metric_for_best_model: str = "eval_loss"
58
+ greater_is_better: bool = False
59
+ load_best_model_at_end: bool = True
60
+
61
+ # Data configuration
62
+ data_dir: str = "my_dataset"
63
+ train_file: str = "train.json"
64
+ validation_file: Optional[str] = None
65
+ test_file: Optional[str] = None
66
+
67
+ # Chat template configuration
68
+ use_chat_template: bool = True
69
+ chat_template_kwargs: dict = None
70
+
71
+ def __post_init__(self):
72
+ if self.chat_template_kwargs is None:
73
+ self.chat_template_kwargs = {
74
+ "enable_thinking": False,
75
+ "add_generation_prompt": True
76
+ }
77
+
78
+ # Validate configuration
79
+ if self.fp16 and self.bf16:
80
+ raise ValueError("Cannot use both fp16 and bf16")
81
+
82
+ if self.max_seq_length > 131072: # 128k limit
83
+ raise ValueError("max_seq_length cannot exceed 131072")
84
+
85
+ def get_config(config_path: str) -> SmolLM3Config:
86
+ """Load configuration from file or return default"""
87
+ if os.path.exists(config_path):
88
+ # Load from file if it exists
89
+ import importlib.util
90
+ spec = importlib.util.spec_from_file_location("config_module", config_path)
91
+ config_module = importlib.util.module_from_spec(spec)
92
+ spec.loader.exec_module(config_module)
93
+
94
+ if hasattr(config_module, 'config'):
95
+ return config_module.config
96
+ else:
97
+ # Try to find a config class
98
+ for attr_name in dir(config_module):
99
+ attr = getattr(config_module, attr_name)
100
+ if isinstance(attr, SmolLM3Config):
101
+ return attr
102
+
103
+ # Return default configuration
104
+ return SmolLM3Config()
105
+
106
+ # Default configuration instance
107
+ config = SmolLM3Config()
config/train_smollm3_dpo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 DPO Training Configuration
3
+ Optimized for Direct Preference Optimization
4
+ """
5
+
6
+ from config.train_smollm3 import SmolLM3Config
7
+
8
+ config = SmolLM3Config(
9
+ # Model configuration
10
+ model_name="HuggingFaceTB/SmolLM3-3B-Instruct", # Start from instruction-tuned model
11
+ max_seq_length=4096,
12
+ use_flash_attention=True,
13
+ use_gradient_checkpointing=True,
14
+
15
+ # Training configuration
16
+ batch_size=2, # Smaller batch size for DPO
17
+ gradient_accumulation_steps=4,
18
+ learning_rate=5e-6, # Very low learning rate for DPO
19
+ weight_decay=0.01,
20
+ warmup_steps=100,
21
+ max_iters=1000,
22
+
23
+ # Mixed precision
24
+ fp16=True,
25
+ bf16=False,
26
+
27
+ # Logging and saving
28
+ save_steps=200,
29
+ eval_steps=100,
30
+ logging_steps=20,
31
+
32
+ # Chat template configuration
33
+ use_chat_template=True,
34
+ chat_template_kwargs={
35
+ "enable_thinking": False, # Disable reasoning for preference learning
36
+ "add_generation_prompt": True
37
+ }
38
+ )
config/train_smollm3_long_context.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Long-Context Training Configuration
3
+ Optimized for long-context tasks (up to 128k tokens)
4
+ """
5
+
6
+ from config.train_smollm3 import SmolLM3Config
7
+
8
+ config = SmolLM3Config(
9
+ # Model configuration
10
+ model_name="HuggingFaceTB/SmolLM3-3B",
11
+ max_seq_length=131072, # 128k tokens
12
+ use_flash_attention=True,
13
+ use_gradient_checkpointing=True,
14
+
15
+ # Training configuration
16
+ batch_size=1, # Reduced for long sequences
17
+ gradient_accumulation_steps=8, # Increased to maintain effective batch size
18
+ learning_rate=1e-5, # Lower learning rate for stability
19
+ weight_decay=0.01,
20
+ warmup_steps=200,
21
+ max_iters=500,
22
+
23
+ # Mixed precision
24
+ fp16=True,
25
+ bf16=False,
26
+
27
+ # Logging and saving
28
+ save_steps=100,
29
+ eval_steps=50,
30
+ logging_steps=10,
31
+
32
+ # Chat template configuration
33
+ use_chat_template=True,
34
+ chat_template_kwargs={
35
+ "enable_thinking": True, # Enable reasoning mode
36
+ "add_generation_prompt": True
37
+ }
38
+ )
create_sample_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Sample Dataset Creation Script
4
+ Creates sample datasets for testing SmolLM3 fine-tuning
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import argparse
10
+ from data import create_sample_dataset
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(description='Create sample dataset for SmolLM3 fine-tuning')
14
+ parser.add_argument('--output_dir', type=str, default='my_dataset',
15
+ help='Output directory for the dataset')
16
+ parser.add_argument('--format', type=str, default='chat',
17
+ choices=['chat', 'instruction', 'user_assistant'],
18
+ help='Dataset format')
19
+ parser.add_argument('--num_samples', type=int, default=100,
20
+ help='Number of samples to create')
21
+
22
+ args = parser.parse_args()
23
+
24
+ # Create sample dataset
25
+ output_path = create_sample_dataset(args.output_dir)
26
+
27
+ print(f"Sample dataset created in: {output_path}")
28
+ print(f"Format: {args.format}")
29
+ print(f"Samples: {args.num_samples}")
30
+ print("\nFiles created:")
31
+ print(f"- {os.path.join(output_path, 'train.json')}")
32
+ print(f"- {os.path.join(output_path, 'validation.json')}")
33
+
34
+ # Show sample data
35
+ with open(os.path.join(output_path, 'train.json'), 'r') as f:
36
+ data = json.load(f)
37
+ print(f"\nSample data:")
38
+ print(json.dumps(data[0], indent=2))
39
+
40
+ if __name__ == '__main__':
41
+ main()
data.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Dataset Handler
3
+ Handles data loading, preprocessing, and tokenization for SmolLM3 fine-tuning
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ from typing import Dict, List, Optional, Union
10
+ from datasets import Dataset, load_dataset
11
+ from transformers import PreTrainedTokenizer
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class SmolLM3Dataset:
17
+ """Dataset handler for SmolLM3 fine-tuning"""
18
+
19
+ def __init__(
20
+ self,
21
+ data_path: str,
22
+ tokenizer: PreTrainedTokenizer,
23
+ max_seq_length: int = 4096,
24
+ use_chat_template: bool = True,
25
+ chat_template_kwargs: Optional[Dict] = None
26
+ ):
27
+ self.data_path = data_path
28
+ self.tokenizer = tokenizer
29
+ self.max_seq_length = max_seq_length
30
+ self.use_chat_template = use_chat_template
31
+ self.chat_template_kwargs = chat_template_kwargs or {}
32
+
33
+ # Load and process dataset
34
+ self.dataset = self._load_dataset()
35
+ self.processed_dataset = self._process_dataset()
36
+
37
+ def _load_dataset(self) -> Dataset:
38
+ """Load dataset from various formats"""
39
+ logger.info(f"Loading dataset from {self.data_path}")
40
+
41
+ # Check if it's a Hugging Face dataset
42
+ if os.path.isdir(self.data_path):
43
+ # Local directory
44
+ try:
45
+ dataset = load_dataset("json", data_files={
46
+ "train": os.path.join(self.data_path, "train.json"),
47
+ "validation": os.path.join(self.data_path, "validation.json") if os.path.exists(os.path.join(self.data_path, "validation.json")) else None,
48
+ "test": os.path.join(self.data_path, "test.json") if os.path.exists(os.path.join(self.data_path, "test.json")) else None
49
+ })
50
+ logger.info("Loaded dataset from local JSON files")
51
+ return dataset
52
+ except Exception as e:
53
+ logger.warning(f"Failed to load as JSON dataset: {e}")
54
+
55
+ # Try to load as a single JSON file
56
+ if os.path.isfile(self.data_path) and self.data_path.endswith('.json'):
57
+ try:
58
+ with open(self.data_path, 'r', encoding='utf-8') as f:
59
+ data = json.load(f)
60
+
61
+ # Convert to dataset format
62
+ if isinstance(data, list):
63
+ dataset = Dataset.from_list(data)
64
+ else:
65
+ dataset = Dataset.from_dict(data)
66
+
67
+ logger.info("Loaded dataset from single JSON file")
68
+ return dataset
69
+ except Exception as e:
70
+ logger.error(f"Failed to load JSON file: {e}")
71
+ raise
72
+
73
+ # Try to load as a Hugging Face dataset name
74
+ try:
75
+ dataset = load_dataset(self.data_path)
76
+ logger.info(f"Loaded Hugging Face dataset: {self.data_path}")
77
+ return dataset
78
+ except Exception as e:
79
+ logger.error(f"Failed to load dataset: {e}")
80
+ raise
81
+
82
+ def _process_dataset(self) -> Dataset:
83
+ """Process the dataset for training"""
84
+ logger.info("Processing dataset for training")
85
+
86
+ def format_chat_template(example):
87
+ """Format example using chat template"""
88
+ if self.use_chat_template:
89
+ try:
90
+ # Handle different input formats
91
+ if "messages" in example:
92
+ messages = example["messages"]
93
+ elif "conversations" in example:
94
+ messages = example["conversations"]
95
+ elif "user" in example and "assistant" in example:
96
+ messages = [
97
+ {"role": "user", "content": example["user"]},
98
+ {"role": "assistant", "content": example["assistant"]}
99
+ ]
100
+ elif "instruction" in example and "output" in example:
101
+ messages = [
102
+ {"role": "user", "content": example["instruction"]},
103
+ {"role": "assistant", "content": example["output"]}
104
+ ]
105
+ elif "prompt" in example and "completion" in example:
106
+ messages = [
107
+ {"role": "user", "content": example["prompt"]},
108
+ {"role": "assistant", "content": example["completion"]}
109
+ ]
110
+ else:
111
+ # Fallback: treat as plain text
112
+ return {"text": str(example)}
113
+
114
+ # Apply chat template
115
+ text = self.tokenizer.apply_chat_template(
116
+ messages,
117
+ tokenize=False,
118
+ **self.chat_template_kwargs
119
+ )
120
+ return {"text": text}
121
+ except Exception as e:
122
+ logger.warning(f"Failed to apply chat template: {e}")
123
+ # Fallback to plain text
124
+ return {"text": str(example)}
125
+ else:
126
+ # Use plain text
127
+ if "text" in example:
128
+ return {"text": example["text"]}
129
+ else:
130
+ return {"text": str(example)}
131
+
132
+ def tokenize_function(examples):
133
+ """Tokenize the examples"""
134
+ # Tokenize the texts
135
+ tokenized = self.tokenizer(
136
+ examples["text"],
137
+ truncation=True,
138
+ padding=False,
139
+ max_length=self.max_seq_length,
140
+ return_overflowing_tokens=True,
141
+ return_length=True,
142
+ )
143
+
144
+ # Calculate input length
145
+ input_length = [len(x) for x in tokenized["input_ids"]]
146
+
147
+ # Create labels (same as input_ids for causal LM)
148
+ tokenized["labels"] = tokenized["input_ids"].copy()
149
+
150
+ return {
151
+ "input_ids": tokenized["input_ids"],
152
+ "attention_mask": tokenized["attention_mask"],
153
+ "labels": tokenized["labels"],
154
+ "length": input_length,
155
+ }
156
+
157
+ # Process the dataset
158
+ processed_dataset = self.dataset.map(
159
+ format_chat_template,
160
+ remove_columns=self.dataset["train"].column_names,
161
+ desc="Formatting dataset"
162
+ )
163
+
164
+ # Tokenize the dataset
165
+ tokenized_dataset = processed_dataset.map(
166
+ tokenize_function,
167
+ remove_columns=processed_dataset["train"].column_names,
168
+ desc="Tokenizing dataset",
169
+ batched=True,
170
+ )
171
+
172
+ logger.info(f"Dataset processed. Train samples: {len(tokenized_dataset['train'])}")
173
+ if "validation" in tokenized_dataset:
174
+ logger.info(f"Validation samples: {len(tokenized_dataset['validation'])}")
175
+
176
+ return tokenized_dataset
177
+
178
+ def get_train_dataset(self) -> Dataset:
179
+ """Get training dataset"""
180
+ return self.processed_dataset["train"]
181
+
182
+ def get_eval_dataset(self) -> Optional[Dataset]:
183
+ """Get evaluation dataset if available"""
184
+ if "validation" in self.processed_dataset:
185
+ return self.processed_dataset["validation"]
186
+ elif "test" in self.processed_dataset:
187
+ return self.processed_dataset["test"]
188
+ else:
189
+ return None
190
+
191
+ def get_data_collator(self):
192
+ """Get data collator for training"""
193
+ from transformers import DataCollatorForLanguageModeling
194
+
195
+ return DataCollatorForLanguageModeling(
196
+ tokenizer=self.tokenizer,
197
+ mlm=False, # We're doing causal LM, not masked LM
198
+ )
199
+
200
+ def create_sample_dataset(output_path: str = "my_dataset"):
201
+ """Create a sample dataset for testing"""
202
+ os.makedirs(output_path, exist_ok=True)
203
+
204
+ # Sample conversations
205
+ conversations = [
206
+ {
207
+ "messages": [
208
+ {"role": "user", "content": "What is machine learning?"},
209
+ {"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed."}
210
+ ]
211
+ },
212
+ {
213
+ "messages": [
214
+ {"role": "user", "content": "Explain gravity in simple terms."},
215
+ {"role": "assistant", "content": "Gravity is the force that pulls objects toward each other, like how the Earth pulls things down to the ground."}
216
+ ]
217
+ },
218
+ {
219
+ "messages": [
220
+ {"role": "user", "content": "How do I make a cup of coffee?"},
221
+ {"role": "assistant", "content": "To make a cup of coffee: 1) Boil water, 2) Add coffee grounds to a filter, 3) Pour hot water over the grounds, 4) Let it brew for a few minutes, 5) Enjoy!"}
222
+ ]
223
+ }
224
+ ]
225
+
226
+ # Split into train/validation
227
+ train_data = conversations[:2]
228
+ validation_data = conversations[2:]
229
+
230
+ # Save to files
231
+ with open(os.path.join(output_path, "train.json"), 'w', encoding='utf-8') as f:
232
+ json.dump(train_data, f, indent=2, ensure_ascii=False)
233
+
234
+ with open(os.path.join(output_path, "validation.json"), 'w', encoding='utf-8') as f:
235
+ json.dump(validation_data, f, indent=2, ensure_ascii=False)
236
+
237
+ logger.info(f"Sample dataset created in {output_path}")
238
+ return output_path
model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Model Wrapper
3
+ Handles model loading, tokenizer, and training setup
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ AutoConfig,
13
+ TrainingArguments,
14
+ Trainer
15
+ )
16
+ from typing import Optional, Dict, Any
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class SmolLM3Model:
22
+ """Wrapper for SmolLM3 model and tokenizer"""
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str = "HuggingFaceTB/SmolLM3-3B",
27
+ max_seq_length: int = 4096,
28
+ config: Optional[Any] = None,
29
+ device_map: Optional[str] = None,
30
+ torch_dtype: Optional[torch.dtype] = None
31
+ ):
32
+ self.model_name = model_name
33
+ self.max_seq_length = max_seq_length
34
+ self.config = config
35
+
36
+ # Set device and dtype
37
+ if torch_dtype is None:
38
+ if torch.cuda.is_available():
39
+ self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
40
+ else:
41
+ self.torch_dtype = torch.float32
42
+ else:
43
+ self.torch_dtype = torch_dtype
44
+
45
+ if device_map is None:
46
+ self.device_map = "auto" if torch.cuda.is_available() else "cpu"
47
+ else:
48
+ self.device_map = device_map
49
+
50
+ # Load tokenizer and model
51
+ self._load_tokenizer()
52
+ self._load_model()
53
+
54
+ def _load_tokenizer(self):
55
+ """Load the tokenizer"""
56
+ logger.info(f"Loading tokenizer from {self.model_name}")
57
+ try:
58
+ self.tokenizer = AutoTokenizer.from_pretrained(
59
+ self.model_name,
60
+ trust_remote_code=True,
61
+ use_fast=True
62
+ )
63
+
64
+ # Set pad token if not present
65
+ if self.tokenizer.pad_token is None:
66
+ self.tokenizer.pad_token = self.tokenizer.eos_token
67
+
68
+ logger.info(f"Tokenizer loaded successfully. Vocab size: {self.tokenizer.vocab_size}")
69
+
70
+ except Exception as e:
71
+ logger.error(f"Failed to load tokenizer: {e}")
72
+ raise
73
+
74
+ def _load_model(self):
75
+ """Load the model"""
76
+ logger.info(f"Loading model from {self.model_name}")
77
+ try:
78
+ # Load model configuration
79
+ model_config = AutoConfig.from_pretrained(
80
+ self.model_name,
81
+ trust_remote_code=True
82
+ )
83
+
84
+ # Update configuration if needed
85
+ if hasattr(model_config, 'max_position_embeddings'):
86
+ model_config.max_position_embeddings = self.max_seq_length
87
+
88
+ # Load model
89
+ self.model = AutoModelForCausalLM.from_pretrained(
90
+ self.model_name,
91
+ config=model_config,
92
+ torch_dtype=self.torch_dtype,
93
+ device_map=self.device_map,
94
+ trust_remote_code=True,
95
+ use_flash_attention_2=self.config.use_flash_attention if self.config else True,
96
+ use_cache=False # Disable KV cache for training
97
+ )
98
+
99
+ # Enable gradient checkpointing if specified
100
+ if self.config and self.config.use_gradient_checkpointing:
101
+ self.model.gradient_checkpointing_enable()
102
+
103
+ logger.info(f"Model loaded successfully. Parameters: {self.model.num_parameters():,}")
104
+
105
+ except Exception as e:
106
+ logger.error(f"Failed to load model: {e}")
107
+ raise
108
+
109
+ def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
110
+ """Get training arguments for the Trainer"""
111
+ if self.config is None:
112
+ raise ValueError("Config is required to get training arguments")
113
+
114
+ # Merge config with kwargs
115
+ training_args = {
116
+ "output_dir": output_dir,
117
+ "per_device_train_batch_size": self.config.batch_size,
118
+ "per_device_eval_batch_size": self.config.batch_size,
119
+ "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
120
+ "learning_rate": self.config.learning_rate,
121
+ "weight_decay": self.config.weight_decay,
122
+ "warmup_steps": self.config.warmup_steps,
123
+ "max_steps": self.config.max_iters,
124
+ "save_steps": self.config.save_steps,
125
+ "eval_steps": self.config.eval_steps,
126
+ "logging_steps": self.config.logging_steps,
127
+ "save_total_limit": self.config.save_total_limit,
128
+ "evaluation_strategy": self.config.eval_strategy,
129
+ "metric_for_best_model": self.config.metric_for_best_model,
130
+ "greater_is_better": self.config.greater_is_better,
131
+ "load_best_model_at_end": self.config.load_best_model_at_end,
132
+ "fp16": self.config.fp16,
133
+ "bf16": self.config.bf16,
134
+ "ddp_backend": self.config.ddp_backend,
135
+ "ddp_find_unused_parameters": self.config.ddp_find_unused_parameters,
136
+ "report_to": "none", # Disable external logging
137
+ "remove_unused_columns": False,
138
+ "dataloader_pin_memory": False,
139
+ "group_by_length": True,
140
+ "length_column_name": "length",
141
+ "ignore_data_skip": False,
142
+ "seed": 42,
143
+ "data_seed": 42,
144
+ "dataloader_num_workers": 4,
145
+ "max_grad_norm": 1.0,
146
+ "optim": self.config.optimizer,
147
+ "lr_scheduler_type": self.config.scheduler,
148
+ "warmup_ratio": 0.1,
149
+ "save_strategy": "steps",
150
+ "logging_strategy": "steps",
151
+ "prediction_loss_only": True,
152
+ }
153
+
154
+ # Override with kwargs
155
+ training_args.update(kwargs)
156
+
157
+ return TrainingArguments(**training_args)
158
+
159
+ def save_pretrained(self, path: str):
160
+ """Save model and tokenizer"""
161
+ logger.info(f"Saving model and tokenizer to {path}")
162
+ os.makedirs(path, exist_ok=True)
163
+
164
+ self.model.save_pretrained(path)
165
+ self.tokenizer.save_pretrained(path)
166
+
167
+ # Save configuration
168
+ if self.config:
169
+ import json
170
+ config_dict = {k: v for k, v in self.config.__dict__.items()
171
+ if not k.startswith('_')}
172
+ with open(os.path.join(path, 'training_config.json'), 'w') as f:
173
+ json.dump(config_dict, f, indent=2, default=str)
174
+
175
+ def load_checkpoint(self, checkpoint_path: str):
176
+ """Load model from checkpoint"""
177
+ logger.info(f"Loading checkpoint from {checkpoint_path}")
178
+ try:
179
+ self.model = AutoModelForCausalLM.from_pretrained(
180
+ checkpoint_path,
181
+ torch_dtype=self.torch_dtype,
182
+ device_map=self.device_map,
183
+ trust_remote_code=True
184
+ )
185
+ logger.info("Checkpoint loaded successfully")
186
+ except Exception as e:
187
+ logger.error(f"Failed to load checkpoint: {e}")
188
+ raise
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ transformers>=4.53.0
4
+ datasets>=2.14.0
5
+ accelerate>=0.20.0
6
+ trl>=0.7.0
7
+
8
+ # Hugging Face ecosystem
9
+ huggingface-hub>=0.16.0
10
+ tokenizers>=0.13.0
11
+
12
+ # Training and optimization
13
+ flash-attn>=2.0.0
14
+ xformers>=0.0.20
15
+ bitsandbytes>=0.41.0
16
+
17
+ # Utilities
18
+ numpy>=1.24.0
19
+ pandas>=2.0.0
20
+ scikit-learn>=1.3.0
21
+ tqdm>=4.65.0
22
+ wandb>=0.15.0
23
+
24
+ # Optional: for evaluation
25
+ lighteval>=0.1.0
26
+ evaluate>=0.4.0
27
+
28
+ # Optional: for deployment
29
+ vllm>=0.2.0
30
+ sentencepiece>=0.1.99
31
+
32
+ # Development
33
+ pytest>=7.0.0
34
+ black>=23.0.0
35
+ isort>=5.12.0
test_setup.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test Setup Script
4
+ Verifies that all components are working correctly
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def test_imports():
18
+ """Test that all required modules can be imported"""
19
+ logger.info("Testing imports...")
20
+
21
+ try:
22
+ import transformers
23
+ logger.info(f"✓ transformers {transformers.__version__}")
24
+ except ImportError as e:
25
+ logger.error(f"✗ transformers: {e}")
26
+ return False
27
+
28
+ try:
29
+ import datasets
30
+ logger.info(f"✓ datasets {datasets.__version__}")
31
+ except ImportError as e:
32
+ logger.error(f"✗ datasets: {e}")
33
+ return False
34
+
35
+ try:
36
+ import trl
37
+ logger.info(f"✓ trl {trl.__version__}")
38
+ except ImportError as e:
39
+ logger.error(f"✗ trl: {e}")
40
+ return False
41
+
42
+ try:
43
+ import accelerate
44
+ logger.info(f"✓ accelerate {accelerate.__version__}")
45
+ except ImportError as e:
46
+ logger.error(f"✗ accelerate: {e}")
47
+ return False
48
+
49
+ return True
50
+
51
+ def test_local_imports():
52
+ """Test that local modules can be imported"""
53
+ logger.info("Testing local imports...")
54
+
55
+ try:
56
+ from config import get_config
57
+ logger.info("✓ config module")
58
+ except ImportError as e:
59
+ logger.error(f"✗ config module: {e}")
60
+ return False
61
+
62
+ try:
63
+ from model import SmolLM3Model
64
+ logger.info("✓ model module")
65
+ except ImportError as e:
66
+ logger.error(f"✗ model module: {e}")
67
+ return False
68
+
69
+ try:
70
+ from data import SmolLM3Dataset
71
+ logger.info("✓ data module")
72
+ except ImportError as e:
73
+ logger.error(f"✗ data module: {e}")
74
+ return False
75
+
76
+ try:
77
+ from trainer import SmolLM3Trainer
78
+ logger.info("✓ trainer module")
79
+ except ImportError as e:
80
+ logger.error(f"✗ trainer module: {e}")
81
+ return False
82
+
83
+ return True
84
+
85
+ def test_config():
86
+ """Test configuration loading"""
87
+ logger.info("Testing configuration...")
88
+
89
+ try:
90
+ from config import get_config
91
+ config = get_config("config/train_smollm3.py")
92
+ logger.info(f"✓ Configuration loaded: {config.model_name}")
93
+ return True
94
+ except Exception as e:
95
+ logger.error(f"✗ Configuration loading failed: {e}")
96
+ return False
97
+
98
+ def test_dataset_creation():
99
+ """Test dataset creation"""
100
+ logger.info("Testing dataset creation...")
101
+
102
+ try:
103
+ from data import create_sample_dataset
104
+ output_path = create_sample_dataset("test_dataset")
105
+
106
+ # Check if files were created
107
+ train_file = os.path.join(output_path, "train.json")
108
+ val_file = os.path.join(output_path, "validation.json")
109
+
110
+ if os.path.exists(train_file) and os.path.exists(val_file):
111
+ logger.info("✓ Sample dataset created successfully")
112
+
113
+ # Clean up
114
+ import shutil
115
+ shutil.rmtree(output_path)
116
+ return True
117
+ else:
118
+ logger.error("✗ Dataset files not created")
119
+ return False
120
+ except Exception as e:
121
+ logger.error(f"✗ Dataset creation failed: {e}")
122
+ return False
123
+
124
+ def test_gpu_availability():
125
+ """Test GPU availability"""
126
+ logger.info("Testing GPU availability...")
127
+
128
+ if torch.cuda.is_available():
129
+ logger.info(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
130
+ logger.info(f"✓ CUDA version: {torch.version.cuda}")
131
+ logger.info(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
132
+ return True
133
+ else:
134
+ logger.warning("⚠ No GPU available, will use CPU")
135
+ return True
136
+
137
+ def test_model_loading():
138
+ """Test model loading (without downloading)"""
139
+ logger.info("Testing model loading...")
140
+
141
+ try:
142
+ from transformers import AutoTokenizer, AutoConfig
143
+
144
+ # Test tokenizer loading
145
+ tokenizer = AutoTokenizer.from_pretrained(
146
+ "HuggingFaceTB/SmolLM3-3B",
147
+ trust_remote_code=True,
148
+ use_fast=True
149
+ )
150
+ logger.info(f"✓ Tokenizer loaded, vocab size: {tokenizer.vocab_size}")
151
+
152
+ # Test config loading
153
+ config = AutoConfig.from_pretrained(
154
+ "HuggingFaceTB/SmolLM3-3B",
155
+ trust_remote_code=True
156
+ )
157
+ logger.info(f"✓ Config loaded, model type: {config.model_type}")
158
+
159
+ return True
160
+ except Exception as e:
161
+ logger.error(f"✗ Model loading test failed: {e}")
162
+ return False
163
+
164
+ def main():
165
+ """Run all tests"""
166
+ logger.info("Starting SmolLM3 setup tests...")
167
+
168
+ tests = [
169
+ ("Import Tests", test_imports),
170
+ ("Local Import Tests", test_local_imports),
171
+ ("Configuration Tests", test_config),
172
+ ("Dataset Creation Tests", test_dataset_creation),
173
+ ("GPU Availability Tests", test_gpu_availability),
174
+ ("Model Loading Tests", test_model_loading),
175
+ ]
176
+
177
+ passed = 0
178
+ total = len(tests)
179
+
180
+ for test_name, test_func in tests:
181
+ logger.info(f"\n{'='*50}")
182
+ logger.info(f"Running: {test_name}")
183
+ logger.info('='*50)
184
+
185
+ try:
186
+ if test_func():
187
+ passed += 1
188
+ logger.info(f"✓ {test_name} PASSED")
189
+ else:
190
+ logger.error(f"✗ {test_name} FAILED")
191
+ except Exception as e:
192
+ logger.error(f"✗ {test_name} FAILED with exception: {e}")
193
+
194
+ logger.info(f"\n{'='*50}")
195
+ logger.info(f"Test Results: {passed}/{total} tests passed")
196
+ logger.info('='*50)
197
+
198
+ if passed == total:
199
+ logger.info("🎉 All tests passed! Setup is ready for SmolLM3 fine-tuning.")
200
+ return 0
201
+ else:
202
+ logger.error("❌ Some tests failed. Please check the errors above.")
203
+ return 1
204
+
205
+ if __name__ == '__main__':
206
+ sys.exit(main())
train.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SmolLM3 Fine-tuning Script for FlexAI Console
4
+ Based on the nanoGPT structure but adapted for SmolLM3 model
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import json
11
+ import torch
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Optional, Dict, Any
15
+
16
+ # Add the current directory to the path for imports
17
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
18
+
19
+ from config import get_config
20
+ from model import SmolLM3Model
21
+ from data import SmolLM3Dataset
22
+ from trainer import SmolLM3Trainer
23
+
24
+ def setup_logging():
25
+ """Setup logging configuration"""
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
29
+ handlers=[
30
+ logging.StreamHandler(sys.stdout),
31
+ logging.FileHandler('training.log')
32
+ ]
33
+ )
34
+ return logging.getLogger(__name__)
35
+
36
+ def parse_args():
37
+ """Parse command line arguments"""
38
+ parser = argparse.ArgumentParser(description='SmolLM3 Fine-tuning Script')
39
+
40
+ # Configuration file
41
+ parser.add_argument('config', type=str, help='Path to configuration file')
42
+
43
+ # Dataset arguments
44
+ parser.add_argument('--dataset_dir', type=str, default='my_dataset',
45
+ help='Path to dataset directory within /input')
46
+
47
+ # Checkpoint arguments
48
+ parser.add_argument('--out_dir', type=str, default='/output-checkpoint',
49
+ help='Output directory for checkpoints')
50
+ parser.add_argument('--init_from', type=str, default='scratch',
51
+ choices=['scratch', 'resume', 'pretrained'],
52
+ help='Initialization method')
53
+
54
+ # Training arguments
55
+ parser.add_argument('--max_iters', type=int, default=None,
56
+ help='Maximum number of training iterations')
57
+ parser.add_argument('--batch_size', type=int, default=None,
58
+ help='Batch size for training')
59
+ parser.add_argument('--learning_rate', type=float, default=None,
60
+ help='Learning rate')
61
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=None,
62
+ help='Gradient accumulation steps')
63
+
64
+ # Model arguments
65
+ parser.add_argument('--model_name', type=str,
66
+ default='HuggingFaceTB/SmolLM3-3B',
67
+ help='Model name or path')
68
+ parser.add_argument('--max_seq_length', type=int, default=4096,
69
+ help='Maximum sequence length')
70
+
71
+ # Logging and saving
72
+ parser.add_argument('--save_steps', type=int, default=500,
73
+ help='Save checkpoint every N steps')
74
+ parser.add_argument('--eval_steps', type=int, default=100,
75
+ help='Evaluate every N steps')
76
+ parser.add_argument('--logging_steps', type=int, default=10,
77
+ help='Log every N steps')
78
+
79
+ return parser.parse_args()
80
+
81
+ def main():
82
+ """Main training function"""
83
+ args = parse_args()
84
+ logger = setup_logging()
85
+
86
+ logger.info("Starting SmolLM3 fine-tuning...")
87
+ logger.info(f"Arguments: {vars(args)}")
88
+
89
+ # Load configuration
90
+ config = get_config(args.config)
91
+
92
+ # Override config with command line arguments
93
+ if args.max_iters is not None:
94
+ config.max_iters = args.max_iters
95
+ if args.batch_size is not None:
96
+ config.batch_size = args.batch_size
97
+ if args.learning_rate is not None:
98
+ config.learning_rate = args.learning_rate
99
+ if args.gradient_accumulation_steps is not None:
100
+ config.gradient_accumulation_steps = args.gradient_accumulation_steps
101
+
102
+ # Setup paths
103
+ dataset_path = os.path.join('/input', args.dataset_dir)
104
+ output_path = args.out_dir
105
+
106
+ # Ensure output directory exists
107
+ os.makedirs(output_path, exist_ok=True)
108
+
109
+ logger.info(f"Dataset path: {dataset_path}")
110
+ logger.info(f"Output path: {output_path}")
111
+
112
+ # Initialize model
113
+ model = SmolLM3Model(
114
+ model_name=args.model_name,
115
+ max_seq_length=args.max_seq_length,
116
+ config=config
117
+ )
118
+
119
+ # Load dataset
120
+ dataset = SmolLM3Dataset(
121
+ data_path=dataset_path,
122
+ tokenizer=model.tokenizer,
123
+ max_seq_length=args.max_seq_length
124
+ )
125
+
126
+ # Initialize trainer
127
+ trainer = SmolLM3Trainer(
128
+ model=model,
129
+ dataset=dataset,
130
+ config=config,
131
+ output_dir=output_path,
132
+ init_from=args.init_from
133
+ )
134
+
135
+ # Start training
136
+ try:
137
+ trainer.train()
138
+ logger.info("Training completed successfully!")
139
+ except Exception as e:
140
+ logger.error(f"Training failed: {e}")
141
+ raise
142
+
143
+ if __name__ == '__main__':
144
+ main()
trainer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM3 Trainer
3
+ Handles the training loop and integrates with Hugging Face Trainer
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import logging
9
+ from typing import Optional, Dict, Any
10
+ from transformers import Trainer, TrainingArguments
11
+ from trl import SFTTrainer
12
+ import json
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class SmolLM3Trainer:
17
+ """Trainer for SmolLM3 fine-tuning"""
18
+
19
+ def __init__(
20
+ self,
21
+ model,
22
+ dataset,
23
+ config,
24
+ output_dir: str,
25
+ init_from: str = "scratch",
26
+ use_sft_trainer: bool = True
27
+ ):
28
+ self.model = model
29
+ self.dataset = dataset
30
+ self.config = config
31
+ self.output_dir = output_dir
32
+ self.init_from = init_from
33
+ self.use_sft_trainer = use_sft_trainer
34
+
35
+ # Setup trainer
36
+ self.trainer = self._setup_trainer()
37
+
38
+ def _setup_trainer(self):
39
+ """Setup the trainer"""
40
+ logger.info("Setting up trainer")
41
+
42
+ # Get training arguments
43
+ training_args = self.model.get_training_arguments(
44
+ output_dir=self.output_dir,
45
+ save_steps=self.config.save_steps,
46
+ eval_steps=self.config.eval_steps,
47
+ logging_steps=self.config.logging_steps,
48
+ max_steps=self.config.max_iters,
49
+ )
50
+
51
+ # Get datasets
52
+ train_dataset = self.dataset.get_train_dataset()
53
+ eval_dataset = self.dataset.get_eval_dataset()
54
+
55
+ # Get data collator
56
+ data_collator = self.dataset.get_data_collator()
57
+
58
+ if self.use_sft_trainer:
59
+ # Use SFTTrainer for supervised fine-tuning
60
+ trainer = SFTTrainer(
61
+ model=self.model.model,
62
+ tokenizer=self.model.tokenizer,
63
+ train_dataset=train_dataset,
64
+ eval_dataset=eval_dataset,
65
+ args=training_args,
66
+ data_collator=data_collator,
67
+ dataset_text_field="text",
68
+ max_seq_length=self.config.max_seq_length,
69
+ packing=False, # Disable packing for better control
70
+ )
71
+ else:
72
+ # Use standard Trainer
73
+ trainer = Trainer(
74
+ model=self.model.model,
75
+ tokenizer=self.model.tokenizer,
76
+ args=training_args,
77
+ train_dataset=train_dataset,
78
+ eval_dataset=eval_dataset,
79
+ data_collator=data_collator,
80
+ )
81
+
82
+ return trainer
83
+
84
+ def load_checkpoint(self, checkpoint_path: str):
85
+ """Load checkpoint for resuming training"""
86
+ logger.info(f"Loading checkpoint from {checkpoint_path}")
87
+
88
+ if self.init_from == "resume":
89
+ # Load the model from checkpoint
90
+ self.model.load_checkpoint(checkpoint_path)
91
+
92
+ # Update trainer with loaded model
93
+ self.trainer.model = self.model.model
94
+
95
+ logger.info("Checkpoint loaded successfully")
96
+ elif self.init_from == "pretrained":
97
+ # Model is already loaded from pretrained
98
+ logger.info("Using pretrained model")
99
+ else:
100
+ logger.info("Starting from scratch")
101
+
102
+ def train(self):
103
+ """Start training"""
104
+ logger.info("Starting training")
105
+
106
+ # Load checkpoint if resuming
107
+ if self.init_from == "resume":
108
+ checkpoint_path = "/input-checkpoint"
109
+ if os.path.exists(checkpoint_path):
110
+ self.load_checkpoint(checkpoint_path)
111
+ else:
112
+ logger.warning(f"Checkpoint path {checkpoint_path} not found, starting from scratch")
113
+
114
+ # Start training
115
+ try:
116
+ train_result = self.trainer.train()
117
+
118
+ # Save the final model
119
+ self.trainer.save_model()
120
+
121
+ # Save training results
122
+ with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
123
+ json.dump(train_result.metrics, f, indent=2)
124
+
125
+ logger.info("Training completed successfully!")
126
+ logger.info(f"Training metrics: {train_result.metrics}")
127
+
128
+ except Exception as e:
129
+ logger.error(f"Training failed: {e}")
130
+ raise
131
+
132
+ def evaluate(self):
133
+ """Evaluate the model"""
134
+ logger.info("Starting evaluation")
135
+
136
+ try:
137
+ eval_results = self.trainer.evaluate()
138
+
139
+ # Save evaluation results
140
+ with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
141
+ json.dump(eval_results, f, indent=2)
142
+
143
+ logger.info(f"Evaluation completed: {eval_results}")
144
+ return eval_results
145
+
146
+ except Exception as e:
147
+ logger.error(f"Evaluation failed: {e}")
148
+ raise
149
+
150
+ def save_model(self, path: Optional[str] = None):
151
+ """Save the trained model"""
152
+ save_path = path or self.output_dir
153
+ logger.info(f"Saving model to {save_path}")
154
+
155
+ try:
156
+ self.trainer.save_model(save_path)
157
+ self.model.tokenizer.save_pretrained(save_path)
158
+
159
+ # Save training configuration
160
+ if self.config:
161
+ config_dict = {k: v for k, v in self.config.__dict__.items()
162
+ if not k.startswith('_')}
163
+ with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
164
+ json.dump(config_dict, f, indent=2, default=str)
165
+
166
+ logger.info("Model saved successfully!")
167
+
168
+ except Exception as e:
169
+ logger.error(f"Failed to save model: {e}")
170
+ raise
171
+
172
+ class SmolLM3DPOTrainer:
173
+ """DPO Trainer for SmolLM3 preference optimization"""
174
+
175
+ def __init__(
176
+ self,
177
+ model,
178
+ dataset,
179
+ config,
180
+ output_dir: str,
181
+ ref_model=None
182
+ ):
183
+ self.model = model
184
+ self.dataset = dataset
185
+ self.config = config
186
+ self.output_dir = output_dir
187
+ self.ref_model = ref_model
188
+
189
+ # Setup DPO trainer
190
+ self.trainer = self._setup_dpo_trainer()
191
+
192
+ def _setup_dpo_trainer(self):
193
+ """Setup DPO trainer"""
194
+ from trl import DPOTrainer
195
+
196
+ # Get training arguments
197
+ training_args = self.model.get_training_arguments(
198
+ output_dir=self.output_dir,
199
+ save_steps=self.config.save_steps,
200
+ eval_steps=self.config.eval_steps,
201
+ logging_steps=self.config.logging_steps,
202
+ max_steps=self.config.max_iters,
203
+ )
204
+
205
+ # Get preference dataset
206
+ train_dataset = self.dataset.get_train_dataset()
207
+ eval_dataset = self.dataset.get_eval_dataset()
208
+
209
+ # Setup DPO trainer
210
+ trainer = DPOTrainer(
211
+ model=self.model.model,
212
+ ref_model=self.ref_model,
213
+ args=training_args,
214
+ train_dataset=train_dataset,
215
+ eval_dataset=eval_dataset,
216
+ tokenizer=self.model.tokenizer,
217
+ max_prompt_length=self.config.max_seq_length // 2,
218
+ max_length=self.config.max_seq_length,
219
+ )
220
+
221
+ return trainer
222
+
223
+ def train(self):
224
+ """Start DPO training"""
225
+ logger.info("Starting DPO training")
226
+
227
+ try:
228
+ train_result = self.trainer.train()
229
+
230
+ # Save the final model
231
+ self.trainer.save_model()
232
+
233
+ # Save training results
234
+ with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
235
+ json.dump(train_result.metrics, f, indent=2)
236
+
237
+ logger.info("DPO training completed successfully!")
238
+ logger.info(f"Training metrics: {train_result.metrics}")
239
+
240
+ except Exception as e:
241
+ logger.error(f"DPO training failed: {e}")
242
+ raise