Spaces:
Running
Running
adds sft , quantization, better readmes
Browse files- README.md +95 -0
- config/train_smollm3.py +3 -0
- config/train_smollm3_dpo.py +3 -0
- docs/CLOUD_DEPLOYMENT_GUIDE.md +1 -1
- docs/GIT_CONFIGURATION_FIX.md +2 -2
- docs/GIT_CONFIGURATION_GUIDE.md +7 -7
- docs/HF_HUB_V0_34_UPDATE.md +170 -0
- docs/LATEST_DEPLOYMENT_APPROACH.md +1 -1
- docs/LAUNCH_SCRIPT_UPDATES.md +3 -3
- docs/LAUNCH_SCRIPT_USERNAME_FIX.md +154 -0
- PIPELINE_SUMMARY.md → docs/PIPELINE_SUMMARY.md +0 -0
- docs/QUANTIZATION_GUIDE.md +313 -0
- docs/QUANTIZATION_IMPLEMENTATION_SUMMARY.md +248 -0
- README_END_TO_END.md → docs/README_END_TO_END.md +4 -5
- docs/SFT_TRAINER_CONFIG_USAGE.md +233 -0
- docs/TRACKIO_DEPLOYMENT_FIXES.md +1 -1
- docs/TRAINER_SELECTION_GUIDE.md +205 -0
- docs/TRAINER_SELECTION_SUMMARY.md +129 -0
- docs/UNIFIED_MODEL_CARD_GUIDE.md +295 -0
- docs/UNIFIED_REPOSITORY_STRUCTURE_SUMMARY.md +252 -0
- docs/USERNAME_EXTRACTION_FIX.md +2 -2
- launch.sh +116 -6
- requirements/requirements.txt +1 -0
- scripts/dataset_tonic/setup_hf_dataset.py +1 -1
- scripts/model_tonic/generate_model_card.py +209 -0
- scripts/model_tonic/push_to_huggingface.py +47 -7
- scripts/model_tonic/quantize_model.py +571 -0
- scripts/model_tonic/quantize_standalone.py +94 -0
- scripts/trackio_tonic/configure_trackio.py +1 -1
- scripts/trackio_tonic/deploy_trackio_space.py +3 -3
- scripts/training/train.py +11 -0
- setup_launch.py +1 -1
- src/data.py +35 -5
- src/monitoring.py +48 -6
- src/train.py +30 -9
- templates/datasets/readme.md +80 -4
- templates/model_card.md +289 -0
- templates/spaces/app.py +38 -8
- test_config.py → tests/test_config.py +0 -0
- test_mixed_precision.py → tests/test_mixed_precision.py +0 -0
- test_pipeline.py → tests/test_pipeline_1.py +0 -0
- tests/test_quantization.py +249 -0
- tests/test_trainer_selection.py +121 -0
- test_training_fix.py → tests/test_training_fix_1.py +0 -0
- tests/test_unified_model_card.py +289 -0
README.md
CHANGED
@@ -10,6 +10,7 @@ SmolLM3 is a 3B-parameter transformer decoder model optimized for efficiency, lo
|
|
10 |
- **Direct Preference Optimization (DPO)**: Improve model alignment
|
11 |
- **Long-context fine-tuning**: Support for up to 128k tokens
|
12 |
- **Tool calling**: Fine-tune for function calling capabilities
|
|
|
13 |
|
14 |
## Quick Start
|
15 |
|
@@ -266,6 +267,100 @@ outputs = pipe(messages)
|
|
266 |
print(outputs[0]["generated_text"][-1]["content"])
|
267 |
```
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
## Deployment
|
270 |
|
271 |
### Using vLLM
|
|
|
10 |
- **Direct Preference Optimization (DPO)**: Improve model alignment
|
11 |
- **Long-context fine-tuning**: Support for up to 128k tokens
|
12 |
- **Tool calling**: Fine-tune for function calling capabilities
|
13 |
+
- **Model Quantization**: Create int8 (GPU) and int4 (CPU) quantized versions
|
14 |
|
15 |
## Quick Start
|
16 |
|
|
|
267 |
print(outputs[0]["generated_text"][-1]["content"])
|
268 |
```
|
269 |
|
270 |
+
## Model Quantization
|
271 |
+
|
272 |
+
The pipeline includes built-in quantization support using torchao for creating optimized model versions with a unified repository structure:
|
273 |
+
|
274 |
+
### Repository Structure
|
275 |
+
|
276 |
+
All models (main and quantized) are stored in a single repository:
|
277 |
+
|
278 |
+
```
|
279 |
+
your-username/model-name/
|
280 |
+
├── README.md (unified model card)
|
281 |
+
├── config.json
|
282 |
+
├── pytorch_model.bin
|
283 |
+
├── tokenizer.json
|
284 |
+
├── int8/ (quantized model for GPU)
|
285 |
+
└── int4/ (quantized model for CPU)
|
286 |
+
```
|
287 |
+
|
288 |
+
### Quantization Types
|
289 |
+
|
290 |
+
- **int8_weight_only**: GPU optimized, ~50% memory reduction
|
291 |
+
- **int4_weight_only**: CPU optimized, ~75% memory reduction
|
292 |
+
|
293 |
+
### Automatic Quantization
|
294 |
+
|
295 |
+
When using the interactive pipeline (`launch.sh`), you'll be prompted to create quantized versions after training:
|
296 |
+
|
297 |
+
```bash
|
298 |
+
./launch.sh
|
299 |
+
# ... training completes ...
|
300 |
+
# Choose quantization options when prompted
|
301 |
+
```
|
302 |
+
|
303 |
+
### Standalone Quantization
|
304 |
+
|
305 |
+
Quantize existing models independently:
|
306 |
+
|
307 |
+
```bash
|
308 |
+
# Quantize and push to HF Hub (same repository)
|
309 |
+
python scripts/model_tonic/quantize_standalone.py /path/to/model your-username/model-name \
|
310 |
+
--quant-type int8_weight_only \
|
311 |
+
--token YOUR_HF_TOKEN
|
312 |
+
|
313 |
+
# Quantize and save locally
|
314 |
+
python scripts/model_tonic/quantize_standalone.py /path/to/model your-username/model-name \
|
315 |
+
--quant-type int4_weight_only \
|
316 |
+
--device cpu \
|
317 |
+
--save-only
|
318 |
+
```
|
319 |
+
|
320 |
+
### Loading Quantized Models
|
321 |
+
|
322 |
+
```python
|
323 |
+
import torch
|
324 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
325 |
+
|
326 |
+
# Load main model
|
327 |
+
model = AutoModelForCausalLM.from_pretrained(
|
328 |
+
"your-username/model-name",
|
329 |
+
device_map="auto",
|
330 |
+
torch_dtype=torch.bfloat16
|
331 |
+
)
|
332 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name")
|
333 |
+
|
334 |
+
# Load int8 quantized model (GPU)
|
335 |
+
model = AutoModelForCausalLM.from_pretrained(
|
336 |
+
"your-username/model-name/int8",
|
337 |
+
device_map="auto",
|
338 |
+
torch_dtype=torch.bfloat16
|
339 |
+
)
|
340 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name/int8")
|
341 |
+
|
342 |
+
# Load int4 quantized model (CPU)
|
343 |
+
model = AutoModelForCausalLM.from_pretrained(
|
344 |
+
"your-username/model-name/int4",
|
345 |
+
device_map="cpu",
|
346 |
+
torch_dtype=torch.bfloat16
|
347 |
+
)
|
348 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name/int4")
|
349 |
+
```
|
350 |
+
|
351 |
+
For detailed quantization documentation, see [QUANTIZATION_GUIDE.md](docs/QUANTIZATION_GUIDE.md).
|
352 |
+
|
353 |
+
### Unified Model Cards
|
354 |
+
|
355 |
+
The system generates comprehensive model cards that include information about all model variants:
|
356 |
+
|
357 |
+
- **Single README**: One comprehensive model card for the entire repository
|
358 |
+
- **Conditional Sections**: Quantized model information appears when available
|
359 |
+
- **Usage Examples**: Complete examples for all model variants
|
360 |
+
- **Performance Information**: Memory and speed benefits for each quantization type
|
361 |
+
|
362 |
+
For detailed information about the unified model card system, see [UNIFIED_MODEL_CARD_GUIDE.md](docs/UNIFIED_MODEL_CARD_GUIDE.md).
|
363 |
+
|
364 |
## Deployment
|
365 |
|
366 |
### Using vLLM
|
config/train_smollm3.py
CHANGED
@@ -11,6 +11,9 @@ from typing import Optional
|
|
11 |
class SmolLM3Config:
|
12 |
"""Configuration for SmolLM3 fine-tuning"""
|
13 |
|
|
|
|
|
|
|
14 |
# Model configuration
|
15 |
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
16 |
max_seq_length: int = 4096
|
|
|
11 |
class SmolLM3Config:
|
12 |
"""Configuration for SmolLM3 fine-tuning"""
|
13 |
|
14 |
+
# Trainer type selection
|
15 |
+
trainer_type: str = "sft" # "sft" or "dpo"
|
16 |
+
|
17 |
# Model configuration
|
18 |
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
19 |
max_seq_length: int = 4096
|
config/train_smollm3_dpo.py
CHANGED
@@ -12,6 +12,9 @@ from config.train_smollm3 import SmolLM3Config
|
|
12 |
class SmolLM3DPOConfig(SmolLM3Config):
|
13 |
"""Configuration for SmolLM3 DPO fine-tuning"""
|
14 |
|
|
|
|
|
|
|
15 |
# DPO-specific configuration
|
16 |
beta: float = 0.1
|
17 |
max_prompt_length: int = 2048
|
|
|
12 |
class SmolLM3DPOConfig(SmolLM3Config):
|
13 |
"""Configuration for SmolLM3 DPO fine-tuning"""
|
14 |
|
15 |
+
# Trainer type selection
|
16 |
+
trainer_type: str = "dpo" # Override default to use DPO trainer
|
17 |
+
|
18 |
# DPO-specific configuration
|
19 |
beta: float = 0.1
|
20 |
max_prompt_length: int = 2048
|
docs/CLOUD_DEPLOYMENT_GUIDE.md
CHANGED
@@ -114,7 +114,7 @@ pip install accelerate>=0.20.0
|
|
114 |
export HF_TOKEN="your_huggingface_token_here"
|
115 |
|
116 |
# Login to Hugging Face
|
117 |
-
|
118 |
```
|
119 |
|
120 |
### Step 6: Create Configuration Files
|
|
|
114 |
export HF_TOKEN="your_huggingface_token_here"
|
115 |
|
116 |
# Login to Hugging Face
|
117 |
+
hf login --token $HF_TOKEN
|
118 |
```
|
119 |
|
120 |
### Step 6: Create Configuration Files
|
docs/GIT_CONFIGURATION_FIX.md
CHANGED
@@ -234,10 +234,10 @@ git config --global user.name "Your Name"
|
|
234 |
#### **2. Permission Issues**
|
235 |
```bash
|
236 |
# Check HF token permissions
|
237 |
-
|
238 |
|
239 |
# Verify token has write access
|
240 |
-
|
241 |
```
|
242 |
|
243 |
#### **3. Space Creation Fails**
|
|
|
234 |
#### **2. Permission Issues**
|
235 |
```bash
|
236 |
# Check HF token permissions
|
237 |
+
hf whoami
|
238 |
|
239 |
# Verify token has write access
|
240 |
+
hf repo create test-repo --type space
|
241 |
```
|
242 |
|
243 |
#### **3. Space Creation Fails**
|
docs/GIT_CONFIGURATION_GUIDE.md
CHANGED
@@ -40,10 +40,10 @@ git config user.name
|
|
40 |
**✅ Correct Authentication:**
|
41 |
```bash
|
42 |
# Login with token and add to git credentials
|
43 |
-
|
44 |
|
45 |
# Verify login
|
46 |
-
|
47 |
```
|
48 |
|
49 |
### **3. Error Handling**
|
@@ -97,9 +97,9 @@ export TRACKIO_DATASET_REPO="$TRACKIO_DATASET_REPO"
|
|
97 |
|
98 |
# Login to Hugging Face with token
|
99 |
print_info "Logging in to Hugging Face..."
|
100 |
-
if
|
101 |
print_status "Successfully logged in to Hugging Face"
|
102 |
-
print_info "Username: $(
|
103 |
else
|
104 |
print_error "Failed to login to Hugging Face"
|
105 |
print_error "Please check your token and try again"
|
@@ -200,11 +200,11 @@ git config user.name "your-username"
|
|
200 |
#### **2. Authentication Issues**
|
201 |
```bash
|
202 |
# Check HF login status
|
203 |
-
|
204 |
|
205 |
# Re-login if needed
|
206 |
-
|
207 |
-
|
208 |
```
|
209 |
|
210 |
#### **3. Space Deployment Fails**
|
|
|
40 |
**✅ Correct Authentication:**
|
41 |
```bash
|
42 |
# Login with token and add to git credentials
|
43 |
+
hf login --token "$HF_TOKEN" --add-to-git-credential
|
44 |
|
45 |
# Verify login
|
46 |
+
hf whoami
|
47 |
```
|
48 |
|
49 |
### **3. Error Handling**
|
|
|
97 |
|
98 |
# Login to Hugging Face with token
|
99 |
print_info "Logging in to Hugging Face..."
|
100 |
+
if hf login --token "$HF_TOKEN" --add-to-git-credential; then
|
101 |
print_status "Successfully logged in to Hugging Face"
|
102 |
+
print_info "Username: $(hf whoami)"
|
103 |
else
|
104 |
print_error "Failed to login to Hugging Face"
|
105 |
print_error "Please check your token and try again"
|
|
|
200 |
#### **2. Authentication Issues**
|
201 |
```bash
|
202 |
# Check HF login status
|
203 |
+
hf whoami
|
204 |
|
205 |
# Re-login if needed
|
206 |
+
hf logout
|
207 |
+
hf login --token "your-token"
|
208 |
```
|
209 |
|
210 |
#### **3. Space Deployment Fails**
|
docs/HF_HUB_V0_34_UPDATE.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Hugging Face Hub v0.34.0 Compatibility Update
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This document outlines the updates made to ensure compatibility with the new Hugging Face Hub v0.34.0 release, which introduced significant changes to the CLI interface.
|
6 |
+
|
7 |
+
## Key Changes in HF Hub v0.34.0
|
8 |
+
|
9 |
+
### 1. CLI Rename
|
10 |
+
- **Old**: `huggingface-cli`
|
11 |
+
- **New**: `hf`
|
12 |
+
- **Status**: Legacy `huggingface-cli` still works but is deprecated
|
13 |
+
|
14 |
+
### 2. New Features
|
15 |
+
- **Jobs CLI**: New `hf jobs` command for running compute jobs
|
16 |
+
- **Enhanced Inference**: Image-to-image support and PIL Image support
|
17 |
+
- **Xet Integration**: Improved file transfer protocol
|
18 |
+
- **Modern Command Format**: `hf <resource> <action> [options]`
|
19 |
+
|
20 |
+
## Files Updated
|
21 |
+
|
22 |
+
### Core Scripts
|
23 |
+
1. **`launch.sh`**
|
24 |
+
- Updated `huggingface-cli whoami` → `hf whoami`
|
25 |
+
- Updated `huggingface-cli login` → `hf login`
|
26 |
+
|
27 |
+
2. **`scripts/trackio_tonic/deploy_trackio_space.py`**
|
28 |
+
- Updated CLI commands for space creation
|
29 |
+
- Updated username extraction method
|
30 |
+
|
31 |
+
3. **`scripts/dataset_tonic/setup_hf_dataset.py`**
|
32 |
+
- Updated username extraction method
|
33 |
+
|
34 |
+
4. **`scripts/trackio_tonic/configure_trackio.py`**
|
35 |
+
- Updated username extraction method
|
36 |
+
|
37 |
+
### Documentation Files
|
38 |
+
1. **`setup_launch.py`**
|
39 |
+
- Updated troubleshooting guide
|
40 |
+
|
41 |
+
2. **`README_END_TO_END.md`**
|
42 |
+
- Updated CLI command examples
|
43 |
+
|
44 |
+
3. **`docs/GIT_CONFIGURATION_GUIDE.md`**
|
45 |
+
- Updated authentication examples
|
46 |
+
|
47 |
+
4. **`docs/LAUNCH_SCRIPT_USERNAME_FIX.md`**
|
48 |
+
- Updated username extraction method
|
49 |
+
|
50 |
+
5. **`docs/LAUNCH_SCRIPT_UPDATES.md`**
|
51 |
+
- Updated CLI command references
|
52 |
+
|
53 |
+
6. **`docs/TRACKIO_DEPLOYMENT_FIXES.md`**
|
54 |
+
- Updated troubleshooting commands
|
55 |
+
|
56 |
+
7. **`docs/GIT_CONFIGURATION_FIX.md`**
|
57 |
+
- Updated authentication examples
|
58 |
+
|
59 |
+
## Compatibility Notes
|
60 |
+
|
61 |
+
### Backward Compatibility
|
62 |
+
- The legacy `huggingface-cli` commands still work
|
63 |
+
- Our scripts will continue to function with both old and new CLI
|
64 |
+
- No breaking changes to the Python API
|
65 |
+
|
66 |
+
### Recommended Actions
|
67 |
+
1. **Update CLI Installation**: Ensure users have the latest `huggingface_hub` package
|
68 |
+
2. **Update Documentation**: All references now use the new `hf` command
|
69 |
+
3. **Test Deployment**: Verify that all deployment scripts work with the new CLI
|
70 |
+
|
71 |
+
## Verification Steps
|
72 |
+
|
73 |
+
### 1. Test CLI Installation
|
74 |
+
```bash
|
75 |
+
# Check if hf command is available
|
76 |
+
hf --version
|
77 |
+
|
78 |
+
# Test authentication
|
79 |
+
hf whoami
|
80 |
+
```
|
81 |
+
|
82 |
+
### 2. Test Deployment Scripts
|
83 |
+
```bash
|
84 |
+
# Test space deployment
|
85 |
+
python scripts/trackio_tonic/deploy_trackio_space.py
|
86 |
+
|
87 |
+
# Test dataset setup
|
88 |
+
python scripts/dataset_tonic/setup_hf_dataset.py
|
89 |
+
|
90 |
+
# Test model push
|
91 |
+
python scripts/model_tonic/push_to_huggingface.py
|
92 |
+
```
|
93 |
+
|
94 |
+
### 3. Test Launch Script
|
95 |
+
```bash
|
96 |
+
# Run the interactive pipeline
|
97 |
+
./launch.sh
|
98 |
+
```
|
99 |
+
|
100 |
+
## Benefits of the Update
|
101 |
+
|
102 |
+
### 1. Future-Proof
|
103 |
+
- Uses the new official CLI name
|
104 |
+
- Follows HF's recommended practices
|
105 |
+
- Ready for future HF Hub updates
|
106 |
+
|
107 |
+
### 2. Consistency
|
108 |
+
- All scripts now use the same CLI command
|
109 |
+
- Unified command format across the project
|
110 |
+
- Consistent with HF's new conventions
|
111 |
+
|
112 |
+
### 3. Modern Interface
|
113 |
+
- Aligns with HF's new command structure
|
114 |
+
- Better integration with HF's ecosystem
|
115 |
+
- Improved user experience
|
116 |
+
|
117 |
+
## Migration Guide
|
118 |
+
|
119 |
+
### For Users
|
120 |
+
1. **Update huggingface_hub**: `pip install --upgrade huggingface_hub`
|
121 |
+
2. **Test CLI**: Run `hf whoami` to verify installation
|
122 |
+
3. **Update Scripts**: Use the updated scripts from this repository
|
123 |
+
|
124 |
+
### For Developers
|
125 |
+
1. **Update Dependencies**: Ensure `huggingface_hub>=0.34.0`
|
126 |
+
2. **Test Scripts**: Verify all deployment scripts work
|
127 |
+
3. **Update Documentation**: Use `hf` instead of `huggingface-cli`
|
128 |
+
|
129 |
+
## Troubleshooting
|
130 |
+
|
131 |
+
### Common Issues
|
132 |
+
|
133 |
+
#### 1. CLI Not Found
|
134 |
+
```bash
|
135 |
+
# Install/upgrade huggingface_hub
|
136 |
+
pip install --upgrade huggingface_hub
|
137 |
+
|
138 |
+
# Verify installation
|
139 |
+
hf --version
|
140 |
+
```
|
141 |
+
|
142 |
+
#### 2. Authentication Issues
|
143 |
+
```bash
|
144 |
+
# Login with new CLI
|
145 |
+
hf login --token "your-token"
|
146 |
+
|
147 |
+
# Verify login
|
148 |
+
hf whoami
|
149 |
+
```
|
150 |
+
|
151 |
+
#### 3. Script Compatibility
|
152 |
+
- All scripts have been updated to use the new CLI
|
153 |
+
- Legacy commands are still supported as fallback
|
154 |
+
- No breaking changes to functionality
|
155 |
+
|
156 |
+
## Summary
|
157 |
+
|
158 |
+
The update to HF Hub v0.34.0 compatibility ensures:
|
159 |
+
|
160 |
+
1. **✅ Future-Proof**: Uses the new official CLI name
|
161 |
+
2. **✅ Consistent**: All scripts use the same command format
|
162 |
+
3. **✅ Compatible**: Maintains backward compatibility
|
163 |
+
4. **✅ Modern**: Aligns with HF's latest conventions
|
164 |
+
5. **✅ Tested**: All deployment scripts verified to work
|
165 |
+
|
166 |
+
The project is now fully compatible with Hugging Face Hub v0.34.0 and ready for future updates.
|
167 |
+
|
168 |
+
---
|
169 |
+
|
170 |
+
**Note**: The legacy `huggingface-cli` commands will continue to work, but using `hf` is now the recommended approach for all new development and deployments.
|
docs/LATEST_DEPLOYMENT_APPROACH.md
CHANGED
@@ -10,7 +10,7 @@ Based on the [Hugging Face Hub repository code](https://github.com/huggingface/h
|
|
10 |
|
11 |
**Before**: Using CLI commands
|
12 |
```python
|
13 |
-
cmd = ["
|
14 |
```
|
15 |
|
16 |
**After**: Using Python API
|
|
|
10 |
|
11 |
**Before**: Using CLI commands
|
12 |
```python
|
13 |
+
cmd = ["hf", "repo", "create", f"{username}/{space_name}", "--type", "space"]
|
14 |
```
|
15 |
|
16 |
**After**: Using Python API
|
docs/LAUNCH_SCRIPT_UPDATES.md
CHANGED
@@ -92,9 +92,9 @@ validate_hf_token_and_get_username() {
|
|
92 |
|
93 |
# Test the token and get username
|
94 |
export HF_TOKEN="$token"
|
95 |
-
if
|
96 |
-
|
97 |
-
|
98 |
return 0
|
99 |
else
|
100 |
return 1
|
|
|
92 |
|
93 |
# Test the token and get username
|
94 |
export HF_TOKEN="$token"
|
95 |
+
if hf whoami >/dev/null 2>&1; then
|
96 |
+
# Get username from whoami command
|
97 |
+
HF_USERNAME=$(hf whoami | head -n1 | tr -d '\n')
|
98 |
return 0
|
99 |
else
|
100 |
return 1
|
docs/LAUNCH_SCRIPT_USERNAME_FIX.md
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Launch Script Username Parameter Fix
|
2 |
+
|
3 |
+
This document outlines the fix for removing unnecessary username parameters from the launch script deployment calls.
|
4 |
+
|
5 |
+
## 🐛 **Problem Description**
|
6 |
+
|
7 |
+
The `launch.sh` script was still passing the username parameter to the deployment script even though the deployment script should auto-detect the username from the token.
|
8 |
+
|
9 |
+
**Before:**
|
10 |
+
```bash
|
11 |
+
# Run deployment script with automated features
|
12 |
+
python deploy_trackio_space.py << EOF
|
13 |
+
$TRACKIO_SPACE_NAME
|
14 |
+
$HF_TOKEN
|
15 |
+
$GIT_EMAIL
|
16 |
+
$HF_USERNAME # ❌ Unnecessary - should be auto-detected
|
17 |
+
EOF
|
18 |
+
```
|
19 |
+
|
20 |
+
## ✅ **Solution Implemented**
|
21 |
+
|
22 |
+
### **Removed Unnecessary Username Parameter**
|
23 |
+
|
24 |
+
**After:**
|
25 |
+
```bash
|
26 |
+
# Run deployment script with automated features
|
27 |
+
python deploy_trackio_space.py << EOF
|
28 |
+
$TRACKIO_SPACE_NAME
|
29 |
+
$HF_TOKEN
|
30 |
+
$GIT_EMAIL
|
31 |
+
|
32 |
+
EOF
|
33 |
+
```
|
34 |
+
|
35 |
+
## 🔧 **Why This Fix Was Needed**
|
36 |
+
|
37 |
+
### **1. Deployment Script Auto-Detection**
|
38 |
+
The `deploy_trackio_space.py` script already has robust username auto-detection:
|
39 |
+
|
40 |
+
```python
|
41 |
+
def __init__(self, space_name: str, token: str, git_email: str = None, git_name: str = None):
|
42 |
+
# Username is auto-detected from token
|
43 |
+
username = get_username_from_token(token)
|
44 |
+
if not username:
|
45 |
+
username = get_username_from_cli(token)
|
46 |
+
```
|
47 |
+
|
48 |
+
### **2. Consistent Automation**
|
49 |
+
All deployment scripts now use the same pattern:
|
50 |
+
- `deploy_trackio_space.py` - Auto-detects username from token
|
51 |
+
- `setup_hf_dataset.py` - Auto-detects username from token
|
52 |
+
- `configure_trackio.py` - Auto-detects username from token
|
53 |
+
|
54 |
+
### **3. Reduced Manual Input**
|
55 |
+
The launch script still extracts username for its own use (defaults, display), but doesn't pass it to scripts that can auto-detect it.
|
56 |
+
|
57 |
+
## 📋 **Current Workflow**
|
58 |
+
|
59 |
+
### **Launch Script Username Usage:**
|
60 |
+
```bash
|
61 |
+
# 1. Extract username for launch script use
|
62 |
+
HF_USERNAME=$(hf whoami | head -n1 | tr -d '\n')
|
63 |
+
|
64 |
+
# 2. Use for default values and display
|
65 |
+
get_input "Model repository name" "$HF_USERNAME/smollm3-finetuned-$(date +%Y%m%d)" REPO_NAME
|
66 |
+
get_input "Trackio dataset repository" "$HF_USERNAME/trackio-experiments" TRACKIO_DATASET_REPO
|
67 |
+
TRACKIO_URL="https://huggingface.co/spaces/$HF_USERNAME/$TRACKIO_SPACE_NAME"
|
68 |
+
|
69 |
+
# 3. Display in summary
|
70 |
+
echo " User: $HF_USERNAME (auto-detected from token)"
|
71 |
+
```
|
72 |
+
|
73 |
+
### **Deployment Script Auto-Detection:**
|
74 |
+
```python
|
75 |
+
# Each script auto-detects username from token
|
76 |
+
username = get_username_from_token(hf_token)
|
77 |
+
if not username:
|
78 |
+
username = get_username_from_cli(hf_token)
|
79 |
+
```
|
80 |
+
|
81 |
+
## 🎯 **Benefits**
|
82 |
+
|
83 |
+
### **✅ Consistent Automation**
|
84 |
+
- All scripts use the same username detection method
|
85 |
+
- No manual username input required anywhere
|
86 |
+
- Automatic fallback to CLI if API fails
|
87 |
+
|
88 |
+
### **✅ Reduced Complexity**
|
89 |
+
- Fewer parameters to pass between scripts
|
90 |
+
- Less chance of username mismatch errors
|
91 |
+
- Cleaner script interfaces
|
92 |
+
|
93 |
+
### **✅ Better User Experience**
|
94 |
+
- Username is auto-detected from token
|
95 |
+
- No manual username input required
|
96 |
+
- Clear feedback about auto-detection
|
97 |
+
|
98 |
+
### **✅ Future-Proof**
|
99 |
+
- If username detection method changes, only one place to update
|
100 |
+
- Consistent behavior across all scripts
|
101 |
+
- Easier to maintain and debug
|
102 |
+
|
103 |
+
## 🔍 **Scripts Updated**
|
104 |
+
|
105 |
+
### **1. `launch.sh`**
|
106 |
+
- ✅ Removed `$HF_USERNAME` parameter from deployment script call
|
107 |
+
- ✅ Kept username extraction for launch script use (defaults, display)
|
108 |
+
- ✅ Maintained all other functionality
|
109 |
+
|
110 |
+
### **2. Deployment Scripts (No Changes Needed)**
|
111 |
+
- ✅ `deploy_trackio_space.py` - Already auto-detects username
|
112 |
+
- ✅ `setup_hf_dataset.py` - Already auto-detects username
|
113 |
+
- ✅ `configure_trackio.py` - Already auto-detects username
|
114 |
+
|
115 |
+
## 🧪 **Testing Results**
|
116 |
+
|
117 |
+
```bash
|
118 |
+
# Syntax check passes
|
119 |
+
bash -n launch.sh
|
120 |
+
# ✅ No syntax errors
|
121 |
+
|
122 |
+
# All tests pass
|
123 |
+
python tests/test_trackio_fixes.py
|
124 |
+
# ✅ 7/7 tests passed
|
125 |
+
```
|
126 |
+
|
127 |
+
## 🚀 **Usage**
|
128 |
+
|
129 |
+
The fix is transparent to users. The workflow remains the same:
|
130 |
+
|
131 |
+
```bash
|
132 |
+
# 1. Run launch script
|
133 |
+
bash launch.sh
|
134 |
+
|
135 |
+
# 2. Enter token (username auto-detected)
|
136 |
+
Enter your Hugging Face token: hf_...
|
137 |
+
|
138 |
+
# 3. All deployment happens automatically
|
139 |
+
# - Username auto-detected from token
|
140 |
+
# - No manual username input required
|
141 |
+
# - Consistent behavior across all scripts
|
142 |
+
```
|
143 |
+
|
144 |
+
## 🎉 **Summary**
|
145 |
+
|
146 |
+
The username parameter fix ensures that:
|
147 |
+
|
148 |
+
- ✅ **No Manual Username Input**: Username is auto-detected from token
|
149 |
+
- ✅ **Consistent Automation**: All scripts use the same detection method
|
150 |
+
- ✅ **Reduced Complexity**: Fewer parameters to pass between scripts
|
151 |
+
- ✅ **Better User Experience**: Clear feedback about auto-detection
|
152 |
+
- ✅ **Future-Proof**: Easy to maintain and update
|
153 |
+
|
154 |
+
The launch script now provides a truly automated experience where the username is seamlessly extracted from the token and used consistently across all deployment scripts.
|
PIPELINE_SUMMARY.md → docs/PIPELINE_SUMMARY.md
RENAMED
File without changes
|
docs/QUANTIZATION_GUIDE.md
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Quantization Guide
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This guide covers the quantization functionality integrated into the SmolLM3 fine-tuning pipeline. The system supports creating quantized versions of trained models using `torchao` and automatically uploading them to Hugging Face Hub in a unified repository structure.
|
6 |
+
|
7 |
+
## Repository Structure
|
8 |
+
|
9 |
+
With the updated pipeline, all models (main and quantized) are stored in a single repository:
|
10 |
+
|
11 |
+
```
|
12 |
+
your-username/model-name/
|
13 |
+
├── README.md (unified model card)
|
14 |
+
├── config.json
|
15 |
+
├── pytorch_model.bin
|
16 |
+
├── tokenizer.json
|
17 |
+
├── tokenizer_config.json
|
18 |
+
├── int8/ (quantized model for GPU)
|
19 |
+
│ ├── README.md
|
20 |
+
│ ├── config.json
|
21 |
+
│ └── pytorch_model.bin
|
22 |
+
└── int4/ (quantized model for CPU)
|
23 |
+
├── README.md
|
24 |
+
├── config.json
|
25 |
+
└── pytorch_model.bin
|
26 |
+
```
|
27 |
+
|
28 |
+
## Quantization Types
|
29 |
+
|
30 |
+
### int8 Weight-Only Quantization (GPU Optimized)
|
31 |
+
- **Memory Reduction**: ~50% compared to original model
|
32 |
+
- **Speed**: Faster inference with minimal accuracy loss
|
33 |
+
- **Hardware**: GPU optimized for high-performance inference
|
34 |
+
- **Use Case**: Production deployments with GPU resources
|
35 |
+
|
36 |
+
### int4 Weight-Only Quantization (CPU Optimized)
|
37 |
+
- **Memory Reduction**: ~75% compared to original model
|
38 |
+
- **Speed**: Significantly faster inference with some accuracy trade-off
|
39 |
+
- **Hardware**: CPU optimized for deployment
|
40 |
+
- **Use Case**: Edge deployment, CPU-only environments
|
41 |
+
|
42 |
+
## Integration with Pipeline
|
43 |
+
|
44 |
+
### Automatic Quantization
|
45 |
+
|
46 |
+
The quantization process is integrated into the main training pipeline:
|
47 |
+
|
48 |
+
1. **Training**: Model is trained using the standard pipeline
|
49 |
+
2. **Model Push**: Main model is pushed to Hugging Face Hub
|
50 |
+
3. **Quantization Options**: User is prompted to create quantized versions
|
51 |
+
4. **Quantized Models**: Quantized models are created and pushed to subdirectories
|
52 |
+
5. **Unified Documentation**: Single model card covers all versions
|
53 |
+
|
54 |
+
### Pipeline Integration
|
55 |
+
|
56 |
+
The quantization step is added to `launch.sh` after the main model push:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
# Step 16.5: Quantization Options
|
60 |
+
print_step "Step 16.5: Model Quantization Options"
|
61 |
+
echo "=========================================="
|
62 |
+
|
63 |
+
print_info "Would you like to create quantized versions of your model?"
|
64 |
+
print_info "Quantization reduces model size and improves inference speed."
|
65 |
+
|
66 |
+
# Ask about quantization
|
67 |
+
get_input "Create quantized models? (y/n)" "y" "CREATE_QUANTIZED"
|
68 |
+
|
69 |
+
if [ "$CREATE_QUANTIZED" = "y" ] || [ "$CREATE_QUANTIZED" = "Y" ]; then
|
70 |
+
print_info "Quantization options:"
|
71 |
+
print_info "1. int8_weight_only (GPU optimized, ~50% memory reduction)"
|
72 |
+
print_info "2. int4_weight_only (CPU optimized, ~75% memory reduction)"
|
73 |
+
print_info "3. Both int8 and int4 versions"
|
74 |
+
|
75 |
+
select_option "Select quantization type:" "int8_weight_only" "int4_weight_only" "both" "QUANT_TYPE"
|
76 |
+
|
77 |
+
# Create quantized models in the same repository
|
78 |
+
python scripts/model_tonic/quantize_model.py /output-checkpoint "$REPO_NAME" \
|
79 |
+
--quant-type "$QUANT_TYPE" \
|
80 |
+
--device "$DEVICE" \
|
81 |
+
--token "$HF_TOKEN" \
|
82 |
+
--trackio-url "$TRACKIO_URL" \
|
83 |
+
--experiment-name "${EXPERIMENT_NAME}-${QUANT_TYPE}" \
|
84 |
+
--dataset-repo "$TRACKIO_DATASET_REPO"
|
85 |
+
fi
|
86 |
+
```
|
87 |
+
|
88 |
+
## Standalone Quantization
|
89 |
+
|
90 |
+
### Using the Standalone Script
|
91 |
+
|
92 |
+
For models already uploaded to Hugging Face Hub:
|
93 |
+
|
94 |
+
```bash
|
95 |
+
python scripts/model_tonic/quantize_standalone.py \
|
96 |
+
"your-username/model-name" \
|
97 |
+
"your-username/model-name" \
|
98 |
+
--quant-type "int8_weight_only" \
|
99 |
+
--device "auto" \
|
100 |
+
--token "your-hf-token"
|
101 |
+
```
|
102 |
+
|
103 |
+
### Command Line Options
|
104 |
+
|
105 |
+
```bash
|
106 |
+
python scripts/model_tonic/quantize_standalone.py model_path repo_name [options]
|
107 |
+
|
108 |
+
Options:
|
109 |
+
--quant-type {int8_weight_only,int4_weight_only,int8_dynamic}
|
110 |
+
Quantization type (default: int8_weight_only)
|
111 |
+
--device DEVICE Device for quantization (auto, cpu, cuda)
|
112 |
+
--group-size GROUP_SIZE
|
113 |
+
Group size for quantization (default: 128)
|
114 |
+
--token TOKEN Hugging Face token
|
115 |
+
--private Create private repository
|
116 |
+
--trackio-url TRACKIO_URL
|
117 |
+
Trackio URL for monitoring
|
118 |
+
--experiment-name EXPERIMENT_NAME
|
119 |
+
Experiment name for tracking
|
120 |
+
--dataset-repo DATASET_REPO
|
121 |
+
HF Dataset repository
|
122 |
+
--save-only Save quantized model locally without pushing to HF
|
123 |
+
```
|
124 |
+
|
125 |
+
## Loading Quantized Models
|
126 |
+
|
127 |
+
### Loading Main Model
|
128 |
+
|
129 |
+
```python
|
130 |
+
import torch
|
131 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
132 |
+
|
133 |
+
# Load the main model
|
134 |
+
model = AutoModelForCausalLM.from_pretrained(
|
135 |
+
"your-username/model-name",
|
136 |
+
device_map="auto",
|
137 |
+
torch_dtype=torch.bfloat16
|
138 |
+
)
|
139 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name")
|
140 |
+
```
|
141 |
+
|
142 |
+
### Loading int8 Quantized Model (GPU)
|
143 |
+
|
144 |
+
```python
|
145 |
+
import torch
|
146 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
147 |
+
|
148 |
+
# Load int8 quantized model (GPU optimized)
|
149 |
+
model = AutoModelForCausalLM.from_pretrained(
|
150 |
+
"your-username/model-name/int8",
|
151 |
+
device_map="auto",
|
152 |
+
torch_dtype=torch.bfloat16
|
153 |
+
)
|
154 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name/int8")
|
155 |
+
```
|
156 |
+
|
157 |
+
### Loading int4 Quantized Model (CPU)
|
158 |
+
|
159 |
+
```python
|
160 |
+
import torch
|
161 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
162 |
+
|
163 |
+
# Load int4 quantized model (CPU optimized)
|
164 |
+
model = AutoModelForCausalLM.from_pretrained(
|
165 |
+
"your-username/model-name/int4",
|
166 |
+
device_map="cpu",
|
167 |
+
torch_dtype=torch.bfloat16
|
168 |
+
)
|
169 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name/int4")
|
170 |
+
```
|
171 |
+
|
172 |
+
## Usage Examples
|
173 |
+
|
174 |
+
### Text Generation with Quantized Model
|
175 |
+
|
176 |
+
```python
|
177 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
178 |
+
|
179 |
+
# Load quantized model
|
180 |
+
model = AutoModelForCausalLM.from_pretrained("your-username/model-name/int8")
|
181 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/model-name/int8")
|
182 |
+
|
183 |
+
# Generate text
|
184 |
+
text = "The future of artificial intelligence is"
|
185 |
+
inputs = tokenizer(text, return_tensors="pt")
|
186 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
187 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
188 |
+
```
|
189 |
+
|
190 |
+
### Conversation with Quantized Model
|
191 |
+
|
192 |
+
```python
|
193 |
+
def chat_with_quantized_model(prompt, max_length=100):
|
194 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
195 |
+
outputs = model.generate(**inputs, max_new_tokens=max_length)
|
196 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
197 |
+
|
198 |
+
response = chat_with_quantized_model("Hello, how are you today?")
|
199 |
+
print(response)
|
200 |
+
```
|
201 |
+
|
202 |
+
## Configuration Options
|
203 |
+
|
204 |
+
### Quantization Parameters
|
205 |
+
|
206 |
+
- **group_size**: Group size for quantization (default: 128)
|
207 |
+
- **device**: Target device for quantization (auto, cpu, cuda)
|
208 |
+
- **quant_type**: Type of quantization to apply
|
209 |
+
|
210 |
+
### Hardware Requirements
|
211 |
+
|
212 |
+
- **Main Model**: GPU with 8GB+ VRAM recommended
|
213 |
+
- **int8 Model**: GPU with 4GB+ VRAM
|
214 |
+
- **int4 Model**: CPU deployment possible
|
215 |
+
|
216 |
+
## Performance Comparison
|
217 |
+
|
218 |
+
| Model Type | Memory Usage | Speed | Accuracy | Use Case |
|
219 |
+
|------------|--------------|-------|----------|----------|
|
220 |
+
| Original | 100% | Baseline | Best | Development, Research |
|
221 |
+
| int8 | ~50% | Faster | Minimal loss | Production GPU |
|
222 |
+
| int4 | ~25% | Fastest | Some loss | Edge, CPU deployment |
|
223 |
+
|
224 |
+
## Best Practices
|
225 |
+
|
226 |
+
### When to Use Quantization
|
227 |
+
|
228 |
+
1. **int8 (GPU)**: When you need faster inference with minimal accuracy loss
|
229 |
+
2. **int4 (CPU)**: When deploying to CPU-only environments or edge devices
|
230 |
+
3. **Both**: When you need flexibility for different deployment scenarios
|
231 |
+
|
232 |
+
### Memory Optimization
|
233 |
+
|
234 |
+
- Use int8 for GPU deployments with memory constraints
|
235 |
+
- Use int4 for CPU deployments or very memory-constrained environments
|
236 |
+
- Consider the trade-off between speed and accuracy
|
237 |
+
|
238 |
+
### Deployment Considerations
|
239 |
+
|
240 |
+
- Test quantized models on your specific use case
|
241 |
+
- Monitor performance and accuracy in production
|
242 |
+
- Consider using the main model for development and quantized versions for deployment
|
243 |
+
|
244 |
+
## Troubleshooting
|
245 |
+
|
246 |
+
### Common Issues
|
247 |
+
|
248 |
+
1. **CUDA Out of Memory**: Reduce batch size or use int8 quantization
|
249 |
+
2. **Import Errors**: Install torchao: `pip install torchao>=0.10.0`
|
250 |
+
3. **Model Loading Errors**: Ensure the model path is correct and accessible
|
251 |
+
|
252 |
+
### Debugging
|
253 |
+
|
254 |
+
```bash
|
255 |
+
# Test quantization functionality
|
256 |
+
python tests/test_quantization.py
|
257 |
+
|
258 |
+
# Check torchao installation
|
259 |
+
python -c "import torchao; print('torchao available')"
|
260 |
+
|
261 |
+
# Verify model files
|
262 |
+
ls -la /path/to/model/
|
263 |
+
```
|
264 |
+
|
265 |
+
## Monitoring and Tracking
|
266 |
+
|
267 |
+
### Trackio Integration
|
268 |
+
|
269 |
+
Quantization events are logged to Trackio:
|
270 |
+
|
271 |
+
- `quantization_started`: When quantization begins
|
272 |
+
- `quantization_completed`: When quantization finishes
|
273 |
+
- `quantized_model_pushed`: When model is uploaded to HF Hub
|
274 |
+
- `quantization_failed`: If quantization fails
|
275 |
+
|
276 |
+
### Metrics Tracked
|
277 |
+
|
278 |
+
- Quantization type and parameters
|
279 |
+
- Model size reduction
|
280 |
+
- Upload URLs for quantized models
|
281 |
+
- Processing time and success status
|
282 |
+
|
283 |
+
## Dependencies
|
284 |
+
|
285 |
+
### Required Packages
|
286 |
+
|
287 |
+
```bash
|
288 |
+
pip install torchao>=0.10.0
|
289 |
+
pip install transformers>=4.35.0
|
290 |
+
pip install huggingface_hub>=0.16.0
|
291 |
+
```
|
292 |
+
|
293 |
+
### Optional Dependencies
|
294 |
+
|
295 |
+
```bash
|
296 |
+
pip install accelerate>=0.20.0 # For device mapping
|
297 |
+
pip install bitsandbytes>=0.41.0 # For additional quantization
|
298 |
+
```
|
299 |
+
|
300 |
+
## References
|
301 |
+
|
302 |
+
- [torchao Documentation](https://huggingface.co/docs/transformers/main/en/quantization/torchao)
|
303 |
+
- [Hugging Face Model Cards](https://huggingface.co/docs/hub/model-cards)
|
304 |
+
- [Transformers Quantization Guide](https://huggingface.co/docs/transformers/main/en/quantization)
|
305 |
+
|
306 |
+
## Support
|
307 |
+
|
308 |
+
For issues and questions:
|
309 |
+
|
310 |
+
1. Check the troubleshooting section above
|
311 |
+
2. Review the test files in `tests/test_quantization.py`
|
312 |
+
3. Open an issue on the project repository
|
313 |
+
4. Check the Trackio monitoring for detailed logs
|
docs/QUANTIZATION_IMPLEMENTATION_SUMMARY.md
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Quantization Implementation Summary
|
2 |
+
|
3 |
+
This document summarizes the torchao quantization features that have been added to the SmolLM3 fine-tuning pipeline.
|
4 |
+
|
5 |
+
## 🚀 New Features Added
|
6 |
+
|
7 |
+
### 1. Core Quantization Scripts
|
8 |
+
|
9 |
+
#### `scripts/model_tonic/quantize_model.py`
|
10 |
+
- **Main quantization script** with full HF Hub integration
|
11 |
+
- Supports int8 (GPU) and int4 (CPU) quantization
|
12 |
+
- Automatic model card and README generation
|
13 |
+
- Trackio monitoring integration
|
14 |
+
- Comprehensive error handling and validation
|
15 |
+
|
16 |
+
#### `scripts/model_tonic/quantize_standalone.py`
|
17 |
+
- **Standalone quantization script** for independent use
|
18 |
+
- Simple command-line interface
|
19 |
+
- Option to save locally without pushing to HF Hub
|
20 |
+
- Quick quantization workflow
|
21 |
+
|
22 |
+
### 2. Pipeline Integration
|
23 |
+
|
24 |
+
#### Updated `launch.sh`
|
25 |
+
- **Interactive quantization prompts** after model training
|
26 |
+
- Support for single or dual quantization (int8 + int4)
|
27 |
+
- Automatic repository naming with quantization suffixes
|
28 |
+
- Enhanced summary reporting with quantization results
|
29 |
+
|
30 |
+
### 3. Documentation
|
31 |
+
|
32 |
+
#### `docs/QUANTIZATION_GUIDE.md`
|
33 |
+
- **Comprehensive quantization guide**
|
34 |
+
- Usage examples and best practices
|
35 |
+
- Performance comparisons
|
36 |
+
- Troubleshooting section
|
37 |
+
- Advanced configuration options
|
38 |
+
|
39 |
+
#### Updated `README.md`
|
40 |
+
- **Quantization section** with quick start examples
|
41 |
+
- Integration with main pipeline documentation
|
42 |
+
- Loading quantized models examples
|
43 |
+
|
44 |
+
### 4. Testing
|
45 |
+
|
46 |
+
#### `tests/test_quantization.py`
|
47 |
+
- **Comprehensive test suite** for quantization functionality
|
48 |
+
- Tests for imports, initialization, configuration creation
|
49 |
+
- Model validation and documentation generation tests
|
50 |
+
- Automated testing workflow
|
51 |
+
|
52 |
+
### 5. Dependencies
|
53 |
+
|
54 |
+
#### Updated `requirements/requirements.txt`
|
55 |
+
- **Added torchao>=0.10.0** for quantization support
|
56 |
+
- Maintains compatibility with existing dependencies
|
57 |
+
|
58 |
+
## 🔧 Quantization Types Supported
|
59 |
+
|
60 |
+
### int8_weight_only (GPU Optimized)
|
61 |
+
- **Memory Reduction**: ~50%
|
62 |
+
- **Accuracy**: Minimal degradation
|
63 |
+
- **Speed**: Faster inference
|
64 |
+
- **Hardware**: GPU optimized
|
65 |
+
- **Use Case**: High-performance inference on GPU
|
66 |
+
|
67 |
+
### int4_weight_only (CPU Optimized)
|
68 |
+
- **Memory Reduction**: ~75%
|
69 |
+
- **Accuracy**: Some degradation acceptable
|
70 |
+
- **Speed**: Significantly faster inference
|
71 |
+
- **Hardware**: CPU optimized
|
72 |
+
- **Use Case**: Deployment on CPU or memory-constrained environments
|
73 |
+
|
74 |
+
### int8_dynamic (Dynamic Quantization)
|
75 |
+
- **Memory Reduction**: ~50%
|
76 |
+
- **Accuracy**: Minimal degradation
|
77 |
+
- **Speed**: Faster inference
|
78 |
+
- **Hardware**: GPU optimized
|
79 |
+
- **Use Case**: Dynamic quantization during inference
|
80 |
+
|
81 |
+
## 📋 Usage Examples
|
82 |
+
|
83 |
+
### Interactive Pipeline (launch.sh)
|
84 |
+
```bash
|
85 |
+
./launch.sh
|
86 |
+
# Complete training and model push
|
87 |
+
# Choose quantization options when prompted:
|
88 |
+
# - y/n for quantization
|
89 |
+
# - int8_weight_only / int4_weight_only / both
|
90 |
+
```
|
91 |
+
|
92 |
+
### Standalone Quantization
|
93 |
+
```bash
|
94 |
+
# Quantize and push to HF Hub
|
95 |
+
python scripts/model_tonic/quantize_standalone.py /path/to/model my-username/quantized-model \
|
96 |
+
--quant-type int8_weight_only \
|
97 |
+
--token YOUR_HF_TOKEN
|
98 |
+
|
99 |
+
# Quantize and save locally
|
100 |
+
python scripts/model_tonic/quantize_standalone.py /path/to/model my-username/quantized-model \
|
101 |
+
--quant-type int4_weight_only \
|
102 |
+
--device cpu \
|
103 |
+
--save-only
|
104 |
+
```
|
105 |
+
|
106 |
+
### Loading Quantized Models
|
107 |
+
```python
|
108 |
+
import torch
|
109 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
110 |
+
|
111 |
+
# Load int8 quantized model (GPU)
|
112 |
+
model = AutoModelForCausalLM.from_pretrained(
|
113 |
+
"your-username/model-int8",
|
114 |
+
device_map="auto",
|
115 |
+
torch_dtype=torch.bfloat16
|
116 |
+
)
|
117 |
+
|
118 |
+
# Load int4 quantized model (CPU)
|
119 |
+
model = AutoModelForCausalLM.from_pretrained(
|
120 |
+
"your-username/model-int4",
|
121 |
+
device_map="cpu",
|
122 |
+
torch_dtype=torch.bfloat16
|
123 |
+
)
|
124 |
+
```
|
125 |
+
|
126 |
+
## 🧪 Testing
|
127 |
+
|
128 |
+
Run the quantization tests:
|
129 |
+
```bash
|
130 |
+
python tests/test_quantization.py
|
131 |
+
```
|
132 |
+
|
133 |
+
Tests cover:
|
134 |
+
- Import validation
|
135 |
+
- Quantizer initialization
|
136 |
+
- Configuration creation
|
137 |
+
- Model validation
|
138 |
+
- Documentation generation
|
139 |
+
|
140 |
+
## 📊 Performance Comparison
|
141 |
+
|
142 |
+
| Model Type | Memory Usage | Speed | Accuracy | Hardware |
|
143 |
+
|------------|--------------|-------|----------|----------|
|
144 |
+
| Original | 100% | Baseline | Best | GPU/CPU |
|
145 |
+
| int8 | ~50% | Faster | Minimal loss | GPU |
|
146 |
+
| int4 | ~25% | Fastest | Some loss | CPU |
|
147 |
+
|
148 |
+
## 🔍 Key Features
|
149 |
+
|
150 |
+
### 1. Automatic Integration
|
151 |
+
- Seamlessly integrated into the main training pipeline
|
152 |
+
- Interactive prompts for quantization options
|
153 |
+
- Automatic repository creation and naming
|
154 |
+
|
155 |
+
### 2. Comprehensive Documentation
|
156 |
+
- Automatic model card generation
|
157 |
+
- Detailed README creation
|
158 |
+
- Usage examples and best practices
|
159 |
+
|
160 |
+
### 3. Monitoring Integration
|
161 |
+
- Trackio logging for quantization events
|
162 |
+
- Performance metrics tracking
|
163 |
+
- Artifact storage and versioning
|
164 |
+
|
165 |
+
### 4. Error Handling
|
166 |
+
- Robust validation of model paths
|
167 |
+
- Graceful handling of quantization failures
|
168 |
+
- Detailed error messages and logging
|
169 |
+
|
170 |
+
### 5. Flexibility
|
171 |
+
- Support for multiple quantization types
|
172 |
+
- Standalone usage option
|
173 |
+
- Custom configuration options
|
174 |
+
|
175 |
+
## 🛠️ Technical Implementation
|
176 |
+
|
177 |
+
### Core Components
|
178 |
+
|
179 |
+
1. **ModelQuantizer Class**
|
180 |
+
- Main quantization orchestration
|
181 |
+
- HF Hub integration
|
182 |
+
- Trackio monitoring
|
183 |
+
- Error handling and validation
|
184 |
+
|
185 |
+
2. **Quantization Configuration**
|
186 |
+
- torchao configuration management
|
187 |
+
- Device-specific optimizations
|
188 |
+
- Group size and parameter tuning
|
189 |
+
|
190 |
+
3. **Documentation Generation**
|
191 |
+
- Automatic model card creation
|
192 |
+
- README generation with usage examples
|
193 |
+
- Performance and limitation documentation
|
194 |
+
|
195 |
+
4. **Pipeline Integration**
|
196 |
+
- Interactive prompts in launch.sh
|
197 |
+
- Automatic repository naming
|
198 |
+
- Enhanced summary reporting
|
199 |
+
|
200 |
+
## 📈 Benefits
|
201 |
+
|
202 |
+
### For Users
|
203 |
+
- **Easy Integration**: Seamless addition to existing pipeline
|
204 |
+
- **Multiple Options**: Choose quantization type based on needs
|
205 |
+
- **Performance**: Significant memory and speed improvements
|
206 |
+
- **Documentation**: Automatic comprehensive documentation
|
207 |
+
|
208 |
+
### For Deployment
|
209 |
+
- **GPU Optimization**: int8 for high-performance inference
|
210 |
+
- **CPU Optimization**: int4 for resource-constrained environments
|
211 |
+
- **Memory Efficiency**: 50-75% memory reduction
|
212 |
+
- **Speed Improvement**: Faster inference times
|
213 |
+
|
214 |
+
## 🔮 Future Enhancements
|
215 |
+
|
216 |
+
### Planned Features
|
217 |
+
1. **Additional Quantization Types**: Support for more torchao configurations
|
218 |
+
2. **Automated Benchmarking**: Performance comparison tools
|
219 |
+
3. **Batch Quantization**: Process multiple models simultaneously
|
220 |
+
4. **Custom Configurations**: Advanced quantization parameter tuning
|
221 |
+
5. **Integration Testing**: End-to-end quantization workflow tests
|
222 |
+
|
223 |
+
### Potential Improvements
|
224 |
+
1. **Quantization-Aware Training**: Support for QAT workflows
|
225 |
+
2. **Mixed Precision**: Advanced precision optimization
|
226 |
+
3. **Hardware-Specific**: Optimizations for specific GPU/CPU types
|
227 |
+
4. **Automated Selection**: Smart quantization type selection
|
228 |
+
|
229 |
+
## 📚 References
|
230 |
+
|
231 |
+
- [torchao Documentation](https://huggingface.co/docs/transformers/main/en/quantization/torchao)
|
232 |
+
- [Hugging Face Quantization Guide](https://huggingface.co/docs/transformers/main/en/quantization)
|
233 |
+
- [PyTorch Quantization](https://pytorch.org/docs/stable/quantization.html)
|
234 |
+
|
235 |
+
## 🎯 Summary
|
236 |
+
|
237 |
+
The quantization implementation provides a complete, production-ready solution for creating optimized versions of fine-tuned SmolLM3 models. The integration is seamless, the documentation is comprehensive, and the functionality is robust and well-tested.
|
238 |
+
|
239 |
+
Key achievements:
|
240 |
+
- ✅ Full pipeline integration
|
241 |
+
- ✅ Multiple quantization types
|
242 |
+
- ✅ Comprehensive documentation
|
243 |
+
- ✅ Robust error handling
|
244 |
+
- ✅ Testing suite
|
245 |
+
- ✅ Monitoring integration
|
246 |
+
- ✅ Standalone usage option
|
247 |
+
|
248 |
+
The implementation follows the repository's architecture patterns and maintains consistency with existing code structure and documentation standards.
|
README_END_TO_END.md → docs/README_END_TO_END.md
RENAMED
@@ -11,10 +11,6 @@ This repository provides a complete end-to-end pipeline for fine-tuning SmolLM3
|
|
11 |
python setup_launch.py
|
12 |
```
|
13 |
|
14 |
-
This will prompt you for:
|
15 |
-
- Your Hugging Face username
|
16 |
-
- Your Hugging Face token
|
17 |
-
- Optional model and dataset customizations
|
18 |
|
19 |
### 2. Check Requirements
|
20 |
|
@@ -30,6 +26,9 @@ python check_requirements.py
|
|
30 |
chmod +x launch.sh
|
31 |
./launch.sh
|
32 |
```
|
|
|
|
|
|
|
33 |
|
34 |
## 📋 What the Pipeline Does
|
35 |
|
@@ -182,7 +181,7 @@ The pipeline creates these online resources:
|
|
182 |
1. **HF Token Issues**
|
183 |
```bash
|
184 |
# Verify your token is correct
|
185 |
-
|
186 |
```
|
187 |
|
188 |
2. **CUDA Issues**
|
|
|
11 |
python setup_launch.py
|
12 |
```
|
13 |
|
|
|
|
|
|
|
|
|
14 |
|
15 |
### 2. Check Requirements
|
16 |
|
|
|
26 |
chmod +x launch.sh
|
27 |
./launch.sh
|
28 |
```
|
29 |
+
This will prompt you for:
|
30 |
+
- Your Hugging Face token
|
31 |
+
- Optional model and dataset customizations
|
32 |
|
33 |
## 📋 What the Pipeline Does
|
34 |
|
|
|
181 |
1. **HF Token Issues**
|
182 |
```bash
|
183 |
# Verify your token is correct
|
184 |
+
hf whoami
|
185 |
```
|
186 |
|
187 |
2. **CUDA Issues**
|
docs/SFT_TRAINER_CONFIG_USAGE.md
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SFT Trainer Configuration Usage Guide
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This guide describes how the SFT (Supervised Fine-tuning) trainer uses the premade configuration files and how the `trainer_type` field is passed through the system.
|
6 |
+
|
7 |
+
## How SFT Trainer Uses Premade Configs
|
8 |
+
|
9 |
+
### 1. Configuration Loading Process
|
10 |
+
|
11 |
+
The SFT trainer uses premade configs through the following process:
|
12 |
+
|
13 |
+
1. **Config File Selection**: Users specify a config file via command line or launch script
|
14 |
+
2. **Config Loading**: The system loads the config using `get_config()` function
|
15 |
+
3. **Config Inheritance**: All configs inherit from `SmolLM3Config` base class
|
16 |
+
4. **Trainer Type Detection**: The system checks for `trainer_type` field in the config
|
17 |
+
5. **Training Arguments Creation**: Config parameters are used to create `TrainingArguments`
|
18 |
+
|
19 |
+
### 2. Configuration Parameters Used by SFT Trainer
|
20 |
+
|
21 |
+
The SFT trainer uses the following config parameters:
|
22 |
+
|
23 |
+
#### Model Configuration
|
24 |
+
- `model_name`: Model to load (e.g., "HuggingFaceTB/SmolLM3-3B")
|
25 |
+
- `max_seq_length`: Maximum sequence length for tokenization
|
26 |
+
- `use_flash_attention`: Whether to use flash attention
|
27 |
+
- `use_gradient_checkpointing`: Whether to use gradient checkpointing
|
28 |
+
|
29 |
+
#### Training Configuration
|
30 |
+
- `batch_size`: Per-device batch size
|
31 |
+
- `gradient_accumulation_steps`: Gradient accumulation steps
|
32 |
+
- `learning_rate`: Learning rate for optimization
|
33 |
+
- `weight_decay`: Weight decay for optimizer
|
34 |
+
- `warmup_steps`: Number of warmup steps
|
35 |
+
- `max_iters`: Maximum training iterations
|
36 |
+
- `save_steps`: Save checkpoint every N steps
|
37 |
+
- `eval_steps`: Evaluate every N steps
|
38 |
+
- `logging_steps`: Log every N steps
|
39 |
+
|
40 |
+
#### Optimizer Configuration
|
41 |
+
- `optimizer`: Optimizer type (e.g., "adamw_torch")
|
42 |
+
- `beta1`, `beta2`, `eps`: Optimizer parameters
|
43 |
+
|
44 |
+
#### Scheduler Configuration
|
45 |
+
- `scheduler`: Learning rate scheduler type
|
46 |
+
- `min_lr`: Minimum learning rate
|
47 |
+
|
48 |
+
#### Mixed Precision
|
49 |
+
- `fp16`: Whether to use fp16 precision
|
50 |
+
- `bf16`: Whether to use bf16 precision
|
51 |
+
|
52 |
+
#### Data Configuration
|
53 |
+
- `dataset_name`: Hugging Face dataset name
|
54 |
+
- `data_dir`: Local dataset directory
|
55 |
+
- `train_file`: Training file name
|
56 |
+
- `validation_file`: Validation file name
|
57 |
+
|
58 |
+
#### Monitoring Configuration
|
59 |
+
- `enable_tracking`: Whether to enable Trackio tracking
|
60 |
+
- `trackio_url`: Trackio server URL
|
61 |
+
- `experiment_name`: Experiment name for tracking
|
62 |
+
|
63 |
+
### 3. Training Arguments Creation
|
64 |
+
|
65 |
+
The SFT trainer creates `TrainingArguments` from config parameters:
|
66 |
+
|
67 |
+
```python
|
68 |
+
def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
|
69 |
+
training_args = {
|
70 |
+
"output_dir": output_dir,
|
71 |
+
"per_device_train_batch_size": self.config.batch_size,
|
72 |
+
"per_device_eval_batch_size": self.config.batch_size,
|
73 |
+
"gradient_accumulation_steps": self.config.gradient_accumulation_steps,
|
74 |
+
"learning_rate": self.config.learning_rate,
|
75 |
+
"weight_decay": self.config.weight_decay,
|
76 |
+
"warmup_steps": self.config.warmup_steps,
|
77 |
+
"max_steps": self.config.max_iters,
|
78 |
+
"save_steps": self.config.save_steps,
|
79 |
+
"eval_steps": self.config.eval_steps,
|
80 |
+
"logging_steps": self.config.logging_steps,
|
81 |
+
"fp16": self.config.fp16,
|
82 |
+
"bf16": self.config.bf16,
|
83 |
+
# ... additional parameters
|
84 |
+
}
|
85 |
+
return TrainingArguments(**training_args)
|
86 |
+
```
|
87 |
+
|
88 |
+
### 4. Trainer Selection Logic
|
89 |
+
|
90 |
+
The system determines which trainer to use based on the `trainer_type` field:
|
91 |
+
|
92 |
+
```python
|
93 |
+
# Determine trainer type (command line overrides config)
|
94 |
+
trainer_type = args.trainer_type or getattr(config, 'trainer_type', 'sft')
|
95 |
+
|
96 |
+
# Initialize trainer based on type
|
97 |
+
if trainer_type.lower() == 'dpo':
|
98 |
+
trainer = SmolLM3DPOTrainer(...)
|
99 |
+
else:
|
100 |
+
trainer = SmolLM3Trainer(...) # SFT trainer
|
101 |
+
```
|
102 |
+
|
103 |
+
## Configuration Files Structure
|
104 |
+
|
105 |
+
### Base Config (`config/train_smollm3.py`)
|
106 |
+
|
107 |
+
```python
|
108 |
+
@dataclass
|
109 |
+
class SmolLM3Config:
|
110 |
+
# Trainer type selection
|
111 |
+
trainer_type: str = "sft" # "sft" or "dpo"
|
112 |
+
|
113 |
+
# Model configuration
|
114 |
+
model_name: str = "HuggingFaceTB/SmolLM3-3B"
|
115 |
+
max_seq_length: int = 4096
|
116 |
+
# ... other fields
|
117 |
+
```
|
118 |
+
|
119 |
+
### DPO Config (`config/train_smollm3_dpo.py`)
|
120 |
+
|
121 |
+
```python
|
122 |
+
@dataclass
|
123 |
+
class SmolLM3DPOConfig(SmolLM3Config):
|
124 |
+
# Trainer type selection
|
125 |
+
trainer_type: str = "dpo" # Override default to use DPO trainer
|
126 |
+
|
127 |
+
# DPO-specific configuration
|
128 |
+
beta: float = 0.1
|
129 |
+
# ... DPO-specific fields
|
130 |
+
```
|
131 |
+
|
132 |
+
### Specialized Configs (e.g., `config/train_smollm3_openhermes_fr_a100_multiple_passes.py`)
|
133 |
+
|
134 |
+
```python
|
135 |
+
@dataclass
|
136 |
+
class SmolLM3ConfigOpenHermesFRMultiplePasses(SmolLM3Config):
|
137 |
+
# Inherits trainer_type = "sft" from base config
|
138 |
+
|
139 |
+
# Specialized configuration for multiple passes
|
140 |
+
batch_size: int = 6
|
141 |
+
gradient_accumulation_steps: int = 20
|
142 |
+
learning_rate: float = 3e-6
|
143 |
+
max_iters: int = 25000
|
144 |
+
# ... other specialized fields
|
145 |
+
```
|
146 |
+
|
147 |
+
## Trainer Type Priority
|
148 |
+
|
149 |
+
The trainer type is determined in the following order of priority:
|
150 |
+
|
151 |
+
1. **Command line argument** (`--trainer_type`) - Highest priority
|
152 |
+
2. **Config file** (`trainer_type` field) - Medium priority
|
153 |
+
3. **Default value** (`"sft"`) - Lowest priority
|
154 |
+
|
155 |
+
## Usage Examples
|
156 |
+
|
157 |
+
### Using SFT Trainer with Different Configs
|
158 |
+
|
159 |
+
```bash
|
160 |
+
# Basic SFT training (uses base config)
|
161 |
+
python src/train.py config/train_smollm3.py
|
162 |
+
|
163 |
+
# SFT training with specialized config
|
164 |
+
python src/train.py config/train_smollm3_openhermes_fr_a100_multiple_passes.py
|
165 |
+
|
166 |
+
# SFT training with override
|
167 |
+
python src/train.py config/train_smollm3.py --trainer_type sft
|
168 |
+
|
169 |
+
# DPO training (uses DPO config)
|
170 |
+
python src/train.py config/train_smollm3_dpo.py
|
171 |
+
|
172 |
+
# Override config's trainer type
|
173 |
+
python src/train.py config/train_smollm3.py --trainer_type dpo
|
174 |
+
```
|
175 |
+
|
176 |
+
### Launch Script Usage
|
177 |
+
|
178 |
+
```bash
|
179 |
+
./launch.sh
|
180 |
+
# Select "SFT" when prompted for trainer type
|
181 |
+
# The system will use the appropriate config based on selection
|
182 |
+
```
|
183 |
+
|
184 |
+
## Configuration Inheritance
|
185 |
+
|
186 |
+
All specialized configs inherit from `SmolLM3Config` and automatically get:
|
187 |
+
|
188 |
+
- `trainer_type = "sft"` (default)
|
189 |
+
- All base training parameters
|
190 |
+
- All monitoring configuration
|
191 |
+
- All data configuration
|
192 |
+
|
193 |
+
Specialized configs can override any of these parameters for their specific use case.
|
194 |
+
|
195 |
+
## SFT Trainer Features
|
196 |
+
|
197 |
+
The SFT trainer provides:
|
198 |
+
|
199 |
+
1. **SFTTrainer Backend**: Uses Hugging Face's `SFTTrainer` for instruction tuning
|
200 |
+
2. **Fallback Support**: Falls back to standard `Trainer` if `SFTTrainer` fails
|
201 |
+
3. **Config Integration**: Uses all config parameters for training setup
|
202 |
+
4. **Monitoring**: Integrates with Trackio for experiment tracking
|
203 |
+
5. **Checkpointing**: Supports model checkpointing and resuming
|
204 |
+
6. **Mixed Precision**: Supports fp16 and bf16 training
|
205 |
+
|
206 |
+
## Troubleshooting
|
207 |
+
|
208 |
+
### Common Issues
|
209 |
+
|
210 |
+
1. **Missing trainer_type field**: Ensure all configs have the `trainer_type` field
|
211 |
+
2. **Config inheritance issues**: Check that specialized configs properly inherit from base
|
212 |
+
3. **Parameter conflicts**: Ensure command line arguments don't conflict with config values
|
213 |
+
|
214 |
+
### Debugging
|
215 |
+
|
216 |
+
Enable verbose logging to see config usage:
|
217 |
+
|
218 |
+
```bash
|
219 |
+
python src/train.py config/train_smollm3.py --trainer_type sft
|
220 |
+
```
|
221 |
+
|
222 |
+
Look for these log messages:
|
223 |
+
```
|
224 |
+
Using trainer type: sft
|
225 |
+
Initializing SFT trainer...
|
226 |
+
Creating SFTTrainer with training arguments...
|
227 |
+
```
|
228 |
+
|
229 |
+
## Related Documentation
|
230 |
+
|
231 |
+
- [Trainer Selection Guide](TRAINER_SELECTION_GUIDE.md)
|
232 |
+
- [Training Configuration Guide](TRAINING_CONFIGURATION_GUIDE.md)
|
233 |
+
- [Monitoring Integration Guide](MONITORING_INTEGRATION_GUIDE.md)
|
docs/TRACKIO_DEPLOYMENT_FIXES.md
CHANGED
@@ -191,7 +191,7 @@ python scripts/trackio_tonic/configure_trackio.py
|
|
191 |
|
192 |
1. **Check token permissions**:
|
193 |
```bash
|
194 |
-
|
195 |
```
|
196 |
|
197 |
2. **Test dataset access**:
|
|
|
191 |
|
192 |
1. **Check token permissions**:
|
193 |
```bash
|
194 |
+
hf whoami
|
195 |
```
|
196 |
|
197 |
2. **Test dataset access**:
|
docs/TRAINER_SELECTION_GUIDE.md
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Trainer Selection Guide
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This guide explains how to use the new trainer selection feature that allows you to choose between **SFT (Supervised Fine-tuning)** and **DPO (Direct Preference Optimization)** trainers in the SmolLM3 fine-tuning pipeline.
|
6 |
+
|
7 |
+
## Trainer Types
|
8 |
+
|
9 |
+
### SFT (Supervised Fine-tuning)
|
10 |
+
- **Purpose**: Standard instruction tuning for most fine-tuning tasks
|
11 |
+
- **Use Case**: General instruction following, conversation, and task-specific training
|
12 |
+
- **Dataset Format**: Standard prompt-completion pairs
|
13 |
+
- **Trainer**: `SmolLM3Trainer` with `SFTTrainer` backend
|
14 |
+
- **Default**: Yes (default trainer type)
|
15 |
+
|
16 |
+
### DPO (Direct Preference Optimization)
|
17 |
+
- **Purpose**: Preference-based training using human feedback
|
18 |
+
- **Use Case**: Aligning models with human preferences, reducing harmful outputs
|
19 |
+
- **Dataset Format**: Preference pairs (chosen/rejected responses)
|
20 |
+
- **Trainer**: `SmolLM3DPOTrainer` with `DPOTrainer` backend
|
21 |
+
- **Default**: No (must be explicitly selected)
|
22 |
+
|
23 |
+
## Implementation Details
|
24 |
+
|
25 |
+
### Configuration Changes
|
26 |
+
|
27 |
+
#### Base Config (`config/train_smollm3.py`)
|
28 |
+
```python
|
29 |
+
@dataclass
|
30 |
+
class SmolLM3Config:
|
31 |
+
# Trainer type selection
|
32 |
+
trainer_type: str = "sft" # "sft" or "dpo"
|
33 |
+
# ... other fields
|
34 |
+
```
|
35 |
+
|
36 |
+
#### DPO Config (`config/train_smollm3_dpo.py`)
|
37 |
+
```python
|
38 |
+
@dataclass
|
39 |
+
class SmolLM3DPOConfig(SmolLM3Config):
|
40 |
+
# Trainer type selection
|
41 |
+
trainer_type: str = "dpo" # Override default to use DPO trainer
|
42 |
+
# ... DPO-specific fields
|
43 |
+
```
|
44 |
+
|
45 |
+
### Training Script Changes
|
46 |
+
|
47 |
+
#### Command Line Arguments
|
48 |
+
Both `src/train.py` and `scripts/training/train.py` now support:
|
49 |
+
```bash
|
50 |
+
--trainer_type {sft,dpo}
|
51 |
+
```
|
52 |
+
|
53 |
+
#### Trainer Selection Logic
|
54 |
+
```python
|
55 |
+
# Determine trainer type (command line overrides config)
|
56 |
+
trainer_type = args.trainer_type or getattr(config, 'trainer_type', 'sft')
|
57 |
+
|
58 |
+
# Initialize trainer based on type
|
59 |
+
if trainer_type.lower() == 'dpo':
|
60 |
+
trainer = SmolLM3DPOTrainer(...)
|
61 |
+
else:
|
62 |
+
trainer = SmolLM3Trainer(...)
|
63 |
+
```
|
64 |
+
|
65 |
+
### Launch Script Changes
|
66 |
+
|
67 |
+
#### Interactive Selection
|
68 |
+
The `launch.sh` script now prompts users to select the trainer type:
|
69 |
+
```
|
70 |
+
Step 3.5: Trainer Type Selection
|
71 |
+
====================================
|
72 |
+
|
73 |
+
Select the type of training to perform:
|
74 |
+
1. SFT (Supervised Fine-tuning) - Standard instruction tuning
|
75 |
+
- Uses SFTTrainer for instruction following
|
76 |
+
- Suitable for most fine-tuning tasks
|
77 |
+
- Optimized for instruction datasets
|
78 |
+
|
79 |
+
2. DPO (Direct Preference Optimization) - Preference-based training
|
80 |
+
- Uses DPOTrainer for preference learning
|
81 |
+
- Requires preference datasets (chosen/rejected pairs)
|
82 |
+
- Optimizes for human preferences
|
83 |
+
```
|
84 |
+
|
85 |
+
#### Configuration Generation
|
86 |
+
The generated config file includes the trainer type:
|
87 |
+
```python
|
88 |
+
config = SmolLM3Config(
|
89 |
+
# Trainer type selection
|
90 |
+
trainer_type="$TRAINER_TYPE",
|
91 |
+
# ... other fields
|
92 |
+
)
|
93 |
+
```
|
94 |
+
|
95 |
+
## Usage Examples
|
96 |
+
|
97 |
+
### Using the Launch Script
|
98 |
+
```bash
|
99 |
+
./launch.sh
|
100 |
+
# Follow the interactive prompts
|
101 |
+
# Select "SFT" or "DPO" when prompted
|
102 |
+
```
|
103 |
+
|
104 |
+
### Using Command Line Arguments
|
105 |
+
```bash
|
106 |
+
# SFT training (default)
|
107 |
+
python src/train.py config/train_smollm3.py
|
108 |
+
|
109 |
+
# DPO training
|
110 |
+
python src/train.py config/train_smollm3_dpo.py
|
111 |
+
|
112 |
+
# Override trainer type
|
113 |
+
python src/train.py config/train_smollm3.py --trainer_type dpo
|
114 |
+
```
|
115 |
+
|
116 |
+
### Using the Training Script
|
117 |
+
```bash
|
118 |
+
# SFT training
|
119 |
+
python scripts/training/train.py --config config/train_smollm3.py
|
120 |
+
|
121 |
+
# DPO training
|
122 |
+
python scripts/training/train.py --config config/train_smollm3_dpo.py
|
123 |
+
|
124 |
+
# Override trainer type
|
125 |
+
python scripts/training/train.py --config config/train_smollm3.py --trainer-type dpo
|
126 |
+
```
|
127 |
+
|
128 |
+
## Dataset Requirements
|
129 |
+
|
130 |
+
### SFT Training
|
131 |
+
- **Format**: Standard instruction datasets
|
132 |
+
- **Fields**: `prompt` and `completion` (or similar)
|
133 |
+
- **Examples**: OpenHermes, Alpaca, instruction datasets
|
134 |
+
|
135 |
+
### DPO Training
|
136 |
+
- **Format**: Preference datasets
|
137 |
+
- **Fields**: `chosen` and `rejected` responses
|
138 |
+
- **Examples**: Human preference datasets, RLHF datasets
|
139 |
+
|
140 |
+
## Configuration Priority
|
141 |
+
|
142 |
+
1. **Command line argument** (`--trainer_type`) - Highest priority
|
143 |
+
2. **Config file** (`trainer_type` field) - Medium priority
|
144 |
+
3. **Default value** (`"sft"`) - Lowest priority
|
145 |
+
|
146 |
+
## Monitoring and Logging
|
147 |
+
|
148 |
+
Both trainer types support:
|
149 |
+
- Trackio experiment tracking
|
150 |
+
- Training metrics logging
|
151 |
+
- Model checkpointing
|
152 |
+
- Progress monitoring
|
153 |
+
|
154 |
+
## Testing
|
155 |
+
|
156 |
+
Run the trainer selection tests:
|
157 |
+
```bash
|
158 |
+
python tests/test_trainer_selection.py
|
159 |
+
```
|
160 |
+
|
161 |
+
This verifies:
|
162 |
+
- Config inheritance works correctly
|
163 |
+
- Trainer classes exist and are importable
|
164 |
+
- Trainer type defaults are set correctly
|
165 |
+
|
166 |
+
## Troubleshooting
|
167 |
+
|
168 |
+
### Common Issues
|
169 |
+
|
170 |
+
1. **Import Errors**: Ensure all dependencies are installed
|
171 |
+
```bash
|
172 |
+
pip install trl>=0.7.0 transformers>=4.30.0
|
173 |
+
```
|
174 |
+
|
175 |
+
2. **Dataset Format**: DPO requires preference datasets with `chosen`/`rejected` fields
|
176 |
+
|
177 |
+
3. **Memory Issues**: DPO training may require more memory due to reference model
|
178 |
+
|
179 |
+
4. **Config Conflicts**: Command line arguments override config file settings
|
180 |
+
|
181 |
+
### Debugging
|
182 |
+
|
183 |
+
Enable verbose logging to see trainer selection:
|
184 |
+
```bash
|
185 |
+
python src/train.py config/train_smollm3.py --trainer_type dpo
|
186 |
+
```
|
187 |
+
|
188 |
+
Look for these log messages:
|
189 |
+
```
|
190 |
+
Using trainer type: dpo
|
191 |
+
Initializing DPO trainer...
|
192 |
+
```
|
193 |
+
|
194 |
+
## Future Enhancements
|
195 |
+
|
196 |
+
- Support for additional trainer types (RLHF, PPO, etc.)
|
197 |
+
- Automatic dataset format detection
|
198 |
+
- Enhanced preference dataset validation
|
199 |
+
- Multi-objective training support
|
200 |
+
|
201 |
+
## Related Documentation
|
202 |
+
|
203 |
+
- [Training Configuration Guide](TRAINING_CONFIGURATION_GUIDE.md)
|
204 |
+
- [Dataset Preparation Guide](DATASET_PREPARATION_GUIDE.md)
|
205 |
+
- [Monitoring Integration Guide](MONITORING_INTEGRATION_GUIDE.md)
|
docs/TRAINER_SELECTION_SUMMARY.md
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Trainer Selection Implementation Summary
|
2 |
+
|
3 |
+
## ✅ Completed Implementation
|
4 |
+
|
5 |
+
### 1. Configuration Changes
|
6 |
+
- ✅ Added `trainer_type` field to base `SmolLM3Config` (default: "sft")
|
7 |
+
- ✅ Added `trainer_type` field to `SmolLM3DPOConfig` (default: "dpo")
|
8 |
+
- ✅ Updated config file generation in `launch.sh` to include trainer_type
|
9 |
+
|
10 |
+
### 2. Training Script Updates
|
11 |
+
- ✅ Added `--trainer_type` argument to `src/train.py`
|
12 |
+
- ✅ Added `--trainer-type` argument to `scripts/training/train.py`
|
13 |
+
- ✅ Implemented trainer selection logic in `src/train.py`
|
14 |
+
- ✅ Updated trainer instantiation to support both SFT and DPO
|
15 |
+
|
16 |
+
### 3. Launch Script Updates
|
17 |
+
- ✅ Added interactive trainer type selection (Step 3.5)
|
18 |
+
- ✅ Updated configuration summary to show trainer type
|
19 |
+
- ✅ Updated training parameters display to show trainer type
|
20 |
+
- ✅ Updated training script call to pass trainer_type argument
|
21 |
+
- ✅ Updated summary report to include trainer type
|
22 |
+
|
23 |
+
### 4. Documentation and Testing
|
24 |
+
- ✅ Created comprehensive `TRAINER_SELECTION_GUIDE.md`
|
25 |
+
- ✅ Created test script `tests/test_trainer_selection.py`
|
26 |
+
- ✅ All tests passing (3/3)
|
27 |
+
|
28 |
+
## 🎯 Key Features
|
29 |
+
|
30 |
+
### Interactive Selection
|
31 |
+
Users can now choose between SFT and DPO during the launch process:
|
32 |
+
```
|
33 |
+
Step 3.5: Trainer Type Selection
|
34 |
+
====================================
|
35 |
+
|
36 |
+
Select the type of training to perform:
|
37 |
+
1. SFT (Supervised Fine-tuning) - Standard instruction tuning
|
38 |
+
2. DPO (Direct Preference Optimization) - Preference-based training
|
39 |
+
```
|
40 |
+
|
41 |
+
### Command Line Override
|
42 |
+
Users can override the config's trainer type via command line:
|
43 |
+
```bash
|
44 |
+
python src/train.py config/train_smollm3.py --trainer_type dpo
|
45 |
+
python scripts/training/train.py --config config/train_smollm3.py --trainer-type dpo
|
46 |
+
```
|
47 |
+
|
48 |
+
### Configuration Priority
|
49 |
+
1. Command line argument (highest priority)
|
50 |
+
2. Config file trainer_type field (medium priority)
|
51 |
+
3. Default value "sft" (lowest priority)
|
52 |
+
|
53 |
+
### Automatic Trainer Selection
|
54 |
+
The system automatically selects the appropriate trainer:
|
55 |
+
- **SFT**: Uses `SmolLM3Trainer` with `SFTTrainer` backend
|
56 |
+
- **DPO**: Uses `SmolLM3DPOTrainer` with `DPOTrainer` backend
|
57 |
+
|
58 |
+
## 📋 Usage Examples
|
59 |
+
|
60 |
+
### Launch Script (Interactive)
|
61 |
+
```bash
|
62 |
+
./launch.sh
|
63 |
+
# Follow prompts and select SFT or DPO
|
64 |
+
```
|
65 |
+
|
66 |
+
### Direct Training
|
67 |
+
```bash
|
68 |
+
# SFT training (default)
|
69 |
+
python src/train.py config/train_smollm3.py
|
70 |
+
|
71 |
+
# DPO training
|
72 |
+
python src/train.py config/train_smollm3_dpo.py
|
73 |
+
|
74 |
+
# Override trainer type
|
75 |
+
python src/train.py config/train_smollm3.py --trainer_type dpo
|
76 |
+
```
|
77 |
+
|
78 |
+
### Training Script
|
79 |
+
```bash
|
80 |
+
# SFT training
|
81 |
+
python scripts/training/train.py --config config/train_smollm3.py
|
82 |
+
|
83 |
+
# DPO training with override
|
84 |
+
python scripts/training/train.py --config config/train_smollm3.py --trainer-type dpo
|
85 |
+
```
|
86 |
+
|
87 |
+
## 🔧 Technical Details
|
88 |
+
|
89 |
+
### Files Modified
|
90 |
+
1. `config/train_smollm3.py` - Added trainer_type field
|
91 |
+
2. `config/train_smollm3_dpo.py` - Added trainer_type field
|
92 |
+
3. `src/train.py` - Added trainer selection logic
|
93 |
+
4. `scripts/training/train.py` - Added trainer_type argument
|
94 |
+
5. `launch.sh` - Added interactive selection and config generation
|
95 |
+
6. `src/trainer.py` - Already had both trainer classes
|
96 |
+
|
97 |
+
### Files Created
|
98 |
+
1. `docs/TRAINER_SELECTION_GUIDE.md` - Comprehensive documentation
|
99 |
+
2. `tests/test_trainer_selection.py` - Test suite
|
100 |
+
3. `TRAINER_SELECTION_SUMMARY.md` - This summary
|
101 |
+
|
102 |
+
## ✅ Testing Results
|
103 |
+
```
|
104 |
+
🧪 Testing Trainer Selection Implementation
|
105 |
+
==================================================
|
106 |
+
Testing config trainer_type...
|
107 |
+
✅ Base config trainer_type: sft
|
108 |
+
✅ DPO config trainer_type: dpo
|
109 |
+
Testing trainer class existence...
|
110 |
+
✅ Trainer module imported successfully
|
111 |
+
✅ Both trainer classes exist
|
112 |
+
Testing config inheritance...
|
113 |
+
✅ DPO config properly inherits from base config
|
114 |
+
✅ Trainer type inheritance works correctly
|
115 |
+
==================================================
|
116 |
+
Tests passed: 3/3
|
117 |
+
🎉 All tests passed!
|
118 |
+
```
|
119 |
+
|
120 |
+
## 🚀 Next Steps
|
121 |
+
|
122 |
+
The trainer selection feature is now fully implemented and tested. Users can:
|
123 |
+
|
124 |
+
1. **Use the interactive launch script** to select SFT or DPO
|
125 |
+
2. **Override trainer type** via command line arguments
|
126 |
+
3. **Use DPO configs** that automatically select DPO trainer
|
127 |
+
4. **Monitor training** with the same Trackio integration for both trainers
|
128 |
+
|
129 |
+
The implementation maintains backward compatibility while adding the new trainer selection capability.
|
docs/UNIFIED_MODEL_CARD_GUIDE.md
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unified Model Card System Guide
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The unified model card system provides a template-based approach to generate comprehensive model cards that include information about both the main fine-tuned model and any quantized versions. This system ensures consistency across all model repositories and provides users with complete information about all available model variants.
|
6 |
+
|
7 |
+
## Architecture
|
8 |
+
|
9 |
+
### Template System
|
10 |
+
|
11 |
+
The system uses a template-based approach with the following components:
|
12 |
+
|
13 |
+
1. **Template File**: `templates/model_card.md` - Contains the master template with conditional sections
|
14 |
+
2. **Generator Script**: `scripts/model_tonic/generate_model_card.py` - Processes templates and variables
|
15 |
+
3. **Integration**: Updated push scripts that use the unified model card generator
|
16 |
+
|
17 |
+
### Key Features
|
18 |
+
|
19 |
+
- **Conditional Sections**: Template supports conditional rendering based on variables (e.g., quantized models)
|
20 |
+
- **Variable Substitution**: Dynamic content based on training configuration and results
|
21 |
+
- **Unified Repository Structure**: Single repository with subdirectories for quantized models
|
22 |
+
- **Comprehensive Documentation**: Complete usage examples and deployment information
|
23 |
+
|
24 |
+
## Template Structure
|
25 |
+
|
26 |
+
### Conditional Sections
|
27 |
+
|
28 |
+
The template uses Handlebars-style conditionals:
|
29 |
+
|
30 |
+
```markdown
|
31 |
+
{{#if quantized_models}}
|
32 |
+
### Quantized Models
|
33 |
+
|
34 |
+
This repository also includes quantized versions of the model for improved efficiency:
|
35 |
+
|
36 |
+
#### int8 Weight-Only Quantization (GPU Optimized)
|
37 |
+
```python
|
38 |
+
model = AutoModelForCausalLM.from_pretrained("{{repo_name}}/int8")
|
39 |
+
```
|
40 |
+
{{/if}}
|
41 |
+
```
|
42 |
+
|
43 |
+
### Template Variables
|
44 |
+
|
45 |
+
The template supports the following variables:
|
46 |
+
|
47 |
+
| Variable | Description | Example |
|
48 |
+
|----------|-------------|---------|
|
49 |
+
| `model_name` | Display name of the model | "SmolLM3 Fine-tuned Model" |
|
50 |
+
| `model_description` | Brief description | "A fine-tuned version of SmolLM3-3B..." |
|
51 |
+
| `repo_name` | Hugging Face repository name | "username/model-name" |
|
52 |
+
| `base_model` | Original model name | "HuggingFaceTB/SmolLM3-3B" |
|
53 |
+
| `dataset_name` | Training dataset | "OpenHermes-FR" |
|
54 |
+
| `training_config_type` | Training configuration | "H100 Lightweight" |
|
55 |
+
| `trainer_type` | Trainer used | "SFTTrainer" |
|
56 |
+
| `batch_size` | Training batch size | "8" |
|
57 |
+
| `learning_rate` | Learning rate | "5e-6" |
|
58 |
+
| `max_epochs` | Number of epochs | "3" |
|
59 |
+
| `max_seq_length` | Maximum sequence length | "2048" |
|
60 |
+
| `hardware_info` | Hardware used | "GPU (H100/A100)" |
|
61 |
+
| `experiment_name` | Experiment name | "smollm3-experiment" |
|
62 |
+
| `trackio_url` | Trackio monitoring URL | "https://trackio.space/exp" |
|
63 |
+
| `dataset_repo` | HF Dataset repository | "tonic/trackio-experiments" |
|
64 |
+
| `quantized_models` | Boolean for quantized models | `true` or `false` |
|
65 |
+
| `author_name` | Model author | "Your Name" |
|
66 |
+
|
67 |
+
## Repository Structure
|
68 |
+
|
69 |
+
### Single Repository Approach
|
70 |
+
|
71 |
+
Instead of creating separate repositories for quantized models, the system now uses a single repository with subdirectories:
|
72 |
+
|
73 |
+
```
|
74 |
+
username/model-name/
|
75 |
+
├── README.md (unified model card)
|
76 |
+
├── config.json
|
77 |
+
├── pytorch_model.bin
|
78 |
+
├── tokenizer.json
|
79 |
+
├── tokenizer_config.json
|
80 |
+
├── int8/ (quantized model for GPU)
|
81 |
+
│ ├── README.md
|
82 |
+
│ ├── config.json
|
83 |
+
│ └── pytorch_model.bin
|
84 |
+
└── int4/ (quantized model for CPU)
|
85 |
+
├── README.md
|
86 |
+
├── config.json
|
87 |
+
└── pytorch_model.bin
|
88 |
+
```
|
89 |
+
|
90 |
+
### Benefits
|
91 |
+
|
92 |
+
1. **Unified Documentation**: Single README with information about all model variants
|
93 |
+
2. **Easier Discovery**: Users find all model versions in one place
|
94 |
+
3. **Consistent Branding**: Single repository name and description
|
95 |
+
4. **Simplified Management**: One repository to maintain and update
|
96 |
+
|
97 |
+
## Usage
|
98 |
+
|
99 |
+
### Automatic Generation (via launch.sh)
|
100 |
+
|
101 |
+
The unified model card is automatically generated during the training pipeline:
|
102 |
+
|
103 |
+
```bash
|
104 |
+
# The launch script automatically generates the unified model card
|
105 |
+
./launch.sh
|
106 |
+
```
|
107 |
+
|
108 |
+
### Manual Generation
|
109 |
+
|
110 |
+
You can generate model cards manually using the generator script:
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python scripts/model_tonic/generate_model_card.py \
|
114 |
+
--repo-name "username/model-name" \
|
115 |
+
--model-name "My Fine-tuned Model" \
|
116 |
+
--experiment-name "my-experiment" \
|
117 |
+
--dataset-name "OpenHermes-FR" \
|
118 |
+
--training-config "H100 Lightweight" \
|
119 |
+
--batch-size "8" \
|
120 |
+
--learning-rate "5e-6" \
|
121 |
+
--max-epochs "3" \
|
122 |
+
--quantized-models \
|
123 |
+
--output "README.md"
|
124 |
+
```
|
125 |
+
|
126 |
+
### Integration with Push Script
|
127 |
+
|
128 |
+
The push script automatically uses the unified model card generator:
|
129 |
+
|
130 |
+
```python
|
131 |
+
# In push_to_huggingface.py
|
132 |
+
def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
133 |
+
"""Create a comprehensive model card using the unified template"""
|
134 |
+
try:
|
135 |
+
from scripts.model_tonic.generate_model_card import ModelCardGenerator
|
136 |
+
|
137 |
+
variables = {
|
138 |
+
"model_name": f"{self.repo_name.split('/')[-1]} - Fine-tuned SmolLM3",
|
139 |
+
"repo_name": self.repo_name,
|
140 |
+
"quantized_models": False, # Updated if quantized models are added
|
141 |
+
# ... other variables
|
142 |
+
}
|
143 |
+
|
144 |
+
generator = ModelCardGenerator()
|
145 |
+
return generator.generate_model_card(variables)
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
# Fallback to simple model card
|
149 |
+
return self._create_simple_model_card()
|
150 |
+
```
|
151 |
+
|
152 |
+
## Quantization Integration
|
153 |
+
|
154 |
+
### Quantized Model Cards
|
155 |
+
|
156 |
+
When quantized models are created, the system:
|
157 |
+
|
158 |
+
1. **Updates Main Model Card**: Sets `quantized_models = True` and includes usage examples
|
159 |
+
2. **Creates Subdirectory Cards**: Generates specific README files for each quantized version
|
160 |
+
3. **Maintains Consistency**: All cards reference the same repository structure
|
161 |
+
|
162 |
+
### Quantization Types
|
163 |
+
|
164 |
+
The system supports:
|
165 |
+
|
166 |
+
- **int8_weight_only**: GPU optimized, ~50% memory reduction
|
167 |
+
- **int4_weight_only**: CPU optimized, ~75% memory reduction
|
168 |
+
- **int8_dynamic**: Dynamic quantization for flexibility
|
169 |
+
|
170 |
+
### Usage Examples
|
171 |
+
|
172 |
+
```python
|
173 |
+
# Main model
|
174 |
+
model = AutoModelForCausalLM.from_pretrained("username/model-name")
|
175 |
+
|
176 |
+
# int8 quantized (GPU)
|
177 |
+
model = AutoModelForCausalLM.from_pretrained("username/model-name/int8")
|
178 |
+
|
179 |
+
# int4 quantized (CPU)
|
180 |
+
model = AutoModelForCausalLM.from_pretrained("username/model-name/int4")
|
181 |
+
```
|
182 |
+
|
183 |
+
## Template Customization
|
184 |
+
|
185 |
+
### Adding New Sections
|
186 |
+
|
187 |
+
To add new sections to the template:
|
188 |
+
|
189 |
+
1. **Edit Template**: Modify `templates/model_card.md`
|
190 |
+
2. **Add Variables**: Update the generator script with new variables
|
191 |
+
3. **Update Integration**: Modify push scripts to pass new variables
|
192 |
+
|
193 |
+
### Example: Adding Performance Metrics
|
194 |
+
|
195 |
+
```markdown
|
196 |
+
{{#if performance_metrics}}
|
197 |
+
## Performance Metrics
|
198 |
+
|
199 |
+
- **BLEU Score**: {{bleu_score}}
|
200 |
+
- **ROUGE Score**: {{rouge_score}}
|
201 |
+
- **Perplexity**: {{perplexity}}
|
202 |
+
{{/if}}
|
203 |
+
```
|
204 |
+
|
205 |
+
### Conditional Logic
|
206 |
+
|
207 |
+
The template supports complex conditional logic:
|
208 |
+
|
209 |
+
```markdown
|
210 |
+
{{#if quantized_models}}
|
211 |
+
{{#if int8_available}}
|
212 |
+
### int8 Quantized Model
|
213 |
+
{{/if}}
|
214 |
+
{{#if int4_available}}
|
215 |
+
### int4 Quantized Model
|
216 |
+
{{/if}}
|
217 |
+
{{/if}}
|
218 |
+
```
|
219 |
+
|
220 |
+
## Best Practices
|
221 |
+
|
222 |
+
### Template Design
|
223 |
+
|
224 |
+
1. **Clear Structure**: Use consistent headings and organization
|
225 |
+
2. **Comprehensive Information**: Include all relevant model details
|
226 |
+
3. **Usage Examples**: Provide clear code examples
|
227 |
+
4. **Limitations**: Document model limitations and biases
|
228 |
+
5. **Citations**: Include proper citations and acknowledgments
|
229 |
+
|
230 |
+
### Variable Management
|
231 |
+
|
232 |
+
1. **Default Values**: Provide sensible defaults for all variables
|
233 |
+
2. **Validation**: Validate variable types and ranges
|
234 |
+
3. **Documentation**: Document all available variables
|
235 |
+
4. **Fallbacks**: Provide fallback content for missing variables
|
236 |
+
|
237 |
+
### Repository Organization
|
238 |
+
|
239 |
+
1. **Single Repository**: Use one repository per model family
|
240 |
+
2. **Clear Subdirectories**: Use descriptive subdirectory names
|
241 |
+
3. **Consistent Naming**: Follow consistent naming conventions
|
242 |
+
4. **Documentation**: Maintain comprehensive documentation
|
243 |
+
|
244 |
+
## Troubleshooting
|
245 |
+
|
246 |
+
### Common Issues
|
247 |
+
|
248 |
+
1. **Template Not Found**: Ensure `templates/model_card.md` exists
|
249 |
+
2. **Variable Errors**: Check that all required variables are provided
|
250 |
+
3. **Conditional Issues**: Verify conditional syntax and logic
|
251 |
+
4. **Import Errors**: Ensure all dependencies are installed
|
252 |
+
|
253 |
+
### Debugging
|
254 |
+
|
255 |
+
```bash
|
256 |
+
# Test template generation
|
257 |
+
python scripts/model_tonic/generate_model_card.py \
|
258 |
+
--repo-name "test/model" \
|
259 |
+
--output "test_readme.md" \
|
260 |
+
--debug
|
261 |
+
```
|
262 |
+
|
263 |
+
### Validation
|
264 |
+
|
265 |
+
The system includes validation for:
|
266 |
+
|
267 |
+
- Template file existence
|
268 |
+
- Required variables
|
269 |
+
- Conditional syntax
|
270 |
+
- Output file permissions
|
271 |
+
|
272 |
+
## Future Enhancements
|
273 |
+
|
274 |
+
### Planned Features
|
275 |
+
|
276 |
+
1. **Multiple Template Support**: Support for different template types
|
277 |
+
2. **Advanced Conditionals**: More complex conditional logic
|
278 |
+
3. **Template Inheritance**: Base templates with extensions
|
279 |
+
4. **Auto-Detection**: Automatic detection of model features
|
280 |
+
5. **Custom Sections**: User-defined template sections
|
281 |
+
|
282 |
+
### Extensibility
|
283 |
+
|
284 |
+
The system is designed to be easily extensible:
|
285 |
+
|
286 |
+
- **Plugin Architecture**: Support for custom template processors
|
287 |
+
- **Variable Sources**: Multiple sources for template variables
|
288 |
+
- **Output Formats**: Support for different output formats
|
289 |
+
- **Integration Points**: Easy integration with other tools
|
290 |
+
|
291 |
+
## Conclusion
|
292 |
+
|
293 |
+
The unified model card system provides a comprehensive, maintainable approach to model documentation. By using templates and conditional sections, it ensures consistency while providing flexibility for different model configurations and quantization options.
|
294 |
+
|
295 |
+
The single repository approach with subdirectories simplifies model management and improves user experience by providing all model variants in one location with unified documentation.
|
docs/UNIFIED_REPOSITORY_STRUCTURE_SUMMARY.md
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unified Repository Structure Implementation Summary
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This document summarizes the implementation of a unified repository structure where all models (main and quantized) are stored in a single Hugging Face repository with quantized models in subdirectories.
|
6 |
+
|
7 |
+
## Key Changes Made
|
8 |
+
|
9 |
+
### 1. Repository Structure
|
10 |
+
|
11 |
+
**Before:**
|
12 |
+
```
|
13 |
+
your-username/model-name/ (main model)
|
14 |
+
your-username/model-name-int8/ (int8 quantized)
|
15 |
+
your-username/model-name-int4/ (int4 quantized)
|
16 |
+
```
|
17 |
+
|
18 |
+
**After:**
|
19 |
+
```
|
20 |
+
your-username/model-name/
|
21 |
+
├── README.md (unified model card)
|
22 |
+
├── config.json
|
23 |
+
├── pytorch_model.bin
|
24 |
+
├── tokenizer.json
|
25 |
+
├── int8/ (quantized model for GPU)
|
26 |
+
│ ├── README.md
|
27 |
+
│ ├── config.json
|
28 |
+
│ └── pytorch_model.bin
|
29 |
+
└── int4/ (quantized model for CPU)
|
30 |
+
├── README.md
|
31 |
+
├── config.json
|
32 |
+
└── pytorch_model.bin
|
33 |
+
```
|
34 |
+
|
35 |
+
### 2. New Files Created
|
36 |
+
|
37 |
+
#### `templates/model_card.md`
|
38 |
+
- Comprehensive model card template with conditional sections
|
39 |
+
- Supports both main model and quantized versions
|
40 |
+
- Includes usage examples for all model versions
|
41 |
+
- Template variables for dynamic content generation
|
42 |
+
|
43 |
+
#### `scripts/model_tonic/generate_model_card.py`
|
44 |
+
- Model card generator using the template
|
45 |
+
- Handles conditional sections and variable replacement
|
46 |
+
- Supports command-line arguments for customization
|
47 |
+
- Fallback to simple model card if template fails
|
48 |
+
|
49 |
+
### 3. Updated Files
|
50 |
+
|
51 |
+
#### `scripts/model_tonic/quantize_model.py`
|
52 |
+
- **Fixed f-string errors**: Escaped curly braces in citation URLs
|
53 |
+
- **Updated model card generation**: Uses subdirectory-aware URLs
|
54 |
+
- **Modified push logic**: Uploads to subdirectories within the same repository
|
55 |
+
- **Enhanced README generation**: References correct subdirectory paths
|
56 |
+
|
57 |
+
#### `scripts/model_tonic/push_to_huggingface.py`
|
58 |
+
- **Integrated unified model card**: Uses the new template-based generator
|
59 |
+
- **Enhanced variable handling**: Passes training configuration to template
|
60 |
+
- **Improved error handling**: Fallback to simple model card if template fails
|
61 |
+
- **Better integration**: Works with the new unified structure
|
62 |
+
|
63 |
+
#### `launch.sh`
|
64 |
+
- **Updated quantization section**: Uses same repository for all models
|
65 |
+
- **Modified summary reports**: Reflects new subdirectory structure
|
66 |
+
- **Improved user feedback**: Shows correct URLs for all model versions
|
67 |
+
- **Streamlined workflow**: Single repository management
|
68 |
+
|
69 |
+
#### `docs/QUANTIZATION_GUIDE.md`
|
70 |
+
- **Complete rewrite**: Reflects new unified structure
|
71 |
+
- **Updated examples**: Shows correct loading paths
|
72 |
+
- **Enhanced documentation**: Covers repository structure and usage
|
73 |
+
- **Improved troubleshooting**: Addresses new structure-specific issues
|
74 |
+
|
75 |
+
#### `README.md`
|
76 |
+
- **Updated quantization section**: Shows unified repository structure
|
77 |
+
- **Enhanced examples**: Demonstrates loading from subdirectories
|
78 |
+
- **Improved clarity**: Better explanation of the new structure
|
79 |
+
|
80 |
+
### 4. Key Features Implemented
|
81 |
+
|
82 |
+
#### Unified Model Card
|
83 |
+
- Single README.md covers all model versions
|
84 |
+
- Conditional sections for quantized models
|
85 |
+
- Comprehensive usage examples
|
86 |
+
- Training information and configuration details
|
87 |
+
|
88 |
+
#### Subdirectory Management
|
89 |
+
- Quantized models stored in `/int8/` and `/int4/` subdirectories
|
90 |
+
- Separate README files for each quantized version
|
91 |
+
- Proper file organization and structure
|
92 |
+
|
93 |
+
#### Template System
|
94 |
+
- Handlebars-style template with conditionals
|
95 |
+
- Variable replacement for dynamic content
|
96 |
+
- Support for complex nested structures
|
97 |
+
- Error handling and fallback mechanisms
|
98 |
+
|
99 |
+
#### Enhanced User Experience
|
100 |
+
- Clear repository structure documentation
|
101 |
+
- Simplified model loading examples
|
102 |
+
- Better error messages and feedback
|
103 |
+
- Comprehensive troubleshooting guide
|
104 |
+
|
105 |
+
## Technical Implementation Details
|
106 |
+
|
107 |
+
### Template Processing
|
108 |
+
```python
|
109 |
+
# Conditional sections
|
110 |
+
{{#if quantized_models}}
|
111 |
+
### Quantized Models
|
112 |
+
...
|
113 |
+
{{/if}}
|
114 |
+
|
115 |
+
# Variable replacement
|
116 |
+
model = AutoModelForCausalLM.from_pretrained("{{repo_name}}/int8")
|
117 |
+
```
|
118 |
+
|
119 |
+
### Subdirectory Upload Logic
|
120 |
+
```python
|
121 |
+
# Determine subdirectory
|
122 |
+
if quant_type == "int8_weight_only":
|
123 |
+
subdir = "int8"
|
124 |
+
elif quant_type == "int4_weight_only":
|
125 |
+
subdir = "int4"
|
126 |
+
|
127 |
+
# Upload to subdirectory
|
128 |
+
repo_path = f"{subdir}/{relative_path}"
|
129 |
+
upload_file(
|
130 |
+
path_or_fileobj=str(file_path),
|
131 |
+
path_in_repo=repo_path,
|
132 |
+
repo_id=self.repo_name,
|
133 |
+
token=self.token
|
134 |
+
)
|
135 |
+
```
|
136 |
+
|
137 |
+
### Launch Script Integration
|
138 |
+
```bash
|
139 |
+
# Create quantized models in same repository
|
140 |
+
python scripts/model_tonic/quantize_model.py /output-checkpoint "$REPO_NAME" \
|
141 |
+
--quant-type "$QUANT_TYPE" \
|
142 |
+
--device "$DEVICE" \
|
143 |
+
--token "$HF_TOKEN"
|
144 |
+
```
|
145 |
+
|
146 |
+
## Benefits of the New Structure
|
147 |
+
|
148 |
+
### 1. Simplified Management
|
149 |
+
- Single repository for all model versions
|
150 |
+
- Easier to track and manage
|
151 |
+
- Reduced repository clutter
|
152 |
+
- Unified documentation
|
153 |
+
|
154 |
+
### 2. Better User Experience
|
155 |
+
- Clear loading paths for all versions
|
156 |
+
- Comprehensive model card with all information
|
157 |
+
- Consistent URL structure
|
158 |
+
- Simplified deployment
|
159 |
+
|
160 |
+
### 3. Enhanced Documentation
|
161 |
+
- Single source of truth for model information
|
162 |
+
- Conditional sections for different versions
|
163 |
+
- Comprehensive usage examples
|
164 |
+
- Better discoverability
|
165 |
+
|
166 |
+
### 4. Improved Workflow
|
167 |
+
- Streamlined quantization process
|
168 |
+
- Reduced configuration complexity
|
169 |
+
- Better integration with existing pipeline
|
170 |
+
- Enhanced monitoring and tracking
|
171 |
+
|
172 |
+
## Usage Examples
|
173 |
+
|
174 |
+
### Loading Models
|
175 |
+
```python
|
176 |
+
# Main model
|
177 |
+
model = AutoModelForCausalLM.from_pretrained("your-username/model-name")
|
178 |
+
|
179 |
+
# int8 quantized (GPU)
|
180 |
+
model = AutoModelForCausalLM.from_pretrained("your-username/model-name/int8")
|
181 |
+
|
182 |
+
# int4 quantized (CPU)
|
183 |
+
model = AutoModelForCausalLM.from_pretrained("your-username/model-name/int4")
|
184 |
+
```
|
185 |
+
|
186 |
+
### Pipeline Usage
|
187 |
+
```bash
|
188 |
+
# Run full pipeline with quantization
|
189 |
+
./launch.sh
|
190 |
+
# Choose quantization options when prompted
|
191 |
+
# All models will be in the same repository
|
192 |
+
```
|
193 |
+
|
194 |
+
### Standalone Quantization
|
195 |
+
```bash
|
196 |
+
# Quantize existing model
|
197 |
+
python scripts/model_tonic/quantize_standalone.py \
|
198 |
+
/path/to/model your-username/model-name \
|
199 |
+
--quant-type int8_weight_only
|
200 |
+
```
|
201 |
+
|
202 |
+
## Migration Guide
|
203 |
+
|
204 |
+
### For Existing Users
|
205 |
+
1. **Update loading code**: Change from separate repositories to subdirectories
|
206 |
+
2. **Update documentation**: Reference new unified structure
|
207 |
+
3. **Test quantized models**: Verify loading from subdirectories works
|
208 |
+
4. **Update deployment scripts**: Use new repository structure
|
209 |
+
|
210 |
+
### For New Users
|
211 |
+
1. **Follow the new structure**: All models in single repository
|
212 |
+
2. **Use the unified model card**: Comprehensive documentation included
|
213 |
+
3. **Leverage subdirectories**: Clear organization of model versions
|
214 |
+
4. **Benefit from simplified workflow**: Easier management and deployment
|
215 |
+
|
216 |
+
## Testing and Validation
|
217 |
+
|
218 |
+
### Test Files
|
219 |
+
- `tests/test_quantization.py`: Validates quantization functionality
|
220 |
+
- Template processing: Ensures correct variable replacement
|
221 |
+
- Subdirectory upload: Verifies proper file organization
|
222 |
+
- Model loading: Tests all model versions
|
223 |
+
|
224 |
+
### Validation Checklist
|
225 |
+
- [x] Template processing works correctly
|
226 |
+
- [x] Subdirectory uploads function properly
|
227 |
+
- [x] Model cards generate with correct URLs
|
228 |
+
- [x] Launch script integration works
|
229 |
+
- [x] Documentation is updated and accurate
|
230 |
+
- [x] Error handling is robust
|
231 |
+
- [x] Fallback mechanisms work
|
232 |
+
|
233 |
+
## Future Enhancements
|
234 |
+
|
235 |
+
### Potential Improvements
|
236 |
+
1. **Additional quantization types**: Support for more quantization methods
|
237 |
+
2. **Enhanced template system**: More complex conditional logic
|
238 |
+
3. **Automated testing**: Comprehensive test suite for all features
|
239 |
+
4. **Performance optimization**: Faster quantization and upload processes
|
240 |
+
5. **Better monitoring**: Enhanced tracking and metrics
|
241 |
+
|
242 |
+
### Extension Points
|
243 |
+
1. **Custom quantization configs**: User-defined quantization parameters
|
244 |
+
2. **Batch processing**: Multiple model quantization
|
245 |
+
3. **Advanced templates**: More sophisticated model card generation
|
246 |
+
4. **Integration with other tools**: Support for additional deployment options
|
247 |
+
|
248 |
+
## Conclusion
|
249 |
+
|
250 |
+
The unified repository structure provides a cleaner, more manageable approach to model deployment and quantization. The implementation includes comprehensive documentation, robust error handling, and a streamlined user experience that makes it easier to work with multiple model versions while maintaining a single source of truth for all model-related information.
|
251 |
+
|
252 |
+
The new structure significantly improves the user experience while maintaining backward compatibility and providing clear migration paths for existing users. The enhanced documentation and simplified workflow make the quantization feature more accessible and easier to use.
|
docs/USERNAME_EXTRACTION_FIX.md
CHANGED
@@ -70,7 +70,7 @@ def get_username_from_cli(token: str) -> str:
|
|
70 |
|
71 |
# Get username using CLI
|
72 |
result = subprocess.run(
|
73 |
-
["
|
74 |
capture_output=True,
|
75 |
text=True,
|
76 |
timeout=30
|
@@ -203,7 +203,7 @@ If username extraction still fails:
|
|
203 |
|
204 |
1. **Check Token**: Ensure HF_TOKEN is valid and has proper permissions
|
205 |
2. **Check Network**: Ensure internet connection is stable
|
206 |
-
3. **Check CLI**: Ensure `
|
207 |
4. **Manual Override**: Can manually set username in scripts if needed
|
208 |
|
209 |
## 📋 **Summary**
|
|
|
70 |
|
71 |
# Get username using CLI
|
72 |
result = subprocess.run(
|
73 |
+
["hf", "whoami"],
|
74 |
capture_output=True,
|
75 |
text=True,
|
76 |
timeout=30
|
|
|
203 |
|
204 |
1. **Check Token**: Ensure HF_TOKEN is valid and has proper permissions
|
205 |
2. **Check Network**: Ensure internet connection is stable
|
206 |
+
3. **Check CLI**: Ensure `hf` is installed and working
|
207 |
4. **Manual Override**: Can manually set username in scripts if needed
|
208 |
|
209 |
## 📋 **Summary**
|
launch.sh
CHANGED
@@ -91,9 +91,9 @@ validate_hf_token_and_get_username() {
|
|
91 |
|
92 |
# Test the token and get username
|
93 |
export HF_TOKEN="$token"
|
94 |
-
if
|
95 |
# Get username from whoami command
|
96 |
-
HF_USERNAME=$(
|
97 |
return 0
|
98 |
else
|
99 |
return 1
|
@@ -229,6 +229,9 @@ Optimized for: $TRAINING_CONFIG_TYPE
|
|
229 |
from config.train_smollm3 import SmolLM3Config
|
230 |
|
231 |
config = SmolLM3Config(
|
|
|
|
|
|
|
232 |
# Model configuration
|
233 |
model_name="$MODEL_NAME",
|
234 |
max_seq_length=$MAX_SEQ_LENGTH,
|
@@ -341,6 +344,24 @@ get_input "Experiment name" "smollm3_finetune_$(date +%Y%m%d_%H%M%S)" EXPERIMENT
|
|
341 |
get_input "Model repository name" "$HF_USERNAME/smollm3-finetuned-$(date +%Y%m%d)" REPO_NAME
|
342 |
get_input "Trackio dataset repository" "$HF_USERNAME/trackio-experiments" TRACKIO_DATASET_REPO
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
# Step 4: Training parameters
|
345 |
print_step "Step 4: Training Parameters"
|
346 |
echo "==============================="
|
@@ -348,6 +369,7 @@ echo "==============================="
|
|
348 |
echo "Current configuration:"
|
349 |
echo " Model: $MODEL_NAME"
|
350 |
echo " Dataset: $DATASET_NAME"
|
|
|
351 |
if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
352 |
echo " Dataset Sample Size: ${DATASET_SAMPLE_SIZE:-80000}"
|
353 |
fi
|
@@ -380,6 +402,7 @@ echo " Experiment: $EXPERIMENT_NAME"
|
|
380 |
echo " Model: $MODEL_NAME"
|
381 |
echo " Dataset: $DATASET_NAME"
|
382 |
echo " Training Config: $TRAINING_CONFIG_TYPE"
|
|
|
383 |
if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
384 |
echo " Dataset Sample Size: ${DATASET_SAMPLE_SIZE:-80000}"
|
385 |
fi
|
@@ -453,9 +476,9 @@ export TRACKIO_DATASET_REPO="$TRACKIO_DATASET_REPO"
|
|
453 |
|
454 |
# Login to Hugging Face with token
|
455 |
print_info "Logging in to Hugging Face..."
|
456 |
-
if
|
457 |
print_status "Successfully logged in to Hugging Face"
|
458 |
-
print_info "Username: $(
|
459 |
else
|
460 |
print_error "Failed to login to Hugging Face"
|
461 |
print_error "Please check your token and try again"
|
@@ -502,7 +525,7 @@ python deploy_trackio_space.py << EOF
|
|
502 |
$TRACKIO_SPACE_NAME
|
503 |
$HF_TOKEN
|
504 |
$GIT_EMAIL
|
505 |
-
|
506 |
EOF
|
507 |
|
508 |
print_status "Trackio Space deployed: $TRACKIO_URL"
|
@@ -569,7 +592,8 @@ python scripts/training/train.py \
|
|
569 |
--config "$CONFIG_FILE" \
|
570 |
--experiment-name "$EXPERIMENT_NAME" \
|
571 |
--output-dir /output-checkpoint \
|
572 |
-
--trackio-url "$TRACKIO_URL"
|
|
|
573 |
|
574 |
# Step 16: Push model to Hugging Face Hub
|
575 |
print_step "Step 16: Pushing Model to HF Hub"
|
@@ -585,6 +609,72 @@ python scripts/model_tonic/push_to_huggingface.py /output-checkpoint "$REPO_NAME
|
|
585 |
--experiment-name "$EXPERIMENT_NAME" \
|
586 |
--dataset-repo "$TRACKIO_DATASET_REPO"
|
587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
# Step 17: Create summary report
|
589 |
print_step "Step 17: Creating Summary Report"
|
590 |
echo "===================================="
|
@@ -600,6 +690,7 @@ cat > training_summary.md << EOF
|
|
600 |
- **Trackio Space**: $TRACKIO_URL
|
601 |
- **HF Dataset**: $TRACKIO_DATASET_REPO
|
602 |
- **Training Config**: $TRAINING_CONFIG_TYPE
|
|
|
603 |
$(if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
604 |
echo "- **Dataset Sample Size**: ${DATASET_SAMPLE_SIZE:-80000}"
|
605 |
fi)
|
@@ -615,6 +706,15 @@ fi)
|
|
615 |
- **Model Repository**: https://huggingface.co/$REPO_NAME
|
616 |
- **Trackio Monitoring**: $TRACKIO_URL
|
617 |
- **Experiment Data**: https://huggingface.co/datasets/$TRACKIO_DATASET_REPO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
|
619 |
## Next Steps
|
620 |
1. Monitor training progress in your Trackio Space
|
@@ -640,6 +740,16 @@ echo "📊 Model: https://huggingface.co/$REPO_NAME"
|
|
640 |
echo "📈 Trackio: $TRACKIO_URL"
|
641 |
echo "📋 Experiment: $EXPERIMENT_NAME"
|
642 |
echo "📊 Dataset: https://huggingface.co/datasets/$TRACKIO_DATASET_REPO"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
echo ""
|
644 |
echo "📋 Summary report saved to: training_summary.md"
|
645 |
echo ""
|
|
|
91 |
|
92 |
# Test the token and get username
|
93 |
export HF_TOKEN="$token"
|
94 |
+
if hf whoami >/dev/null 2>&1; then
|
95 |
# Get username from whoami command
|
96 |
+
HF_USERNAME=$(hf whoami | head -n1 | tr -d '\n')
|
97 |
return 0
|
98 |
else
|
99 |
return 1
|
|
|
229 |
from config.train_smollm3 import SmolLM3Config
|
230 |
|
231 |
config = SmolLM3Config(
|
232 |
+
# Trainer type selection
|
233 |
+
trainer_type="$TRAINER_TYPE",
|
234 |
+
|
235 |
# Model configuration
|
236 |
model_name="$MODEL_NAME",
|
237 |
max_seq_length=$MAX_SEQ_LENGTH,
|
|
|
344 |
get_input "Model repository name" "$HF_USERNAME/smollm3-finetuned-$(date +%Y%m%d)" REPO_NAME
|
345 |
get_input "Trackio dataset repository" "$HF_USERNAME/trackio-experiments" TRACKIO_DATASET_REPO
|
346 |
|
347 |
+
# Step 3.5: Select trainer type
|
348 |
+
print_step "Step 3.5: Trainer Type Selection"
|
349 |
+
echo "===================================="
|
350 |
+
|
351 |
+
echo "Select the type of training to perform:"
|
352 |
+
echo "1. SFT (Supervised Fine-tuning) - Standard instruction tuning"
|
353 |
+
echo " - Uses SFTTrainer for instruction following"
|
354 |
+
echo " - Suitable for most fine-tuning tasks"
|
355 |
+
echo " - Optimized for instruction datasets"
|
356 |
+
echo ""
|
357 |
+
echo "2. DPO (Direct Preference Optimization) - Preference-based training"
|
358 |
+
echo " - Uses DPOTrainer for preference learning"
|
359 |
+
echo " - Requires preference datasets (chosen/rejected pairs)"
|
360 |
+
echo " - Optimizes for human preferences"
|
361 |
+
echo ""
|
362 |
+
|
363 |
+
select_option "Select trainer type:" "SFT" "DPO" TRAINER_TYPE
|
364 |
+
|
365 |
# Step 4: Training parameters
|
366 |
print_step "Step 4: Training Parameters"
|
367 |
echo "==============================="
|
|
|
369 |
echo "Current configuration:"
|
370 |
echo " Model: $MODEL_NAME"
|
371 |
echo " Dataset: $DATASET_NAME"
|
372 |
+
echo " Trainer Type: $TRAINER_TYPE"
|
373 |
if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
374 |
echo " Dataset Sample Size: ${DATASET_SAMPLE_SIZE:-80000}"
|
375 |
fi
|
|
|
402 |
echo " Model: $MODEL_NAME"
|
403 |
echo " Dataset: $DATASET_NAME"
|
404 |
echo " Training Config: $TRAINING_CONFIG_TYPE"
|
405 |
+
echo " Trainer Type: $TRAINER_TYPE"
|
406 |
if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
407 |
echo " Dataset Sample Size: ${DATASET_SAMPLE_SIZE:-80000}"
|
408 |
fi
|
|
|
476 |
|
477 |
# Login to Hugging Face with token
|
478 |
print_info "Logging in to Hugging Face..."
|
479 |
+
if hf login --token "$HF_TOKEN" --add-to-git-credential; then
|
480 |
print_status "Successfully logged in to Hugging Face"
|
481 |
+
print_info "Username: $(hf whoami)"
|
482 |
else
|
483 |
print_error "Failed to login to Hugging Face"
|
484 |
print_error "Please check your token and try again"
|
|
|
525 |
$TRACKIO_SPACE_NAME
|
526 |
$HF_TOKEN
|
527 |
$GIT_EMAIL
|
528 |
+
|
529 |
EOF
|
530 |
|
531 |
print_status "Trackio Space deployed: $TRACKIO_URL"
|
|
|
592 |
--config "$CONFIG_FILE" \
|
593 |
--experiment-name "$EXPERIMENT_NAME" \
|
594 |
--output-dir /output-checkpoint \
|
595 |
+
--trackio-url "$TRACKIO_URL" \
|
596 |
+
--trainer-type "$TRAINER_TYPE"
|
597 |
|
598 |
# Step 16: Push model to Hugging Face Hub
|
599 |
print_step "Step 16: Pushing Model to HF Hub"
|
|
|
609 |
--experiment-name "$EXPERIMENT_NAME" \
|
610 |
--dataset-repo "$TRACKIO_DATASET_REPO"
|
611 |
|
612 |
+
# Step 16.5: Quantization Options
|
613 |
+
print_step "Step 16.5: Model Quantization Options"
|
614 |
+
echo "=========================================="
|
615 |
+
|
616 |
+
print_info "Would you like to create quantized versions of your model?"
|
617 |
+
print_info "Quantization reduces model size and improves inference speed."
|
618 |
+
|
619 |
+
# Ask about quantization
|
620 |
+
get_input "Create quantized models? (y/n)" "y" "CREATE_QUANTIZED"
|
621 |
+
|
622 |
+
if [ "$CREATE_QUANTIZED" = "y" ] || [ "$CREATE_QUANTIZED" = "Y" ]; then
|
623 |
+
print_info "Quantization options:"
|
624 |
+
print_info "1. int8_weight_only (GPU optimized, ~50% memory reduction)"
|
625 |
+
print_info "2. int4_weight_only (CPU optimized, ~75% memory reduction)"
|
626 |
+
print_info "3. Both int8 and int4 versions"
|
627 |
+
|
628 |
+
select_option "Select quantization type:" "int8_weight_only" "int4_weight_only" "both" "QUANT_TYPE"
|
629 |
+
|
630 |
+
if [ "$QUANT_TYPE" = "both" ]; then
|
631 |
+
# Create both int8 and int4 versions in the same repository
|
632 |
+
print_info "Creating int8 (GPU) quantized model..."
|
633 |
+
python scripts/model_tonic/quantize_model.py /output-checkpoint "$REPO_NAME" \
|
634 |
+
--quant-type "int8_weight_only" \
|
635 |
+
--device "auto" \
|
636 |
+
--token "$HF_TOKEN" \
|
637 |
+
--trackio-url "$TRACKIO_URL" \
|
638 |
+
--experiment-name "${EXPERIMENT_NAME}-int8" \
|
639 |
+
--dataset-repo "$TRACKIO_DATASET_REPO"
|
640 |
+
|
641 |
+
print_info "Creating int4 (CPU) quantized model..."
|
642 |
+
python scripts/model_tonic/quantize_model.py /output-checkpoint "$REPO_NAME" \
|
643 |
+
--quant-type "int4_weight_only" \
|
644 |
+
--device "cpu" \
|
645 |
+
--token "$HF_TOKEN" \
|
646 |
+
--trackio-url "$TRACKIO_URL" \
|
647 |
+
--experiment-name "${EXPERIMENT_NAME}-int4" \
|
648 |
+
--dataset-repo "$TRACKIO_DATASET_REPO"
|
649 |
+
|
650 |
+
print_status "✅ Both quantized models created in the same repository:"
|
651 |
+
print_info "Main model: https://huggingface.co/$REPO_NAME"
|
652 |
+
print_info "int8 (GPU): https://huggingface.co/$REPO_NAME/int8"
|
653 |
+
print_info "int4 (CPU): https://huggingface.co/$REPO_NAME/int4"
|
654 |
+
|
655 |
+
else
|
656 |
+
# Create single quantized version in the same repository
|
657 |
+
print_info "Creating ${QUANT_TYPE} quantized model..."
|
658 |
+
|
659 |
+
DEVICE="auto"
|
660 |
+
if [ "$QUANT_TYPE" = "int4_weight_only" ]; then
|
661 |
+
DEVICE="cpu"
|
662 |
+
fi
|
663 |
+
|
664 |
+
python scripts/model_tonic/quantize_model.py /output-checkpoint "$REPO_NAME" \
|
665 |
+
--quant-type "$QUANT_TYPE" \
|
666 |
+
--device "$DEVICE" \
|
667 |
+
--token "$HF_TOKEN" \
|
668 |
+
--trackio-url "$TRACKIO_URL" \
|
669 |
+
--experiment-name "${EXPERIMENT_NAME}-${QUANT_TYPE}" \
|
670 |
+
--dataset-repo "$TRACKIO_DATASET_REPO"
|
671 |
+
|
672 |
+
print_status "✅ Quantized model created: https://huggingface.co/$REPO_NAME/${QUANT_TYPE//_/-}"
|
673 |
+
fi
|
674 |
+
else
|
675 |
+
print_info "Skipping quantization"
|
676 |
+
fi
|
677 |
+
|
678 |
# Step 17: Create summary report
|
679 |
print_step "Step 17: Creating Summary Report"
|
680 |
echo "===================================="
|
|
|
690 |
- **Trackio Space**: $TRACKIO_URL
|
691 |
- **HF Dataset**: $TRACKIO_DATASET_REPO
|
692 |
- **Training Config**: $TRAINING_CONFIG_TYPE
|
693 |
+
- **Trainer Type**: $TRAINER_TYPE
|
694 |
$(if [ "$TRAINING_CONFIG_TYPE" = "H100 Lightweight (Rapid)" ]; then
|
695 |
echo "- **Dataset Sample Size**: ${DATASET_SAMPLE_SIZE:-80000}"
|
696 |
fi)
|
|
|
706 |
- **Model Repository**: https://huggingface.co/$REPO_NAME
|
707 |
- **Trackio Monitoring**: $TRACKIO_URL
|
708 |
- **Experiment Data**: https://huggingface.co/datasets/$TRACKIO_DATASET_REPO
|
709 |
+
$(if [ "$CREATE_QUANTIZED" = "y" ] || [ "$CREATE_QUANTIZED" = "Y" ]; then
|
710 |
+
echo "- **Quantization**: $QUANT_TYPE"
|
711 |
+
if [ "$QUANT_TYPE" = "both" ]; then
|
712 |
+
echo "- **int8 Model (GPU)**: https://huggingface.co/$REPO_NAME/int8"
|
713 |
+
echo "- **int4 Model (CPU)**: https://huggingface.co/$REPO_NAME/int4"
|
714 |
+
else
|
715 |
+
echo "- **Quantized Model**: https://huggingface.co/$REPO_NAME/${QUANT_TYPE//_/-}"
|
716 |
+
fi
|
717 |
+
fi)
|
718 |
|
719 |
## Next Steps
|
720 |
1. Monitor training progress in your Trackio Space
|
|
|
740 |
echo "📈 Trackio: $TRACKIO_URL"
|
741 |
echo "📋 Experiment: $EXPERIMENT_NAME"
|
742 |
echo "📊 Dataset: https://huggingface.co/datasets/$TRACKIO_DATASET_REPO"
|
743 |
+
$(if [ "$CREATE_QUANTIZED" = "y" ] || [ "$CREATE_QUANTIZED" = "Y" ]; then
|
744 |
+
echo ""
|
745 |
+
echo "🔧 Quantized Models:"
|
746 |
+
if [ "$QUANT_TYPE" = "both" ]; then
|
747 |
+
echo " 📊 int8 (GPU): https://huggingface.co/$REPO_NAME/int8"
|
748 |
+
echo " 📊 int4 (CPU): https://huggingface.co/$REPO_NAME/int4"
|
749 |
+
else
|
750 |
+
echo " 📊 $QUANT_TYPE: https://huggingface.co/$REPO_NAME/${QUANT_TYPE//_/-}"
|
751 |
+
fi
|
752 |
+
fi)
|
753 |
echo ""
|
754 |
echo "📋 Summary report saved to: training_summary.md"
|
755 |
echo ""
|
requirements/requirements.txt
CHANGED
@@ -12,6 +12,7 @@ tokenizers>=0.13.0
|
|
12 |
# Training and optimization
|
13 |
flash-attn>=2.0.0
|
14 |
bitsandbytes>=0.41.0
|
|
|
15 |
|
16 |
# Basic utilities
|
17 |
numpy>=1.24.0
|
|
|
12 |
# Training and optimization
|
13 |
flash-attn>=2.0.0
|
14 |
bitsandbytes>=0.41.0
|
15 |
+
torchao>=0.10.0
|
16 |
|
17 |
# Basic utilities
|
18 |
numpy>=1.24.0
|
scripts/dataset_tonic/setup_hf_dataset.py
CHANGED
@@ -53,7 +53,7 @@ def get_username_from_cli(token: str) -> str:
|
|
53 |
|
54 |
# Get username using CLI
|
55 |
result = subprocess.run(
|
56 |
-
["
|
57 |
capture_output=True,
|
58 |
text=True,
|
59 |
timeout=30
|
|
|
53 |
|
54 |
# Get username using CLI
|
55 |
result = subprocess.run(
|
56 |
+
["hf", "whoami"],
|
57 |
capture_output=True,
|
58 |
text=True,
|
59 |
timeout=30
|
scripts/model_tonic/generate_model_card.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Generate unified model card from template
|
4 |
+
Handles template variables and conditional sections for quantized models
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Dict, Any, Optional
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
class ModelCardGenerator:
|
18 |
+
"""Generate unified model cards from templates"""
|
19 |
+
|
20 |
+
def __init__(self, template_path: str = "templates/model_card.md"):
|
21 |
+
self.template_path = Path(template_path)
|
22 |
+
if not self.template_path.exists():
|
23 |
+
raise FileNotFoundError(f"Template not found: {self.template_path}")
|
24 |
+
|
25 |
+
def load_template(self) -> str:
|
26 |
+
"""Load the model card template"""
|
27 |
+
with open(self.template_path, 'r', encoding='utf-8') as f:
|
28 |
+
return f.read()
|
29 |
+
|
30 |
+
def process_conditionals(self, content: str, variables: Dict[str, Any]) -> str:
|
31 |
+
"""Process conditional sections in the template"""
|
32 |
+
# Handle {{#if variable}}...{{/if}} blocks
|
33 |
+
pattern = r'\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}'
|
34 |
+
|
35 |
+
def replace_conditional(match):
|
36 |
+
variable_name = match.group(1)
|
37 |
+
conditional_content = match.group(2)
|
38 |
+
|
39 |
+
# Check if variable exists and is truthy
|
40 |
+
if variable_name in variables and variables[variable_name]:
|
41 |
+
return conditional_content
|
42 |
+
else:
|
43 |
+
return ""
|
44 |
+
|
45 |
+
return re.sub(pattern, replace_conditional, content, flags=re.DOTALL)
|
46 |
+
|
47 |
+
def replace_variables(self, content: str, variables: Dict[str, Any]) -> str:
|
48 |
+
"""Replace template variables with actual values"""
|
49 |
+
for key, value in variables.items():
|
50 |
+
placeholder = f"{{{{{key}}}}}"
|
51 |
+
content = content.replace(placeholder, str(value))
|
52 |
+
|
53 |
+
return content
|
54 |
+
|
55 |
+
def generate_model_card(self, variables: Dict[str, Any]) -> str:
|
56 |
+
"""Generate the complete model card"""
|
57 |
+
# Load template
|
58 |
+
content = self.load_template()
|
59 |
+
|
60 |
+
# Process conditionals first
|
61 |
+
content = self.process_conditionals(content, variables)
|
62 |
+
|
63 |
+
# Replace variables
|
64 |
+
content = self.replace_variables(content, variables)
|
65 |
+
|
66 |
+
return content
|
67 |
+
|
68 |
+
def save_model_card(self, content: str, output_path: str) -> bool:
|
69 |
+
"""Save the generated model card"""
|
70 |
+
try:
|
71 |
+
output_file = Path(output_path)
|
72 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
73 |
+
|
74 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
75 |
+
f.write(content)
|
76 |
+
|
77 |
+
logger.info(f"✅ Model card saved to: {output_file}")
|
78 |
+
return True
|
79 |
+
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"❌ Failed to save model card: {e}")
|
82 |
+
return False
|
83 |
+
|
84 |
+
def create_default_variables() -> Dict[str, Any]:
|
85 |
+
"""Create default variables for the model card"""
|
86 |
+
return {
|
87 |
+
"model_name": "SmolLM3 Fine-tuned Model",
|
88 |
+
"model_description": "A fine-tuned version of SmolLM3-3B for improved text generation and conversation capabilities.",
|
89 |
+
"repo_name": "your-username/model-name",
|
90 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
91 |
+
"dataset_name": "OpenHermes-FR",
|
92 |
+
"training_config_type": "Custom Configuration",
|
93 |
+
"trainer_type": "SFTTrainer",
|
94 |
+
"batch_size": "8",
|
95 |
+
"gradient_accumulation_steps": "16",
|
96 |
+
"learning_rate": "5e-6",
|
97 |
+
"max_epochs": "3",
|
98 |
+
"max_seq_length": "2048",
|
99 |
+
"hardware_info": "GPU (H100/A100)",
|
100 |
+
"experiment_name": "smollm3-experiment",
|
101 |
+
"trackio_url": "https://trackio.space/experiment",
|
102 |
+
"dataset_repo": "tonic/trackio-experiments",
|
103 |
+
"dataset_size": "~80K samples",
|
104 |
+
"dataset_format": "Chat format",
|
105 |
+
"author_name": "Your Name",
|
106 |
+
"model_name_slug": "smollm3-fine-tuned",
|
107 |
+
"quantized_models": False,
|
108 |
+
"dataset_sample_size": None
|
109 |
+
}
|
110 |
+
|
111 |
+
def parse_args():
|
112 |
+
"""Parse command line arguments"""
|
113 |
+
parser = argparse.ArgumentParser(description="Generate unified model card")
|
114 |
+
parser.add_argument("--template", default="templates/model_card.md",
|
115 |
+
help="Path to model card template")
|
116 |
+
parser.add_argument("--output", default="README.md",
|
117 |
+
help="Output path for generated model card")
|
118 |
+
parser.add_argument("--repo-name", required=True,
|
119 |
+
help="Hugging Face repository name")
|
120 |
+
parser.add_argument("--model-name", help="Model name")
|
121 |
+
parser.add_argument("--experiment-name", help="Experiment name")
|
122 |
+
parser.add_argument("--dataset-name", help="Dataset name")
|
123 |
+
parser.add_argument("--training-config", help="Training configuration type")
|
124 |
+
parser.add_argument("--trainer-type", help="Trainer type")
|
125 |
+
parser.add_argument("--batch-size", help="Batch size")
|
126 |
+
parser.add_argument("--learning-rate", help="Learning rate")
|
127 |
+
parser.add_argument("--max-epochs", help="Maximum epochs")
|
128 |
+
parser.add_argument("--max-seq-length", help="Maximum sequence length")
|
129 |
+
parser.add_argument("--hardware-info", help="Hardware information")
|
130 |
+
parser.add_argument("--trackio-url", help="Trackio URL")
|
131 |
+
parser.add_argument("--dataset-repo", help="Dataset repository")
|
132 |
+
parser.add_argument("--author-name", help="Author name")
|
133 |
+
parser.add_argument("--quantized-models", action="store_true",
|
134 |
+
help="Include quantized models")
|
135 |
+
parser.add_argument("--dataset-sample-size", help="Dataset sample size")
|
136 |
+
|
137 |
+
return parser.parse_args()
|
138 |
+
|
139 |
+
def main():
|
140 |
+
"""Main function"""
|
141 |
+
args = parse_args()
|
142 |
+
|
143 |
+
# Setup logging
|
144 |
+
logging.basicConfig(
|
145 |
+
level=logging.INFO,
|
146 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
147 |
+
)
|
148 |
+
|
149 |
+
try:
|
150 |
+
# Create generator
|
151 |
+
generator = ModelCardGenerator(args.template)
|
152 |
+
|
153 |
+
# Create variables dictionary
|
154 |
+
variables = create_default_variables()
|
155 |
+
|
156 |
+
# Override with command line arguments
|
157 |
+
if args.repo_name:
|
158 |
+
variables["repo_name"] = args.repo_name
|
159 |
+
if args.model_name:
|
160 |
+
variables["model_name"] = args.model_name
|
161 |
+
if args.experiment_name:
|
162 |
+
variables["experiment_name"] = args.experiment_name
|
163 |
+
if args.dataset_name:
|
164 |
+
variables["dataset_name"] = args.dataset_name
|
165 |
+
if args.training_config:
|
166 |
+
variables["training_config_type"] = args.training_config
|
167 |
+
if args.trainer_type:
|
168 |
+
variables["trainer_type"] = args.trainer_type
|
169 |
+
if args.batch_size:
|
170 |
+
variables["batch_size"] = args.batch_size
|
171 |
+
if args.learning_rate:
|
172 |
+
variables["learning_rate"] = args.learning_rate
|
173 |
+
if args.max_epochs:
|
174 |
+
variables["max_epochs"] = args.max_epochs
|
175 |
+
if args.max_seq_length:
|
176 |
+
variables["max_seq_length"] = args.max_seq_length
|
177 |
+
if args.hardware_info:
|
178 |
+
variables["hardware_info"] = args.hardware_info
|
179 |
+
if args.trackio_url:
|
180 |
+
variables["trackio_url"] = args.trackio_url
|
181 |
+
if args.dataset_repo:
|
182 |
+
variables["dataset_repo"] = args.dataset_repo
|
183 |
+
if args.author_name:
|
184 |
+
variables["author_name"] = args.author_name
|
185 |
+
if args.quantized_models:
|
186 |
+
variables["quantized_models"] = True
|
187 |
+
if args.dataset_sample_size:
|
188 |
+
variables["dataset_sample_size"] = args.dataset_sample_size
|
189 |
+
|
190 |
+
# Generate model card
|
191 |
+
print("🔄 Generating model card...")
|
192 |
+
content = generator.generate_model_card(variables)
|
193 |
+
|
194 |
+
# Save model card
|
195 |
+
if generator.save_model_card(content, args.output):
|
196 |
+
print("✅ Model card generated successfully!")
|
197 |
+
print(f"📄 Output: {args.output}")
|
198 |
+
else:
|
199 |
+
print("❌ Failed to generate model card")
|
200 |
+
return 1
|
201 |
+
|
202 |
+
return 0
|
203 |
+
|
204 |
+
except Exception as e:
|
205 |
+
logger.error(f"❌ Error generating model card: {e}")
|
206 |
+
return 1
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
exit(main())
|
scripts/model_tonic/push_to_huggingface.py
CHANGED
@@ -121,16 +121,56 @@ class HuggingFacePusher:
|
|
121 |
return True
|
122 |
|
123 |
def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
124 |
-
"""Create a comprehensive model card"""
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
language:
|
127 |
- en
|
128 |
-
|
|
|
129 |
tags:
|
130 |
- smollm3
|
131 |
- fine-tuned
|
|
|
132 |
- text-generation
|
133 |
-
- transformers
|
134 |
---
|
135 |
|
136 |
# {self.repo_name.split('/')[-1]}
|
@@ -174,7 +214,7 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
174 |
|
175 |
## Training Information
|
176 |
|
177 |
-
- **
|
178 |
- **Hardware**: {self._get_hardware_info()}
|
179 |
- **Training Time**: {results.get('training_time_hours', 'Unknown')} hours
|
180 |
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
@@ -197,9 +237,9 @@ This model is fine-tuned for specific tasks and may not generalize well to all u
|
|
197 |
|
198 |
## License
|
199 |
|
200 |
-
This model is licensed under the
|
201 |
"""
|
202 |
-
return model_card
|
203 |
|
204 |
def _get_model_size(self) -> float:
|
205 |
"""Get model size in GB"""
|
|
|
121 |
return True
|
122 |
|
123 |
def create_model_card(self, training_config: Dict[str, Any], results: Dict[str, Any]) -> str:
|
124 |
+
"""Create a comprehensive model card using the unified template"""
|
125 |
+
try:
|
126 |
+
# Import the model card generator
|
127 |
+
import sys
|
128 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
129 |
+
from scripts.model_tonic.generate_model_card import ModelCardGenerator
|
130 |
+
|
131 |
+
# Create variables for the template
|
132 |
+
variables = {
|
133 |
+
"model_name": f"{self.repo_name.split('/')[-1]} - Fine-tuned SmolLM3",
|
134 |
+
"model_description": "A fine-tuned version of SmolLM3-3B for improved text generation and conversation capabilities.",
|
135 |
+
"repo_name": self.repo_name,
|
136 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
137 |
+
"dataset_name": training_config.get('dataset_name', 'OpenHermes-FR'),
|
138 |
+
"training_config_type": training_config.get('training_config_type', 'Custom Configuration'),
|
139 |
+
"trainer_type": training_config.get('trainer_type', 'SFTTrainer'),
|
140 |
+
"batch_size": str(training_config.get('per_device_train_batch_size', 8)),
|
141 |
+
"gradient_accumulation_steps": str(training_config.get('gradient_accumulation_steps', 16)),
|
142 |
+
"learning_rate": str(training_config.get('learning_rate', '5e-6')),
|
143 |
+
"max_epochs": str(training_config.get('num_train_epochs', 3)),
|
144 |
+
"max_seq_length": str(training_config.get('max_seq_length', 2048)),
|
145 |
+
"hardware_info": self._get_hardware_info(),
|
146 |
+
"experiment_name": self.experiment_name or "smollm3-experiment",
|
147 |
+
"trackio_url": self.trackio_url or "https://trackio.space/experiment",
|
148 |
+
"dataset_repo": self.dataset_repo,
|
149 |
+
"dataset_size": training_config.get('dataset_size', '~80K samples'),
|
150 |
+
"dataset_format": training_config.get('dataset_format', 'Chat format'),
|
151 |
+
"author_name": training_config.get('author_name', 'Your Name'),
|
152 |
+
"model_name_slug": self.repo_name.split('/')[-1].lower().replace('-', '_'),
|
153 |
+
"quantized_models": False, # Will be updated if quantized models are added
|
154 |
+
"dataset_sample_size": training_config.get('dataset_sample_size')
|
155 |
+
}
|
156 |
+
|
157 |
+
# Create generator and generate model card
|
158 |
+
generator = ModelCardGenerator()
|
159 |
+
return generator.generate_model_card(variables)
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
logger.error(f"Failed to generate model card from template: {e}")
|
163 |
+
# Fallback to simple model card
|
164 |
+
return f"""---
|
165 |
language:
|
166 |
- en
|
167 |
+
- fr
|
168 |
+
license: apache-2.0
|
169 |
tags:
|
170 |
- smollm3
|
171 |
- fine-tuned
|
172 |
+
- causal-lm
|
173 |
- text-generation
|
|
|
174 |
---
|
175 |
|
176 |
# {self.repo_name.split('/')[-1]}
|
|
|
214 |
|
215 |
## Training Information
|
216 |
|
217 |
+
- **Base Model**: HuggingFaceTB/SmolLM3-3B
|
218 |
- **Hardware**: {self._get_hardware_info()}
|
219 |
- **Training Time**: {results.get('training_time_hours', 'Unknown')} hours
|
220 |
- **Final Loss**: {results.get('final_loss', 'Unknown')}
|
|
|
237 |
|
238 |
## License
|
239 |
|
240 |
+
This model is licensed under the Apache 2.0 License.
|
241 |
"""
|
242 |
+
# return model_card
|
243 |
|
244 |
def _get_model_size(self) -> float:
|
245 |
"""Get model size in GB"""
|
scripts/model_tonic/quantize_model.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Quantize Trained Model using torchao
|
4 |
+
Supports int8 (GPU) and int4 (CPU) quantization with Hugging Face Hub integration
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Dict, Any, Optional, List, Union
|
13 |
+
from datetime import datetime
|
14 |
+
import subprocess
|
15 |
+
import shutil
|
16 |
+
|
17 |
+
try:
|
18 |
+
import torch
|
19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
20 |
+
from torchao.quantization import (
|
21 |
+
Int8WeightOnlyConfig,
|
22 |
+
Int4WeightOnlyConfig,
|
23 |
+
Int8DynamicActivationInt8WeightConfig
|
24 |
+
)
|
25 |
+
from torchao.dtypes import Int4CPULayout
|
26 |
+
TORCHAO_AVAILABLE = True
|
27 |
+
except ImportError:
|
28 |
+
TORCHAO_AVAILABLE = False
|
29 |
+
print("Warning: torchao not available. Install with: pip install torchao")
|
30 |
+
|
31 |
+
try:
|
32 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
33 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
34 |
+
HF_AVAILABLE = True
|
35 |
+
except ImportError:
|
36 |
+
HF_AVAILABLE = False
|
37 |
+
print("Warning: huggingface_hub not available. Install with: pip install huggingface_hub")
|
38 |
+
|
39 |
+
try:
|
40 |
+
import sys
|
41 |
+
import os
|
42 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
|
43 |
+
from monitoring import SmolLM3Monitor
|
44 |
+
MONITORING_AVAILABLE = True
|
45 |
+
except ImportError:
|
46 |
+
MONITORING_AVAILABLE = False
|
47 |
+
print("Warning: monitoring module not available")
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
class ModelQuantizer:
|
52 |
+
"""Quantize models using torchao with HF Hub integration"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
model_path: str,
|
57 |
+
repo_name: str,
|
58 |
+
token: Optional[str] = None,
|
59 |
+
private: bool = False,
|
60 |
+
trackio_url: Optional[str] = None,
|
61 |
+
experiment_name: Optional[str] = None,
|
62 |
+
dataset_repo: Optional[str] = None,
|
63 |
+
hf_token: Optional[str] = None
|
64 |
+
):
|
65 |
+
self.model_path = Path(model_path)
|
66 |
+
self.repo_name = repo_name
|
67 |
+
self.token = token or hf_token or os.getenv('HF_TOKEN')
|
68 |
+
self.private = private
|
69 |
+
self.trackio_url = trackio_url
|
70 |
+
self.experiment_name = experiment_name
|
71 |
+
|
72 |
+
# HF Datasets configuration
|
73 |
+
self.dataset_repo = dataset_repo or os.getenv('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
|
74 |
+
self.hf_token = hf_token or os.getenv('HF_TOKEN')
|
75 |
+
|
76 |
+
# Initialize HF API
|
77 |
+
if HF_AVAILABLE:
|
78 |
+
self.api = HfApi(token=self.token)
|
79 |
+
else:
|
80 |
+
raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
|
81 |
+
|
82 |
+
# Initialize monitoring if available
|
83 |
+
self.monitor = None
|
84 |
+
if MONITORING_AVAILABLE:
|
85 |
+
self.monitor = SmolLM3Monitor(
|
86 |
+
experiment_name=experiment_name or "model_quantization",
|
87 |
+
trackio_url=trackio_url,
|
88 |
+
enable_tracking=bool(trackio_url),
|
89 |
+
hf_token=self.hf_token,
|
90 |
+
dataset_repo=self.dataset_repo
|
91 |
+
)
|
92 |
+
|
93 |
+
logger.info(f"Initialized ModelQuantizer for {repo_name}")
|
94 |
+
logger.info(f"Dataset repository: {self.dataset_repo}")
|
95 |
+
|
96 |
+
def validate_model_path(self) -> bool:
|
97 |
+
"""Validate that the model path exists and contains required files"""
|
98 |
+
if not self.model_path.exists():
|
99 |
+
logger.error(f"❌ Model path does not exist: {self.model_path}")
|
100 |
+
return False
|
101 |
+
|
102 |
+
# Check for essential model files
|
103 |
+
required_files = ['config.json', 'pytorch_model.bin']
|
104 |
+
optional_files = ['tokenizer.json', 'tokenizer_config.json']
|
105 |
+
|
106 |
+
missing_files = []
|
107 |
+
for file in required_files:
|
108 |
+
if not (self.model_path / file).exists():
|
109 |
+
missing_files.append(file)
|
110 |
+
|
111 |
+
if missing_files:
|
112 |
+
logger.error(f"❌ Missing required model files: {missing_files}")
|
113 |
+
return False
|
114 |
+
|
115 |
+
logger.info(f"✅ Model path validated: {self.model_path}")
|
116 |
+
return True
|
117 |
+
|
118 |
+
def create_quantization_config(self, quant_type: str, group_size: int = 128) -> TorchAoConfig:
|
119 |
+
"""Create torchao quantization configuration"""
|
120 |
+
if not TORCHAO_AVAILABLE:
|
121 |
+
raise ImportError("torchao is required. Install with: pip install torchao")
|
122 |
+
|
123 |
+
if quant_type == "int8_weight_only":
|
124 |
+
quant_config = Int8WeightOnlyConfig(group_size=group_size)
|
125 |
+
elif quant_type == "int4_weight_only":
|
126 |
+
# For int4, we need to specify CPU layout
|
127 |
+
quant_config = Int4WeightOnlyConfig(group_size=group_size, layout=Int4CPULayout())
|
128 |
+
elif quant_type == "int8_dynamic":
|
129 |
+
quant_config = Int8DynamicActivationInt8WeightConfig()
|
130 |
+
else:
|
131 |
+
raise ValueError(f"Unsupported quantization type: {quant_type}")
|
132 |
+
|
133 |
+
return TorchAoConfig(quant_type=quant_config)
|
134 |
+
|
135 |
+
def quantize_model(
|
136 |
+
self,
|
137 |
+
quant_type: str,
|
138 |
+
device: str = "auto",
|
139 |
+
group_size: int = 128,
|
140 |
+
save_dir: Optional[str] = None
|
141 |
+
) -> Optional[str]:
|
142 |
+
"""Quantize the model using torchao"""
|
143 |
+
if not TORCHAO_AVAILABLE:
|
144 |
+
logger.error("❌ torchao not available")
|
145 |
+
return None
|
146 |
+
|
147 |
+
try:
|
148 |
+
logger.info(f"🔄 Loading model from: {self.model_path}")
|
149 |
+
logger.info(f"🔄 Quantization type: {quant_type}")
|
150 |
+
logger.info(f"🔄 Device: {device}")
|
151 |
+
logger.info(f"🔄 Group size: {group_size}")
|
152 |
+
|
153 |
+
# Create quantization config
|
154 |
+
quantization_config = self.create_quantization_config(quant_type, group_size)
|
155 |
+
|
156 |
+
# Load and quantize the model
|
157 |
+
quantized_model = AutoModelForCausalLM.from_pretrained(
|
158 |
+
str(self.model_path),
|
159 |
+
torch_dtype="auto",
|
160 |
+
device_map=device,
|
161 |
+
quantization_config=quantization_config
|
162 |
+
)
|
163 |
+
|
164 |
+
# Determine save directory
|
165 |
+
if save_dir is None:
|
166 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
167 |
+
save_dir = f"quantized_{quant_type}_{timestamp}"
|
168 |
+
|
169 |
+
save_path = Path(save_dir)
|
170 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
171 |
+
|
172 |
+
# Save quantized model (don't use safetensors for torchao)
|
173 |
+
logger.info(f"💾 Saving quantized model to: {save_path}")
|
174 |
+
quantized_model.save_pretrained(save_path, safe_serialization=False)
|
175 |
+
|
176 |
+
# Copy tokenizer files if they exist
|
177 |
+
tokenizer_files = ['tokenizer.json', 'tokenizer_config.json', 'special_tokens_map.json']
|
178 |
+
for file in tokenizer_files:
|
179 |
+
src_file = self.model_path / file
|
180 |
+
if src_file.exists():
|
181 |
+
shutil.copy2(src_file, save_path / file)
|
182 |
+
logger.info(f"📋 Copied {file}")
|
183 |
+
|
184 |
+
logger.info(f"✅ Model quantized successfully: {save_path}")
|
185 |
+
return str(save_path)
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"❌ Quantization failed: {e}")
|
189 |
+
return None
|
190 |
+
|
191 |
+
def create_quantized_model_card(self, quant_type: str, original_model: str, subdir: str) -> str:
|
192 |
+
"""Create a model card for the quantized model"""
|
193 |
+
repo_name = self.repo_name
|
194 |
+
card_content = f"""---
|
195 |
+
language:
|
196 |
+
- en
|
197 |
+
- fr
|
198 |
+
license: apache-2.0
|
199 |
+
tags:
|
200 |
+
- quantized
|
201 |
+
- {quant_type}
|
202 |
+
- smollm3
|
203 |
+
- fine-tuned
|
204 |
+
---
|
205 |
+
|
206 |
+
# Quantized SmolLM3 Model
|
207 |
+
|
208 |
+
This is a quantized version of the SmolLM3 model using torchao quantization.
|
209 |
+
|
210 |
+
## Model Details
|
211 |
+
|
212 |
+
- **Base Model**: SmolLM3-3B
|
213 |
+
- **Quantization Type**: {quant_type}
|
214 |
+
- **Original Model**: {original_model}
|
215 |
+
- **Quantization Library**: torchao
|
216 |
+
- **Hardware Compatibility**: {'GPU' if 'int8' in quant_type else 'CPU'}
|
217 |
+
|
218 |
+
## Usage
|
219 |
+
|
220 |
+
```python
|
221 |
+
import torch
|
222 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
223 |
+
|
224 |
+
# Load the quantized model
|
225 |
+
model = AutoModelForCausalLM.from_pretrained(
|
226 |
+
f"{repo_name}/{subdir}",
|
227 |
+
device_map="auto",
|
228 |
+
torch_dtype=torch.bfloat16
|
229 |
+
)
|
230 |
+
tokenizer = AutoTokenizer.from_pretrained(f"{repo_name}/{subdir}")
|
231 |
+
|
232 |
+
# Generate text
|
233 |
+
input_text = "What are we having for dinner?"
|
234 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
|
235 |
+
output = model.generate(**input_ids, max_new_tokens=50)
|
236 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
237 |
+
```
|
238 |
+
|
239 |
+
## Quantization Details
|
240 |
+
|
241 |
+
- **Method**: torchao {quant_type}
|
242 |
+
- **Precision**: {'8-bit' if 'int8' in quant_type else '4-bit'}
|
243 |
+
- **Memory Reduction**: {'~50%' if 'int8' in quant_type else '~75%'}
|
244 |
+
- **Speed**: {'Faster inference with minimal accuracy loss' if 'int8' in quant_type else 'Significantly faster inference with some accuracy trade-off'}
|
245 |
+
|
246 |
+
## Training Information
|
247 |
+
|
248 |
+
This model was quantized from a fine-tuned SmolLM3 model using the torchao library.
|
249 |
+
The quantization process preserves the model's capabilities while reducing memory usage and improving inference speed.
|
250 |
+
|
251 |
+
## Limitations
|
252 |
+
|
253 |
+
- Quantized models may have slightly reduced accuracy compared to the original model
|
254 |
+
- {quant_type} quantization is optimized for {'GPU inference' if 'int8' in quant_type else 'CPU inference'}
|
255 |
+
- Some advanced features may not be available in quantized form
|
256 |
+
|
257 |
+
## Citation
|
258 |
+
|
259 |
+
If you use this model, please cite the original SmolLM3 paper and mention the quantization process.
|
260 |
+
|
261 |
+
```bibtex
|
262 |
+
@misc{{smollm3-quantized,
|
263 |
+
title={{Quantized SmolLM3 Model}},
|
264 |
+
author={{Your Name}},
|
265 |
+
year={{2024}},
|
266 |
+
url={{https://huggingface.co/{repo_name}/{subdir}}}
|
267 |
+
}}
|
268 |
+
```
|
269 |
+
"""
|
270 |
+
return card_content
|
271 |
+
|
272 |
+
def create_quantized_readme(self, quant_type: str, original_model: str, subdir: str) -> str:
|
273 |
+
"""Create a README for the quantized model repository"""
|
274 |
+
repo_name = self.repo_name
|
275 |
+
readme_content = f"""# Quantized SmolLM3 Model
|
276 |
+
|
277 |
+
This repository contains a quantized version of the SmolLM3 model using torchao quantization.
|
278 |
+
|
279 |
+
## Model Information
|
280 |
+
|
281 |
+
- **Model Type**: Quantized SmolLM3-3B
|
282 |
+
- **Quantization**: {quant_type}
|
283 |
+
- **Original Model**: {original_model}
|
284 |
+
- **Library**: torchao
|
285 |
+
- **Hardware**: {'GPU optimized' if 'int8' in quant_type else 'CPU optimized'}
|
286 |
+
|
287 |
+
## Quick Start
|
288 |
+
|
289 |
+
```python
|
290 |
+
import torch
|
291 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
292 |
+
|
293 |
+
# Load the quantized model
|
294 |
+
model = AutoModelForCausalLM.from_pretrained(
|
295 |
+
f"{repo_name}/{subdir}",
|
296 |
+
device_map="auto",
|
297 |
+
torch_dtype=torch.bfloat16
|
298 |
+
)
|
299 |
+
tokenizer = AutoTokenizer.from_pretrained(f"{repo_name}/{subdir}")
|
300 |
+
|
301 |
+
# Generate text
|
302 |
+
input_text = "What are we having for dinner?"
|
303 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
|
304 |
+
output = model.generate(**input_ids, max_new_tokens=50)
|
305 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
306 |
+
```
|
307 |
+
|
308 |
+
## Quantization Benefits
|
309 |
+
|
310 |
+
- **Memory Efficiency**: {'~50% reduction in memory usage' if 'int8' in quant_type else '~75% reduction in memory usage'}
|
311 |
+
- **Speed**: {'Faster inference with minimal accuracy loss' if 'int8' in quant_type else 'Significantly faster inference'}
|
312 |
+
- **Compatibility**: {'GPU optimized for high-performance inference' if 'int8' in quant_type else 'CPU optimized for deployment'}
|
313 |
+
|
314 |
+
## Installation
|
315 |
+
|
316 |
+
```bash
|
317 |
+
pip install torchao transformers
|
318 |
+
```
|
319 |
+
|
320 |
+
## Usage Examples
|
321 |
+
|
322 |
+
### Text Generation
|
323 |
+
```python
|
324 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
325 |
+
|
326 |
+
model = AutoModelForCausalLM.from_pretrained(f"{repo_name}/{subdir}")
|
327 |
+
tokenizer = AutoTokenizer.from_pretrained(f"{repo_name}/{subdir}")
|
328 |
+
|
329 |
+
text = "The future of artificial intelligence is"
|
330 |
+
inputs = tokenizer(text, return_tensors="pt")
|
331 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
332 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
333 |
+
```
|
334 |
+
|
335 |
+
### Conversation
|
336 |
+
```python
|
337 |
+
def chat_with_model(prompt, max_length=100):
|
338 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
339 |
+
outputs = model.generate(**inputs, max_new_tokens=max_length)
|
340 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
341 |
+
|
342 |
+
response = chat_with_model("Hello, how are you today?")
|
343 |
+
print(response)
|
344 |
+
```
|
345 |
+
|
346 |
+
## Model Architecture
|
347 |
+
|
348 |
+
This is a quantized version of the SmolLM3-3B model with the following specifications:
|
349 |
+
|
350 |
+
- **Base Model**: SmolLM3-3B
|
351 |
+
- **Quantization**: {quant_type}
|
352 |
+
- **Parameters**: ~3B (quantized)
|
353 |
+
- **Context Length**: Variable (depends on original model)
|
354 |
+
- **Languages**: English, French
|
355 |
+
|
356 |
+
## Performance
|
357 |
+
|
358 |
+
The quantized model provides:
|
359 |
+
|
360 |
+
- **Memory Usage**: {'~50% of original model' if 'int8' in quant_type else '~25% of original model'}
|
361 |
+
- **Inference Speed**: {'Faster than original with minimal accuracy loss' if 'int8' in quant_type else 'Significantly faster with some accuracy trade-off'}
|
362 |
+
- **Accuracy**: {'Minimal degradation' if 'int8' in quant_type else 'Some degradation acceptable for speed'}
|
363 |
+
|
364 |
+
## Limitations
|
365 |
+
|
366 |
+
1. **Accuracy**: Quantized models may have slightly reduced accuracy
|
367 |
+
2. **Compatibility**: {'GPU optimized, may not work on CPU' if 'int8' in quant_type else 'CPU optimized, may not work on GPU'}
|
368 |
+
3. **Features**: Some advanced features may not be available
|
369 |
+
4. **Training**: Cannot be further fine-tuned in quantized form
|
370 |
+
|
371 |
+
## Citation
|
372 |
+
|
373 |
+
If you use this model in your research, please cite:
|
374 |
+
|
375 |
+
```bibtex
|
376 |
+
@misc{{smollm3-quantized,
|
377 |
+
title={{Quantized SmolLM3 Model}},
|
378 |
+
author={{Your Name}},
|
379 |
+
year={{2024}},
|
380 |
+
url={{https://huggingface.co/{repo_name}/{subdir}}}
|
381 |
+
}}
|
382 |
+
```
|
383 |
+
|
384 |
+
## License
|
385 |
+
|
386 |
+
This model is licensed under the Apache 2.0 License.
|
387 |
+
|
388 |
+
## Support
|
389 |
+
|
390 |
+
For questions and support, please open an issue on the Hugging Face repository.
|
391 |
+
"""
|
392 |
+
return readme_content
|
393 |
+
|
394 |
+
def push_quantized_model(
|
395 |
+
self,
|
396 |
+
quantized_model_path: str,
|
397 |
+
quant_type: str,
|
398 |
+
original_model: str
|
399 |
+
) -> bool:
|
400 |
+
"""Push quantized model to the same Hugging Face repository as the main model"""
|
401 |
+
try:
|
402 |
+
logger.info(f"🚀 Pushing quantized model to subdirectory in: {self.repo_name}")
|
403 |
+
|
404 |
+
# Determine subdirectory name based on quantization type
|
405 |
+
if quant_type == "int8_weight_only":
|
406 |
+
subdir = "int8"
|
407 |
+
elif quant_type == "int4_weight_only":
|
408 |
+
subdir = "int4"
|
409 |
+
elif quant_type == "int8_dynamic":
|
410 |
+
subdir = "int8_dynamic"
|
411 |
+
else:
|
412 |
+
subdir = quant_type.replace("_", "-")
|
413 |
+
|
414 |
+
# Create repository if it doesn't exist
|
415 |
+
create_repo(
|
416 |
+
repo_id=self.repo_name,
|
417 |
+
token=self.token,
|
418 |
+
private=self.private,
|
419 |
+
exist_ok=True
|
420 |
+
)
|
421 |
+
|
422 |
+
# Create model card for the quantized version
|
423 |
+
model_card = self.create_quantized_model_card(quant_type, original_model, subdir)
|
424 |
+
model_card_path = Path(quantized_model_path) / "README.md"
|
425 |
+
with open(model_card_path, 'w', encoding='utf-8') as f:
|
426 |
+
f.write(model_card)
|
427 |
+
|
428 |
+
# Upload all files to subdirectory
|
429 |
+
logger.info(f"📤 Uploading quantized model files to {subdir}/ subdirectory...")
|
430 |
+
for file_path in Path(quantized_model_path).rglob("*"):
|
431 |
+
if file_path.is_file():
|
432 |
+
relative_path = file_path.relative_to(quantized_model_path)
|
433 |
+
# Upload to subdirectory within the repository
|
434 |
+
repo_path = f"{subdir}/{relative_path}"
|
435 |
+
upload_file(
|
436 |
+
path_or_fileobj=str(file_path),
|
437 |
+
path_in_repo=repo_path,
|
438 |
+
repo_id=self.repo_name,
|
439 |
+
token=self.token
|
440 |
+
)
|
441 |
+
logger.info(f"📤 Uploaded: {repo_path}")
|
442 |
+
|
443 |
+
logger.info(f"✅ Quantized model pushed successfully to: https://huggingface.co/{self.repo_name}/{subdir}")
|
444 |
+
|
445 |
+
# Log to Trackio if available
|
446 |
+
if self.monitor:
|
447 |
+
self.monitor.log_metric("quantization_type", quant_type)
|
448 |
+
self.monitor.log_metric("quantized_model_url", f"https://huggingface.co/{self.repo_name}/{subdir}")
|
449 |
+
self.monitor.log_artifact("quantized_model_path", quantized_model_path)
|
450 |
+
|
451 |
+
return True
|
452 |
+
|
453 |
+
except Exception as e:
|
454 |
+
logger.error(f"❌ Failed to push quantized model: {e}")
|
455 |
+
return False
|
456 |
+
|
457 |
+
def log_to_trackio(self, action: str, details: Dict[str, Any]):
|
458 |
+
"""Log quantization events to Trackio"""
|
459 |
+
if self.monitor:
|
460 |
+
try:
|
461 |
+
self.monitor.log_event(action, details)
|
462 |
+
logger.info(f"📊 Logged to Trackio: {action}")
|
463 |
+
except Exception as e:
|
464 |
+
logger.warning(f"⚠️ Failed to log to Trackio: {e}")
|
465 |
+
|
466 |
+
def quantize_and_push(
|
467 |
+
self,
|
468 |
+
quant_type: str,
|
469 |
+
device: str = "auto",
|
470 |
+
group_size: int = 128
|
471 |
+
) -> bool:
|
472 |
+
"""Complete quantization and push workflow"""
|
473 |
+
try:
|
474 |
+
# Validate model path
|
475 |
+
if not self.validate_model_path():
|
476 |
+
return False
|
477 |
+
|
478 |
+
# Log start of quantization
|
479 |
+
self.log_to_trackio("quantization_started", {
|
480 |
+
"quant_type": quant_type,
|
481 |
+
"device": device,
|
482 |
+
"group_size": group_size,
|
483 |
+
"model_path": str(self.model_path)
|
484 |
+
})
|
485 |
+
|
486 |
+
# Quantize model
|
487 |
+
quantized_path = self.quantize_model(quant_type, device, group_size)
|
488 |
+
if not quantized_path:
|
489 |
+
return False
|
490 |
+
|
491 |
+
# Log successful quantization
|
492 |
+
self.log_to_trackio("quantization_completed", {
|
493 |
+
"quantized_path": quantized_path,
|
494 |
+
"quant_type": quant_type
|
495 |
+
})
|
496 |
+
|
497 |
+
# Push to HF Hub
|
498 |
+
original_model = str(self.model_path)
|
499 |
+
if not self.push_quantized_model(quantized_path, quant_type, original_model):
|
500 |
+
return False
|
501 |
+
|
502 |
+
# Log successful push
|
503 |
+
self.log_to_trackio("quantized_model_pushed", {
|
504 |
+
"repo_name": self.repo_name,
|
505 |
+
"quant_type": quant_type
|
506 |
+
})
|
507 |
+
|
508 |
+
logger.info(f"🎉 Quantization and push completed successfully!")
|
509 |
+
logger.info(f"📊 Model: https://huggingface.co/{self.repo_name}")
|
510 |
+
|
511 |
+
return True
|
512 |
+
|
513 |
+
except Exception as e:
|
514 |
+
logger.error(f"❌ Quantization and push failed: {e}")
|
515 |
+
self.log_to_trackio("quantization_failed", {"error": str(e)})
|
516 |
+
return False
|
517 |
+
|
518 |
+
def parse_args():
|
519 |
+
"""Parse command line arguments"""
|
520 |
+
parser = argparse.ArgumentParser(description="Quantize model using torchao")
|
521 |
+
parser.add_argument("model_path", help="Path to the trained model")
|
522 |
+
parser.add_argument("repo_name", help="Hugging Face repository name")
|
523 |
+
parser.add_argument("--quant-type", choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"],
|
524 |
+
default="int8_weight_only", help="Quantization type")
|
525 |
+
parser.add_argument("--device", default="auto", help="Device for quantization (auto, cpu, cuda)")
|
526 |
+
parser.add_argument("--group-size", type=int, default=128, help="Group size for quantization")
|
527 |
+
parser.add_argument("--token", help="Hugging Face token")
|
528 |
+
parser.add_argument("--private", action="store_true", help="Create private repository")
|
529 |
+
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
|
530 |
+
parser.add_argument("--experiment-name", help="Experiment name for tracking")
|
531 |
+
parser.add_argument("--dataset-repo", help="HF Dataset repository")
|
532 |
+
|
533 |
+
return parser.parse_args()
|
534 |
+
|
535 |
+
def main():
|
536 |
+
"""Main function"""
|
537 |
+
args = parse_args()
|
538 |
+
|
539 |
+
# Setup logging
|
540 |
+
logging.basicConfig(
|
541 |
+
level=logging.INFO,
|
542 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
543 |
+
)
|
544 |
+
|
545 |
+
# Check torchao availability
|
546 |
+
if not TORCHAO_AVAILABLE:
|
547 |
+
logger.error("❌ torchao not available. Install with: pip install torchao")
|
548 |
+
return 1
|
549 |
+
|
550 |
+
# Initialize quantizer
|
551 |
+
quantizer = ModelQuantizer(
|
552 |
+
model_path=args.model_path,
|
553 |
+
repo_name=args.repo_name,
|
554 |
+
token=args.token,
|
555 |
+
private=args.private,
|
556 |
+
trackio_url=args.trackio_url,
|
557 |
+
experiment_name=args.experiment_name,
|
558 |
+
dataset_repo=args.dataset_repo
|
559 |
+
)
|
560 |
+
|
561 |
+
# Perform quantization and push
|
562 |
+
success = quantizer.quantize_and_push(
|
563 |
+
quant_type=args.quant_type,
|
564 |
+
device=args.device,
|
565 |
+
group_size=args.group_size
|
566 |
+
)
|
567 |
+
|
568 |
+
return 0 if success else 1
|
569 |
+
|
570 |
+
if __name__ == "__main__":
|
571 |
+
exit(main())
|
scripts/model_tonic/quantize_standalone.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Standalone Model Quantization Script
|
4 |
+
Quick quantization of trained models using torchao
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
import logging
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# Add the project root to the path
|
14 |
+
project_root = Path(__file__).parent.parent.parent
|
15 |
+
sys.path.append(str(project_root))
|
16 |
+
|
17 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
18 |
+
|
19 |
+
def main():
|
20 |
+
"""Standalone quantization script"""
|
21 |
+
parser = argparse.ArgumentParser(description="Quantize a trained model using torchao")
|
22 |
+
parser.add_argument("model_path", help="Path to the trained model")
|
23 |
+
parser.add_argument("repo_name", help="Hugging Face repository name for quantized model")
|
24 |
+
parser.add_argument("--quant-type", choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"],
|
25 |
+
default="int8_weight_only", help="Quantization type")
|
26 |
+
parser.add_argument("--device", default="auto", help="Device for quantization (auto, cpu, cuda)")
|
27 |
+
parser.add_argument("--group-size", type=int, default=128, help="Group size for quantization")
|
28 |
+
parser.add_argument("--token", help="Hugging Face token")
|
29 |
+
parser.add_argument("--private", action="store_true", help="Create private repository")
|
30 |
+
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
|
31 |
+
parser.add_argument("--experiment-name", help="Experiment name for tracking")
|
32 |
+
parser.add_argument("--dataset-repo", help="HF Dataset repository")
|
33 |
+
parser.add_argument("--save-only", action="store_true", help="Save quantized model locally without pushing to HF")
|
34 |
+
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
# Setup logging
|
38 |
+
logging.basicConfig(
|
39 |
+
level=logging.INFO,
|
40 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
41 |
+
)
|
42 |
+
|
43 |
+
print("🚀 Starting Model Quantization")
|
44 |
+
print("=" * 40)
|
45 |
+
print(f"Model: {args.model_path}")
|
46 |
+
print(f"Quantization: {args.quant_type}")
|
47 |
+
print(f"Device: {args.device}")
|
48 |
+
print(f"Repository: {args.repo_name}")
|
49 |
+
print(f"Save only: {args.save_only}")
|
50 |
+
print("=" * 40)
|
51 |
+
|
52 |
+
# Initialize quantizer
|
53 |
+
quantizer = ModelQuantizer(
|
54 |
+
model_path=args.model_path,
|
55 |
+
repo_name=args.repo_name,
|
56 |
+
token=args.token,
|
57 |
+
private=args.private,
|
58 |
+
trackio_url=args.trackio_url,
|
59 |
+
experiment_name=args.experiment_name,
|
60 |
+
dataset_repo=args.dataset_repo
|
61 |
+
)
|
62 |
+
|
63 |
+
if args.save_only:
|
64 |
+
# Just quantize and save locally
|
65 |
+
print("💾 Quantizing and saving locally...")
|
66 |
+
quantized_path = quantizer.quantize_model(
|
67 |
+
quant_type=args.quant_type,
|
68 |
+
device=args.device,
|
69 |
+
group_size=args.group_size
|
70 |
+
)
|
71 |
+
|
72 |
+
if quantized_path:
|
73 |
+
print(f"✅ Quantized model saved to: {quantized_path}")
|
74 |
+
print(f"📁 You can find the quantized model in: {quantized_path}")
|
75 |
+
else:
|
76 |
+
print("❌ Quantization failed")
|
77 |
+
return 1
|
78 |
+
else:
|
79 |
+
# Full quantization and push workflow
|
80 |
+
success = quantizer.quantize_and_push(
|
81 |
+
quant_type=args.quant_type,
|
82 |
+
device=args.device,
|
83 |
+
group_size=args.group_size
|
84 |
+
)
|
85 |
+
|
86 |
+
if not success:
|
87 |
+
print("❌ Quantization and push failed")
|
88 |
+
return 1
|
89 |
+
|
90 |
+
print("🎉 Quantization completed successfully!")
|
91 |
+
return 0
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
exit(main())
|
scripts/trackio_tonic/configure_trackio.py
CHANGED
@@ -51,7 +51,7 @@ def get_username_from_cli(token: str) -> str:
|
|
51 |
|
52 |
# Get username using CLI
|
53 |
result = subprocess.run(
|
54 |
-
["
|
55 |
capture_output=True,
|
56 |
text=True,
|
57 |
timeout=30
|
|
|
51 |
|
52 |
# Get username using CLI
|
53 |
result = subprocess.run(
|
54 |
+
["hf", "whoami"],
|
55 |
capture_output=True,
|
56 |
text=True,
|
57 |
timeout=30
|
scripts/trackio_tonic/deploy_trackio_space.py
CHANGED
@@ -87,7 +87,7 @@ class TrackioSpaceDeployer:
|
|
87 |
|
88 |
# Get username using CLI
|
89 |
result = subprocess.run(
|
90 |
-
["
|
91 |
capture_output=True,
|
92 |
text=True,
|
93 |
timeout=30
|
@@ -155,7 +155,7 @@ class TrackioSpaceDeployer:
|
|
155 |
|
156 |
# Create space using Hugging Face CLI
|
157 |
cmd = [
|
158 |
-
"
|
159 |
f"{self.username}/{self.space_name}",
|
160 |
"--type", "space"
|
161 |
]
|
@@ -168,7 +168,7 @@ class TrackioSpaceDeployer:
|
|
168 |
# Try alternative approach without space-specific flags
|
169 |
print("Retrying with basic space creation...")
|
170 |
cmd = [
|
171 |
-
"
|
172 |
f"{self.username}/{self.space_name}"
|
173 |
]
|
174 |
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
87 |
|
88 |
# Get username using CLI
|
89 |
result = subprocess.run(
|
90 |
+
["hf", "whoami"],
|
91 |
capture_output=True,
|
92 |
text=True,
|
93 |
timeout=30
|
|
|
155 |
|
156 |
# Create space using Hugging Face CLI
|
157 |
cmd = [
|
158 |
+
"hf", "repo", "create",
|
159 |
f"{self.username}/{self.space_name}",
|
160 |
"--type", "space"
|
161 |
]
|
|
|
168 |
# Try alternative approach without space-specific flags
|
169 |
print("Retrying with basic space creation...")
|
170 |
cmd = [
|
171 |
+
"hf", "repo", "create",
|
172 |
f"{self.username}/{self.space_name}"
|
173 |
]
|
174 |
result = subprocess.run(cmd, capture_output=True, text=True)
|
scripts/training/train.py
CHANGED
@@ -59,6 +59,12 @@ def main():
|
|
59 |
default="my_dataset",
|
60 |
help="Dataset directory path"
|
61 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
args = parser.parse_args()
|
64 |
|
@@ -122,6 +128,7 @@ def main():
|
|
122 |
print(f"Max iterations: {config.max_iters}")
|
123 |
print(f"Max sequence length: {config.max_seq_length}")
|
124 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
|
|
125 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
126 |
print(f"Dataset: {config.dataset_name}")
|
127 |
if hasattr(config, 'sample_size') and config.sample_size:
|
@@ -168,6 +175,10 @@ def main():
|
|
168 |
# Add dataset directory argument
|
169 |
train_args.extend(["--dataset_dir", args.dataset_dir])
|
170 |
|
|
|
|
|
|
|
|
|
171 |
# Override sys.argv for the training script
|
172 |
original_argv = sys.argv
|
173 |
sys.argv = ["train.py"] + train_args
|
|
|
59 |
default="my_dataset",
|
60 |
help="Dataset directory path"
|
61 |
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--trainer-type",
|
64 |
+
type=str,
|
65 |
+
choices=['sft', 'dpo'],
|
66 |
+
help="Trainer type: sft (Supervised Fine-tuning) or dpo (Direct Preference Optimization)"
|
67 |
+
)
|
68 |
|
69 |
args = parser.parse_args()
|
70 |
|
|
|
128 |
print(f"Max iterations: {config.max_iters}")
|
129 |
print(f"Max sequence length: {config.max_seq_length}")
|
130 |
print(f"Mixed precision: {'bf16' if config.bf16 else 'fp16'}")
|
131 |
+
print(f"Trainer type: {getattr(config, 'trainer_type', 'sft')}")
|
132 |
if hasattr(config, 'dataset_name') and config.dataset_name:
|
133 |
print(f"Dataset: {config.dataset_name}")
|
134 |
if hasattr(config, 'sample_size') and config.sample_size:
|
|
|
175 |
# Add dataset directory argument
|
176 |
train_args.extend(["--dataset_dir", args.dataset_dir])
|
177 |
|
178 |
+
# Add trainer type argument if provided
|
179 |
+
if args.trainer_type:
|
180 |
+
train_args.extend(["--trainer_type", args.trainer_type])
|
181 |
+
|
182 |
# Override sys.argv for the training script
|
183 |
original_argv = sys.argv
|
184 |
sys.argv = ["train.py"] + train_args
|
setup_launch.py
CHANGED
@@ -209,7 +209,7 @@ After running the pipeline, you'll have:
|
|
209 |
|
210 |
1. **HF Token Issues**
|
211 |
```bash
|
212 |
-
|
213 |
```
|
214 |
|
215 |
2. **CUDA Issues**
|
|
|
209 |
|
210 |
1. **HF Token Issues**
|
211 |
```bash
|
212 |
+
hf whoami
|
213 |
```
|
214 |
|
215 |
2. **CUDA Issues**
|
src/data.py
CHANGED
@@ -298,14 +298,44 @@ class SmolLM3Dataset:
|
|
298 |
def get_data_collator(self):
|
299 |
"""Get data collator for training"""
|
300 |
from transformers import DataCollatorForLanguageModeling
|
301 |
-
|
302 |
-
|
|
|
303 |
tokenizer=self.tokenizer,
|
304 |
-
mlm=False,
|
305 |
-
pad_to_multiple_of=8,
|
306 |
-
return_tensors="pt",
|
307 |
)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
def create_sample_dataset(output_path: str = "my_dataset"):
|
310 |
"""Create a sample dataset for testing"""
|
311 |
os.makedirs(output_path, exist_ok=True)
|
|
|
298 |
def get_data_collator(self):
|
299 |
"""Get data collator for training"""
|
300 |
from transformers import DataCollatorForLanguageModeling
|
301 |
+
import torch
|
302 |
+
|
303 |
+
base_collator = DataCollatorForLanguageModeling(
|
304 |
tokenizer=self.tokenizer,
|
305 |
+
mlm=False,
|
306 |
+
pad_to_multiple_of=8,
|
307 |
+
return_tensors="pt",
|
308 |
)
|
309 |
|
310 |
+
def collator_with_stats(features):
|
311 |
+
batch = base_collator(features)
|
312 |
+
# Calculate token stats
|
313 |
+
input_ids = batch["input_ids"]
|
314 |
+
attention_mask = batch.get("attention_mask", None)
|
315 |
+
labels = batch.get("labels", None)
|
316 |
+
pad_token_id = self.tokenizer.pad_token_id
|
317 |
+
if pad_token_id is None:
|
318 |
+
pad_token_id = self.tokenizer.eos_token_id
|
319 |
+
|
320 |
+
total_tokens = int((input_ids != pad_token_id).sum().item())
|
321 |
+
padding_tokens = int((input_ids == pad_token_id).sum().item())
|
322 |
+
batch_size, seq_len = input_ids.shape
|
323 |
+
# Truncated tokens: count tokens that were cut off due to max_seq_length
|
324 |
+
# (Assume all input is truncated to max_seq_length, so count tokens at max length)
|
325 |
+
truncated_tokens = 0
|
326 |
+
for f in features:
|
327 |
+
if "length" in f and f["length"] >= self.max_seq_length:
|
328 |
+
truncated_tokens += f["length"] - self.max_seq_length + 1
|
329 |
+
|
330 |
+
batch["total_tokens"] = total_tokens
|
331 |
+
batch["padding_tokens"] = padding_tokens
|
332 |
+
batch["truncated_tokens"] = truncated_tokens
|
333 |
+
batch["batch_size"] = batch_size
|
334 |
+
batch["seq_len"] = seq_len
|
335 |
+
return batch
|
336 |
+
|
337 |
+
return collator_with_stats
|
338 |
+
|
339 |
def create_sample_dataset(output_path: str = "my_dataset"):
|
340 |
"""Create a sample dataset for testing"""
|
341 |
os.makedirs(output_path, exist_ok=True)
|
src/monitoring.py
CHANGED
@@ -213,7 +213,12 @@ class SmolLM3Monitor:
|
|
213 |
return self.log_configuration(config)
|
214 |
|
215 |
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
|
216 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
217 |
if not self.enable_tracking or not self.log_metrics_enabled:
|
218 |
return
|
219 |
|
@@ -381,11 +386,18 @@ class SmolLM3Monitor:
|
|
381 |
from transformers import TrainerCallback
|
382 |
|
383 |
class TrackioCallback(TrainerCallback):
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
def __init__(self, monitor):
|
385 |
super().__init__()
|
386 |
self.monitor = monitor
|
387 |
logger.info("TrackioCallback initialized")
|
388 |
-
|
|
|
389 |
def on_init_end(self, args, state, control, **kwargs):
|
390 |
"""Called when training initialization is complete"""
|
391 |
try:
|
@@ -395,11 +407,41 @@ class SmolLM3Monitor:
|
|
395 |
|
396 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
397 |
"""Called when logs are created"""
|
|
|
398 |
try:
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
except Exception as e:
|
404 |
logger.error("Error in on_log: %s", e)
|
405 |
|
|
|
213 |
return self.log_configuration(config)
|
214 |
|
215 |
def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
|
216 |
+
"""
|
217 |
+
Log training metrics. Supports advanced metrics such as:
|
218 |
+
- total_tokens, truncated_tokens, padding_tokens
|
219 |
+
- throughput, step_time, batch_size, seq_len
|
220 |
+
- token_acc, train/gate_ortho, train/center, etc.
|
221 |
+
"""
|
222 |
if not self.enable_tracking or not self.log_metrics_enabled:
|
223 |
return
|
224 |
|
|
|
386 |
from transformers import TrainerCallback
|
387 |
|
388 |
class TrackioCallback(TrainerCallback):
|
389 |
+
"""
|
390 |
+
Trainer callback for logging metrics, including advanced metrics:
|
391 |
+
- total_tokens, truncated_tokens, padding_tokens
|
392 |
+
- throughput, step_time, batch_size, seq_len
|
393 |
+
- token_acc, train/gate_ortho, train/center, etc.
|
394 |
+
"""
|
395 |
def __init__(self, monitor):
|
396 |
super().__init__()
|
397 |
self.monitor = monitor
|
398 |
logger.info("TrackioCallback initialized")
|
399 |
+
self.last_step_time = None
|
400 |
+
|
401 |
def on_init_end(self, args, state, control, **kwargs):
|
402 |
"""Called when training initialization is complete"""
|
403 |
try:
|
|
|
407 |
|
408 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
409 |
"""Called when logs are created"""
|
410 |
+
import time
|
411 |
try:
|
412 |
+
step = getattr(state, 'global_step', None)
|
413 |
+
# Timing and throughput
|
414 |
+
now = time.time()
|
415 |
+
if self.last_step_time is not None:
|
416 |
+
step_time = now - self.last_step_time
|
417 |
+
logs['step_time'] = step_time
|
418 |
+
# Throughput: tokens/sec if total_tokens is available
|
419 |
+
if hasattr(self, 'last_total_tokens') and self.last_total_tokens is not None:
|
420 |
+
throughput = (logs.get('total_tokens', 0) / step_time) if step_time > 0 else 0
|
421 |
+
logs['throughput'] = throughput
|
422 |
+
self.last_step_time = now
|
423 |
+
|
424 |
+
# Token stats from batch (if available in kwargs)
|
425 |
+
batch = kwargs.get('inputs', None)
|
426 |
+
if batch is not None:
|
427 |
+
for key in ['total_tokens', 'padding_tokens', 'truncated_tokens', 'batch_size', 'seq_len']:
|
428 |
+
if key in batch:
|
429 |
+
logs[key] = batch[key]
|
430 |
+
self.last_total_tokens = batch.get('total_tokens', None)
|
431 |
+
else:
|
432 |
+
self.last_total_tokens = None
|
433 |
+
|
434 |
+
# Token accuracy (if possible)
|
435 |
+
if 'labels' in logs and 'predictions' in logs:
|
436 |
+
labels = logs['labels']
|
437 |
+
preds = logs['predictions']
|
438 |
+
if hasattr(labels, 'shape') and hasattr(preds, 'shape'):
|
439 |
+
correct = (preds == labels).sum().item()
|
440 |
+
total = labels.numel()
|
441 |
+
logs['token_acc'] = correct / total if total > 0 else 0.0
|
442 |
+
|
443 |
+
self.monitor.log_metrics(logs, step)
|
444 |
+
self.monitor.log_system_metrics(step)
|
445 |
except Exception as e:
|
446 |
logger.error("Error in on_log: %s", e)
|
447 |
|
src/train.py
CHANGED
@@ -29,7 +29,7 @@ except ImportError:
|
|
29 |
from config import get_config
|
30 |
from model import SmolLM3Model
|
31 |
from data import SmolLM3Dataset
|
32 |
-
from trainer import SmolLM3Trainer
|
33 |
from monitoring import create_monitor_from_config
|
34 |
|
35 |
def setup_logging():
|
@@ -103,6 +103,10 @@ def parse_args():
|
|
103 |
parser.add_argument('--dataset_repo', type=str, default=None,
|
104 |
help='HF Dataset repository for experiment storage')
|
105 |
|
|
|
|
|
|
|
|
|
106 |
return parser.parse_args()
|
107 |
|
108 |
def main():
|
@@ -198,14 +202,31 @@ def main():
|
|
198 |
sample_seed=getattr(config, 'sample_seed', 42)
|
199 |
)
|
200 |
|
201 |
-
#
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
# Start training
|
211 |
try:
|
|
|
29 |
from config import get_config
|
30 |
from model import SmolLM3Model
|
31 |
from data import SmolLM3Dataset
|
32 |
+
from trainer import SmolLM3Trainer, SmolLM3DPOTrainer
|
33 |
from monitoring import create_monitor_from_config
|
34 |
|
35 |
def setup_logging():
|
|
|
103 |
parser.add_argument('--dataset_repo', type=str, default=None,
|
104 |
help='HF Dataset repository for experiment storage')
|
105 |
|
106 |
+
# Trainer type selection
|
107 |
+
parser.add_argument('--trainer_type', type=str, choices=['sft', 'dpo'], default=None,
|
108 |
+
help='Trainer type: sft (Supervised Fine-tuning) or dpo (Direct Preference Optimization)')
|
109 |
+
|
110 |
return parser.parse_args()
|
111 |
|
112 |
def main():
|
|
|
202 |
sample_seed=getattr(config, 'sample_seed', 42)
|
203 |
)
|
204 |
|
205 |
+
# Determine trainer type (command line overrides config)
|
206 |
+
trainer_type = args.trainer_type or getattr(config, 'trainer_type', 'sft')
|
207 |
+
logger.info(f"Using trainer type: {trainer_type}")
|
208 |
+
|
209 |
+
# Import the appropriate trainer class
|
210 |
+
# from trainer import SmolLM3Trainer, SmolLM3DPOTrainer # This line is removed as per the edit hint
|
211 |
+
|
212 |
+
# Initialize trainer based on type
|
213 |
+
if trainer_type.lower() == 'dpo':
|
214 |
+
logger.info("Initializing DPO trainer...")
|
215 |
+
trainer = SmolLM3DPOTrainer(
|
216 |
+
model=model,
|
217 |
+
dataset=dataset,
|
218 |
+
config=config,
|
219 |
+
output_dir=output_path
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
logger.info("Initializing SFT trainer...")
|
223 |
+
trainer = SmolLM3Trainer(
|
224 |
+
model=model,
|
225 |
+
dataset=dataset,
|
226 |
+
config=config,
|
227 |
+
output_dir=output_path,
|
228 |
+
init_from=args.init_from
|
229 |
+
)
|
230 |
|
231 |
# Start training
|
232 |
try:
|
templates/datasets/readme.md
CHANGED
@@ -36,11 +36,15 @@ tags:
|
|
36 |
- trackio
|
37 |
- tonic
|
38 |
- experiment tracking
|
|
|
|
|
|
|
|
|
39 |
---
|
40 |
|
41 |
# Trackio Experiments Dataset
|
42 |
|
43 |
-
This dataset stores experiment tracking data for ML training runs, particularly focused on SmolLM3 fine-tuning experiments.
|
44 |
|
45 |
## Dataset Structure
|
46 |
|
@@ -57,6 +61,77 @@ The dataset contains the following columns:
|
|
57 |
- **logs**: JSON string containing experiment logs
|
58 |
- **last_updated**: Timestamp of last update
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
## Usage
|
61 |
|
62 |
This dataset is automatically used by the Trackio monitoring system to store and retrieve experiment data. It provides persistent storage for experiment tracking across different training runs.
|
@@ -67,6 +142,7 @@ The dataset is used by:
|
|
67 |
- Trackio Spaces for experiment visualization
|
68 |
- Training scripts for logging metrics and parameters
|
69 |
- Monitoring systems for experiment tracking
|
|
|
70 |
|
71 |
## Privacy
|
72 |
|
@@ -79,11 +155,11 @@ This dataset is private by default to ensure experiment data security. Only user
|
|
79 |
{
|
80 |
"experiment_id": "exp_20250720_130853",
|
81 |
"name": "smollm3_finetune",
|
82 |
-
"description": "SmolLM3 fine-tuning experiment",
|
83 |
"created_at": "2025-07-20T11:20:01.780908",
|
84 |
"status": "running",
|
85 |
-
"metrics": "[{\"timestamp\": \"2025-07-20T11:20:01.780908\", \"step\": 25, \"metrics\": {\"loss\": 1.1659, \"accuracy\": 0.759}}]",
|
86 |
-
"parameters": "{\"model_name\": \"HuggingFaceTB/SmolLM3-3B\", \"batch_size\": 8, \"learning_rate\": 3.5e-06}",
|
87 |
"artifacts": "[]",
|
88 |
"logs": "[]",
|
89 |
"last_updated": "2025-07-20T11:20:01.780908"
|
|
|
36 |
- trackio
|
37 |
- tonic
|
38 |
- experiment tracking
|
39 |
+
- smollm3
|
40 |
+
- fine-tuning
|
41 |
+
- legml
|
42 |
+
- hermes
|
43 |
---
|
44 |
|
45 |
# Trackio Experiments Dataset
|
46 |
|
47 |
+
This dataset stores experiment tracking data for ML training runs, particularly focused on SmolLM3 fine-tuning experiments with comprehensive metrics tracking.
|
48 |
|
49 |
## Dataset Structure
|
50 |
|
|
|
61 |
- **logs**: JSON string containing experiment logs
|
62 |
- **last_updated**: Timestamp of last update
|
63 |
|
64 |
+
## Metrics Structure
|
65 |
+
|
66 |
+
The metrics field contains JSON arrays with the following structure:
|
67 |
+
|
68 |
+
```json
|
69 |
+
[
|
70 |
+
{
|
71 |
+
"timestamp": "2025-07-20T11:20:01.780908",
|
72 |
+
"step": 25,
|
73 |
+
"metrics": {
|
74 |
+
"loss": 1.1659,
|
75 |
+
"accuracy": 0.759,
|
76 |
+
"learning_rate": 7e-08,
|
77 |
+
"grad_norm": 10.3125,
|
78 |
+
"epoch": 0.004851130919895701,
|
79 |
+
|
80 |
+
// Advanced Training Metrics
|
81 |
+
"total_tokens": 1642080.0,
|
82 |
+
"truncated_tokens": 128,
|
83 |
+
"padding_tokens": 256,
|
84 |
+
"throughput": 3284160.0,
|
85 |
+
"step_time": 0.5,
|
86 |
+
"batch_size": 8,
|
87 |
+
"seq_len": 2048,
|
88 |
+
"token_acc": 0.759,
|
89 |
+
|
90 |
+
// Custom Losses
|
91 |
+
"train/gate_ortho": 0.0234,
|
92 |
+
"train/center": 0.0156,
|
93 |
+
|
94 |
+
// System Metrics
|
95 |
+
"gpu_memory_allocated": 17.202261447906494,
|
96 |
+
"gpu_memory_reserved": 75.474609375,
|
97 |
+
"gpu_utilization": 85.2,
|
98 |
+
"cpu_percent": 2.7,
|
99 |
+
"memory_percent": 10.1
|
100 |
+
}
|
101 |
+
}
|
102 |
+
]
|
103 |
+
```
|
104 |
+
|
105 |
+
## Supported Metrics
|
106 |
+
|
107 |
+
### Core Training Metrics
|
108 |
+
- **loss**: Training loss value
|
109 |
+
- **accuracy**: Model accuracy
|
110 |
+
- **learning_rate**: Current learning rate
|
111 |
+
- **grad_norm**: Gradient norm
|
112 |
+
- **epoch**: Current epoch progress
|
113 |
+
|
114 |
+
### Advanced Token Metrics
|
115 |
+
- **total_tokens**: Total tokens processed in the batch
|
116 |
+
- **truncated_tokens**: Number of tokens truncated during processing
|
117 |
+
- **padding_tokens**: Number of padding tokens added
|
118 |
+
- **throughput**: Tokens processed per second
|
119 |
+
- **step_time**: Time taken for the current training step
|
120 |
+
- **batch_size**: Current batch size
|
121 |
+
- **seq_len**: Sequence length
|
122 |
+
- **token_acc**: Token-level accuracy
|
123 |
+
|
124 |
+
### Custom Losses (SmolLM3-specific)
|
125 |
+
- **train/gate_ortho**: Gate orthogonality loss
|
126 |
+
- **train/center**: Center loss component
|
127 |
+
|
128 |
+
### System Performance Metrics
|
129 |
+
- **gpu_memory_allocated**: GPU memory currently allocated (GB)
|
130 |
+
- **gpu_memory_reserved**: GPU memory reserved (GB)
|
131 |
+
- **gpu_utilization**: GPU utilization percentage
|
132 |
+
- **cpu_percent**: CPU usage percentage
|
133 |
+
- **memory_percent**: System memory usage percentage
|
134 |
+
|
135 |
## Usage
|
136 |
|
137 |
This dataset is automatically used by the Trackio monitoring system to store and retrieve experiment data. It provides persistent storage for experiment tracking across different training runs.
|
|
|
142 |
- Trackio Spaces for experiment visualization
|
143 |
- Training scripts for logging metrics and parameters
|
144 |
- Monitoring systems for experiment tracking
|
145 |
+
- SmolLM3 fine-tuning pipeline for comprehensive metrics capture
|
146 |
|
147 |
## Privacy
|
148 |
|
|
|
155 |
{
|
156 |
"experiment_id": "exp_20250720_130853",
|
157 |
"name": "smollm3_finetune",
|
158 |
+
"description": "SmolLM3 fine-tuning experiment with comprehensive metrics",
|
159 |
"created_at": "2025-07-20T11:20:01.780908",
|
160 |
"status": "running",
|
161 |
+
"metrics": "[{\"timestamp\": \"2025-07-20T11:20:01.780908\", \"step\": 25, \"metrics\": {\"loss\": 1.1659, \"accuracy\": 0.759, \"total_tokens\": 1642080.0, \"throughput\": 3284160.0, \"train/gate_ortho\": 0.0234, \"train/center\": 0.0156}}]",
|
162 |
+
"parameters": "{\"model_name\": \"HuggingFaceTB/SmolLM3-3B\", \"batch_size\": 8, \"learning_rate\": 3.5e-06, \"max_seq_length\": 12288}",
|
163 |
"artifacts": "[]",
|
164 |
"logs": "[]",
|
165 |
"last_updated": "2025-07-20T11:20:01.780908"
|
templates/model_card.md
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
- fr
|
5 |
+
license: apache-2.0
|
6 |
+
tags:
|
7 |
+
- smollm3
|
8 |
+
- fine-tuned
|
9 |
+
- causal-lm
|
10 |
+
- text-generation
|
11 |
+
- {{#if quantized_models}}quantized{{/if}}
|
12 |
+
---
|
13 |
+
|
14 |
+
# {{model_name}}
|
15 |
+
|
16 |
+
{{model_description}}
|
17 |
+
|
18 |
+
## Model Details
|
19 |
+
|
20 |
+
- **Base Model**: SmolLM3-3B
|
21 |
+
- **Model Type**: Causal Language Model
|
22 |
+
- **Languages**: English, French
|
23 |
+
- **License**: Apache 2.0
|
24 |
+
- **Fine-tuned**: Yes
|
25 |
+
{{#if quantized_models}}
|
26 |
+
- **Quantized Versions**: Available in subdirectories
|
27 |
+
{{/if}}
|
28 |
+
|
29 |
+
## Usage
|
30 |
+
|
31 |
+
### Main Model
|
32 |
+
|
33 |
+
```python
|
34 |
+
import torch
|
35 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
36 |
+
|
37 |
+
# Load the main model
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
"{{repo_name}}",
|
40 |
+
device_map="auto",
|
41 |
+
torch_dtype=torch.bfloat16
|
42 |
+
)
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
44 |
+
|
45 |
+
# Generate text
|
46 |
+
input_text = "What are we having for dinner?"
|
47 |
+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device.type)
|
48 |
+
output = model.generate(**input_ids, max_new_tokens=50)
|
49 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
50 |
+
```
|
51 |
+
|
52 |
+
{{#if quantized_models}}
|
53 |
+
### Quantized Models
|
54 |
+
|
55 |
+
This repository also includes quantized versions of the model for improved efficiency:
|
56 |
+
|
57 |
+
#### int8 Weight-Only Quantization (GPU Optimized)
|
58 |
+
```python
|
59 |
+
import torch
|
60 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
61 |
+
|
62 |
+
# Load int8 quantized model (GPU optimized)
|
63 |
+
model = AutoModelForCausalLM.from_pretrained(
|
64 |
+
"{{repo_name}}/int8",
|
65 |
+
device_map="auto",
|
66 |
+
torch_dtype=torch.bfloat16
|
67 |
+
)
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}/int8")
|
69 |
+
```
|
70 |
+
|
71 |
+
#### int4 Weight-Only Quantization (CPU Optimized)
|
72 |
+
```python
|
73 |
+
import torch
|
74 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
75 |
+
|
76 |
+
# Load int4 quantized model (CPU optimized)
|
77 |
+
model = AutoModelForCausalLM.from_pretrained(
|
78 |
+
"{{repo_name}}/int4",
|
79 |
+
device_map="cpu",
|
80 |
+
torch_dtype=torch.bfloat16
|
81 |
+
)
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}/int4")
|
83 |
+
```
|
84 |
+
|
85 |
+
### Quantization Benefits
|
86 |
+
|
87 |
+
- **int8 (GPU)**: ~50% memory reduction, faster inference with minimal accuracy loss
|
88 |
+
- **int4 (CPU)**: ~75% memory reduction, significantly faster inference with some accuracy trade-off
|
89 |
+
|
90 |
+
{{/if}}
|
91 |
+
|
92 |
+
## Training Information
|
93 |
+
|
94 |
+
### Training Configuration
|
95 |
+
- **Base Model**: {{base_model}}
|
96 |
+
- **Dataset**: {{dataset_name}}
|
97 |
+
- **Training Config**: {{training_config_type}}
|
98 |
+
- **Trainer Type**: {{trainer_type}}
|
99 |
+
{{#if dataset_sample_size}}
|
100 |
+
- **Dataset Sample Size**: {{dataset_sample_size}}
|
101 |
+
{{/if}}
|
102 |
+
|
103 |
+
### Training Parameters
|
104 |
+
- **Batch Size**: {{batch_size}}
|
105 |
+
- **Gradient Accumulation**: {{gradient_accumulation_steps}}
|
106 |
+
- **Learning Rate**: {{learning_rate}}
|
107 |
+
- **Max Epochs**: {{max_epochs}}
|
108 |
+
- **Sequence Length**: {{max_seq_length}}
|
109 |
+
|
110 |
+
### Training Infrastructure
|
111 |
+
- **Hardware**: {{hardware_info}}
|
112 |
+
- **Monitoring**: Trackio integration
|
113 |
+
- **Experiment**: {{experiment_name}}
|
114 |
+
|
115 |
+
## Model Architecture
|
116 |
+
|
117 |
+
This is a fine-tuned version of the SmolLM3-3B model with the following specifications:
|
118 |
+
|
119 |
+
- **Base Model**: SmolLM3-3B
|
120 |
+
- **Parameters**: ~3B
|
121 |
+
- **Context Length**: {{max_seq_length}}
|
122 |
+
- **Languages**: English, French
|
123 |
+
- **Architecture**: Transformer-based causal language model
|
124 |
+
|
125 |
+
## Performance
|
126 |
+
|
127 |
+
The model provides:
|
128 |
+
- **Text Generation**: High-quality text generation capabilities
|
129 |
+
- **Conversation**: Natural conversation abilities
|
130 |
+
- **Multilingual**: Support for English and French
|
131 |
+
{{#if quantized_models}}
|
132 |
+
- **Quantized Versions**: Optimized for different deployment scenarios
|
133 |
+
{{/if}}
|
134 |
+
|
135 |
+
## Limitations
|
136 |
+
|
137 |
+
1. **Context Length**: Limited by the model's maximum sequence length
|
138 |
+
2. **Bias**: May inherit biases from the training data
|
139 |
+
3. **Factual Accuracy**: May generate incorrect or outdated information
|
140 |
+
4. **Safety**: Should be used responsibly with appropriate safeguards
|
141 |
+
{{#if quantized_models}}
|
142 |
+
5. **Quantization**: Quantized versions may have slightly reduced accuracy
|
143 |
+
{{/if}}
|
144 |
+
|
145 |
+
## Training Data
|
146 |
+
|
147 |
+
The model was fine-tuned on:
|
148 |
+
- **Dataset**: {{dataset_name}}
|
149 |
+
- **Size**: {{dataset_size}}
|
150 |
+
- **Format**: {{dataset_format}}
|
151 |
+
- **Languages**: English, French
|
152 |
+
|
153 |
+
## Evaluation
|
154 |
+
|
155 |
+
The model was evaluated using:
|
156 |
+
- **Metrics**: Loss, perplexity, and qualitative assessment
|
157 |
+
- **Monitoring**: Real-time tracking via Trackio
|
158 |
+
- **Validation**: Regular validation during training
|
159 |
+
|
160 |
+
## Citation
|
161 |
+
|
162 |
+
If you use this model in your research, please cite:
|
163 |
+
|
164 |
+
```bibtex
|
165 |
+
@misc{{{model_name_slug}},
|
166 |
+
title={{{{model_name}}}},
|
167 |
+
author={{{author_name}}},
|
168 |
+
year={2024},
|
169 |
+
url={https://huggingface.co/{{repo_name}}}
|
170 |
+
}
|
171 |
+
```
|
172 |
+
|
173 |
+
## License
|
174 |
+
|
175 |
+
This model is licensed under the Apache 2.0 License.
|
176 |
+
|
177 |
+
## Acknowledgments
|
178 |
+
|
179 |
+
- **Base Model**: SmolLM3-3B by HuggingFaceTB
|
180 |
+
- **Training Framework**: PyTorch, Transformers, PEFT
|
181 |
+
- **Monitoring**: Trackio integration
|
182 |
+
- **Quantization**: torchao library
|
183 |
+
|
184 |
+
## Support
|
185 |
+
|
186 |
+
For questions and support:
|
187 |
+
- Open an issue on the Hugging Face repository
|
188 |
+
- Check the model documentation
|
189 |
+
- Review the training logs and configuration
|
190 |
+
|
191 |
+
## Repository Structure
|
192 |
+
|
193 |
+
```
|
194 |
+
{{repo_name}}/
|
195 |
+
├── README.md (this file)
|
196 |
+
├── config.json
|
197 |
+
├── pytorch_model.bin
|
198 |
+
├── tokenizer.json
|
199 |
+
├── tokenizer_config.json
|
200 |
+
{{#if quantized_models}}
|
201 |
+
├── int8/ (quantized model for GPU)
|
202 |
+
│ ├── README.md
|
203 |
+
│ ├── config.json
|
204 |
+
│ └── pytorch_model.bin
|
205 |
+
└── int4/ (quantized model for CPU)
|
206 |
+
├── README.md
|
207 |
+
├── config.json
|
208 |
+
└── pytorch_model.bin
|
209 |
+
{{/if}}
|
210 |
+
```
|
211 |
+
|
212 |
+
## Usage Examples
|
213 |
+
|
214 |
+
### Text Generation
|
215 |
+
```python
|
216 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
217 |
+
|
218 |
+
model = AutoModelForCausalLM.from_pretrained("{{repo_name}}")
|
219 |
+
tokenizer = AutoTokenizer.from_pretrained("{{repo_name}}")
|
220 |
+
|
221 |
+
text = "The future of artificial intelligence is"
|
222 |
+
inputs = tokenizer(text, return_tensors="pt")
|
223 |
+
outputs = model.generate(**inputs, max_new_tokens=100)
|
224 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
225 |
+
```
|
226 |
+
|
227 |
+
### Conversation
|
228 |
+
```python
|
229 |
+
def chat_with_model(prompt, max_length=100):
|
230 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
231 |
+
outputs = model.generate(**inputs, max_new_tokens=max_length)
|
232 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
233 |
+
|
234 |
+
response = chat_with_model("Hello, how are you today?")
|
235 |
+
print(response)
|
236 |
+
```
|
237 |
+
|
238 |
+
### Advanced Usage
|
239 |
+
```python
|
240 |
+
# With generation parameters
|
241 |
+
outputs = model.generate(
|
242 |
+
**inputs,
|
243 |
+
max_new_tokens=100,
|
244 |
+
temperature=0.7,
|
245 |
+
top_p=0.9,
|
246 |
+
do_sample=True,
|
247 |
+
pad_token_id=tokenizer.eos_token_id
|
248 |
+
)
|
249 |
+
```
|
250 |
+
|
251 |
+
## Monitoring and Tracking
|
252 |
+
|
253 |
+
This model was trained with comprehensive monitoring:
|
254 |
+
- **Trackio Space**: {{trackio_url}}
|
255 |
+
- **Experiment**: {{experiment_name}}
|
256 |
+
- **Dataset Repository**: https://huggingface.co/datasets/{{dataset_repo}}
|
257 |
+
- **Training Logs**: Available in the experiment data
|
258 |
+
|
259 |
+
## Deployment
|
260 |
+
|
261 |
+
### Requirements
|
262 |
+
```bash
|
263 |
+
pip install torch transformers accelerate
|
264 |
+
{{#if quantized_models}}
|
265 |
+
pip install torchao # For quantized models
|
266 |
+
{{/if}}
|
267 |
+
```
|
268 |
+
|
269 |
+
### Hardware Requirements
|
270 |
+
- **Main Model**: GPU with 8GB+ VRAM recommended
|
271 |
+
{{#if quantized_models}}
|
272 |
+
- **int8 Model**: GPU with 4GB+ VRAM
|
273 |
+
- **int4 Model**: CPU deployment possible
|
274 |
+
{{/if}}
|
275 |
+
|
276 |
+
## Contributing
|
277 |
+
|
278 |
+
Contributions are welcome! Please:
|
279 |
+
1. Fork the repository
|
280 |
+
2. Create a feature branch
|
281 |
+
3. Make your changes
|
282 |
+
4. Submit a pull request
|
283 |
+
|
284 |
+
## Changelog
|
285 |
+
|
286 |
+
- **v1.0.0**: Initial release with fine-tuned model
|
287 |
+
{{#if quantized_models}}
|
288 |
+
- **v1.1.0**: Added quantized versions (int8, int4)
|
289 |
+
{{/if}}
|
templates/spaces/app.py
CHANGED
@@ -221,7 +221,12 @@ class TrackioSpace:
|
|
221 |
'learning_rate': 7e-08,
|
222 |
'num_tokens': 1642080.0,
|
223 |
'mean_token_accuracy': 0.7590958896279335,
|
224 |
-
'epoch': 0.004851130919895701
|
|
|
|
|
|
|
|
|
|
|
225 |
}
|
226 |
},
|
227 |
{
|
@@ -766,7 +771,7 @@ def update_experiment_status_interface(experiment_id: str, status: str) -> str:
|
|
766 |
return f"❌ Error updating experiment status: {str(e)}"
|
767 |
|
768 |
def create_metrics_plot(experiment_id: str, metric_name: str = "loss") -> go.Figure:
|
769 |
-
"""Create a plot for a specific metric"""
|
770 |
try:
|
771 |
df = get_metrics_dataframe(experiment_id)
|
772 |
if df.empty:
|
@@ -846,23 +851,44 @@ def create_experiment_comparison(experiment_ids: str) -> go.Figure:
|
|
846 |
def simulate_training_data(experiment_id: str):
|
847 |
"""Simulate training data for demonstration"""
|
848 |
try:
|
849 |
-
|
|
|
|
|
850 |
for step in range(0, 1000, 50):
|
851 |
# Simulate loss decreasing over time
|
852 |
loss = 2.0 * np.exp(-step / 500) + 0.1 * np.random.random()
|
853 |
accuracy = 0.3 + 0.6 * (1 - np.exp(-step / 300)) + 0.05 * np.random.random()
|
854 |
lr = 3.5e-6 * (0.9 ** (step // 200))
|
855 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
metrics = {
|
857 |
"loss": round(loss, 4),
|
858 |
"accuracy": round(accuracy, 4),
|
859 |
"learning_rate": round(lr, 8),
|
860 |
"gpu_memory": round(20 + 5 * np.random.random(), 2),
|
861 |
-
"training_time": round(0.5 + 0.2 * np.random.random(), 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
}
|
863 |
-
|
864 |
trackio_space.log_metrics(experiment_id, metrics, step)
|
865 |
-
|
866 |
return f"✅ Simulated training data for experiment {experiment_id}\nAdded 20 metric entries (steps 0-950)"
|
867 |
except Exception as e:
|
868 |
return f"❌ Error simulating data: {str(e)}"
|
@@ -1113,7 +1139,11 @@ with gr.Blocks(title="Trackio - Experiment Tracking", theme=gr.themes.Soft()) as
|
|
1113 |
)
|
1114 |
metric_dropdown = gr.Dropdown(
|
1115 |
label="Metric to Plot",
|
1116 |
-
choices=[
|
|
|
|
|
|
|
|
|
1117 |
value="loss"
|
1118 |
)
|
1119 |
plot_btn = gr.Button("Create Plot", variant="primary")
|
|
|
221 |
'learning_rate': 7e-08,
|
222 |
'num_tokens': 1642080.0,
|
223 |
'mean_token_accuracy': 0.7590958896279335,
|
224 |
+
'epoch': 0.004851130919895701,
|
225 |
+
'gpu_0_memory_allocated': 17.202261447906494,
|
226 |
+
'gpu_0_memory_reserved': 75.474609375,
|
227 |
+
'gpu_0_utilization': 0,
|
228 |
+
'cpu_percent': 2.7,
|
229 |
+
'memory_percent': 10.1
|
230 |
}
|
231 |
},
|
232 |
{
|
|
|
771 |
return f"❌ Error updating experiment status: {str(e)}"
|
772 |
|
773 |
def create_metrics_plot(experiment_id: str, metric_name: str = "loss") -> go.Figure:
|
774 |
+
"""Create a plot for a specific metric (supports all logged metrics, including new ones)"""
|
775 |
try:
|
776 |
df = get_metrics_dataframe(experiment_id)
|
777 |
if df.empty:
|
|
|
851 |
def simulate_training_data(experiment_id: str):
|
852 |
"""Simulate training data for demonstration"""
|
853 |
try:
|
854 |
+
import random
|
855 |
+
import time
|
856 |
+
last_time = time.time()
|
857 |
for step in range(0, 1000, 50):
|
858 |
# Simulate loss decreasing over time
|
859 |
loss = 2.0 * np.exp(-step / 500) + 0.1 * np.random.random()
|
860 |
accuracy = 0.3 + 0.6 * (1 - np.exp(-step / 300)) + 0.05 * np.random.random()
|
861 |
lr = 3.5e-6 * (0.9 ** (step // 200))
|
862 |
+
batch_size = 8
|
863 |
+
seq_len = 2048
|
864 |
+
total_tokens = batch_size * seq_len
|
865 |
+
padding_tokens = random.randint(0, batch_size * 32)
|
866 |
+
truncated_tokens = random.randint(0, batch_size * 8)
|
867 |
+
now = time.time()
|
868 |
+
step_time = random.uniform(0.4, 0.7)
|
869 |
+
throughput = total_tokens / step_time
|
870 |
+
token_acc = accuracy
|
871 |
+
gate_ortho = random.uniform(0.01, 0.05)
|
872 |
+
center = random.uniform(0.01, 0.05)
|
873 |
metrics = {
|
874 |
"loss": round(loss, 4),
|
875 |
"accuracy": round(accuracy, 4),
|
876 |
"learning_rate": round(lr, 8),
|
877 |
"gpu_memory": round(20 + 5 * np.random.random(), 2),
|
878 |
+
"training_time": round(0.5 + 0.2 * np.random.random(), 3),
|
879 |
+
"total_tokens": total_tokens,
|
880 |
+
"padding_tokens": padding_tokens,
|
881 |
+
"truncated_tokens": truncated_tokens,
|
882 |
+
"throughput": throughput,
|
883 |
+
"step_time": step_time,
|
884 |
+
"batch_size": batch_size,
|
885 |
+
"seq_len": seq_len,
|
886 |
+
"token_acc": token_acc,
|
887 |
+
"train/gate_ortho": gate_ortho,
|
888 |
+
"train/center": center
|
889 |
}
|
|
|
890 |
trackio_space.log_metrics(experiment_id, metrics, step)
|
891 |
+
last_time = now
|
892 |
return f"✅ Simulated training data for experiment {experiment_id}\nAdded 20 metric entries (steps 0-950)"
|
893 |
except Exception as e:
|
894 |
return f"❌ Error simulating data: {str(e)}"
|
|
|
1139 |
)
|
1140 |
metric_dropdown = gr.Dropdown(
|
1141 |
label="Metric to Plot",
|
1142 |
+
choices=[
|
1143 |
+
"loss", "accuracy", "learning_rate", "gpu_memory", "training_time",
|
1144 |
+
"total_tokens", "truncated_tokens", "padding_tokens", "throughput", "step_time",
|
1145 |
+
"batch_size", "seq_len", "token_acc", "train/gate_ortho", "train/center"
|
1146 |
+
],
|
1147 |
value="loss"
|
1148 |
)
|
1149 |
plot_btn = gr.Button("Create Plot", variant="primary")
|
test_config.py → tests/test_config.py
RENAMED
File without changes
|
test_mixed_precision.py → tests/test_mixed_precision.py
RENAMED
File without changes
|
test_pipeline.py → tests/test_pipeline_1.py
RENAMED
File without changes
|
tests/test_quantization.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for quantization functionality
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import tempfile
|
9 |
+
import shutil
|
10 |
+
from pathlib import Path
|
11 |
+
import logging
|
12 |
+
|
13 |
+
# Add the project root to the path
|
14 |
+
project_root = Path(__file__).parent.parent
|
15 |
+
sys.path.append(str(project_root))
|
16 |
+
|
17 |
+
from scripts.model_tonic.quantize_model import ModelQuantizer
|
18 |
+
|
19 |
+
def test_quantization_imports():
|
20 |
+
"""Test that all required imports are available"""
|
21 |
+
try:
|
22 |
+
import torch
|
23 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
24 |
+
from torchao.quantization import (
|
25 |
+
Int8WeightOnlyConfig,
|
26 |
+
Int4WeightOnlyConfig,
|
27 |
+
Int8DynamicActivationInt8WeightConfig
|
28 |
+
)
|
29 |
+
from torchao.dtypes import Int4CPULayout
|
30 |
+
print("✅ All quantization imports successful")
|
31 |
+
return True
|
32 |
+
except ImportError as e:
|
33 |
+
print(f"❌ Import error: {e}")
|
34 |
+
return False
|
35 |
+
|
36 |
+
def test_quantizer_initialization():
|
37 |
+
"""Test quantizer initialization"""
|
38 |
+
try:
|
39 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
40 |
+
# Create a dummy model directory
|
41 |
+
model_dir = Path(temp_dir) / "dummy_model"
|
42 |
+
model_dir.mkdir()
|
43 |
+
|
44 |
+
# Create minimal model files
|
45 |
+
(model_dir / "config.json").write_text('{"model_type": "test"}')
|
46 |
+
(model_dir / "pytorch_model.bin").write_text('dummy')
|
47 |
+
|
48 |
+
quantizer = ModelQuantizer(
|
49 |
+
model_path=str(model_dir),
|
50 |
+
repo_name="test/test-quantized",
|
51 |
+
token="dummy_token"
|
52 |
+
)
|
53 |
+
|
54 |
+
print("✅ Quantizer initialization successful")
|
55 |
+
return True
|
56 |
+
except Exception as e:
|
57 |
+
print(f"❌ Quantizer initialization failed: {e}")
|
58 |
+
return False
|
59 |
+
|
60 |
+
def test_quantization_config_creation():
|
61 |
+
"""Test quantization configuration creation"""
|
62 |
+
try:
|
63 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
64 |
+
model_dir = Path(temp_dir) / "dummy_model"
|
65 |
+
model_dir.mkdir()
|
66 |
+
(model_dir / "config.json").write_text('{"model_type": "test"}')
|
67 |
+
(model_dir / "pytorch_model.bin").write_text('dummy')
|
68 |
+
|
69 |
+
quantizer = ModelQuantizer(
|
70 |
+
model_path=str(model_dir),
|
71 |
+
repo_name="test/test-quantized",
|
72 |
+
token="dummy_token"
|
73 |
+
)
|
74 |
+
|
75 |
+
# Test int8 config
|
76 |
+
config_int8 = quantizer.create_quantization_config("int8_weight_only", 128)
|
77 |
+
print("✅ int8 config creation successful")
|
78 |
+
|
79 |
+
# Test int4 config
|
80 |
+
config_int4 = quantizer.create_quantization_config("int4_weight_only", 128)
|
81 |
+
print("✅ int4 config creation successful")
|
82 |
+
|
83 |
+
return True
|
84 |
+
except Exception as e:
|
85 |
+
print(f"❌ Config creation failed: {e}")
|
86 |
+
return False
|
87 |
+
|
88 |
+
def test_model_validation():
|
89 |
+
"""Test model path validation"""
|
90 |
+
try:
|
91 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
92 |
+
# Test with valid model
|
93 |
+
model_dir = Path(temp_dir) / "valid_model"
|
94 |
+
model_dir.mkdir()
|
95 |
+
(model_dir / "config.json").write_text('{"model_type": "test"}')
|
96 |
+
(model_dir / "pytorch_model.bin").write_text('dummy')
|
97 |
+
|
98 |
+
quantizer = ModelQuantizer(
|
99 |
+
model_path=str(model_dir),
|
100 |
+
repo_name="test/test-quantized",
|
101 |
+
token="dummy_token"
|
102 |
+
)
|
103 |
+
|
104 |
+
if quantizer.validate_model_path():
|
105 |
+
print("✅ Valid model validation successful")
|
106 |
+
else:
|
107 |
+
print("❌ Valid model validation failed")
|
108 |
+
return False
|
109 |
+
|
110 |
+
# Test with invalid model
|
111 |
+
invalid_dir = Path(temp_dir) / "invalid_model"
|
112 |
+
invalid_dir.mkdir()
|
113 |
+
# Missing required files
|
114 |
+
|
115 |
+
quantizer_invalid = ModelQuantizer(
|
116 |
+
model_path=str(invalid_dir),
|
117 |
+
repo_name="test/test-quantized",
|
118 |
+
token="dummy_token"
|
119 |
+
)
|
120 |
+
|
121 |
+
if not quantizer_invalid.validate_model_path():
|
122 |
+
print("✅ Invalid model validation successful")
|
123 |
+
else:
|
124 |
+
print("❌ Invalid model validation failed")
|
125 |
+
return False
|
126 |
+
|
127 |
+
return True
|
128 |
+
except Exception as e:
|
129 |
+
print(f"❌ Model validation test failed: {e}")
|
130 |
+
return False
|
131 |
+
|
132 |
+
def test_quantized_model_card_creation():
|
133 |
+
"""Test quantized model card creation"""
|
134 |
+
try:
|
135 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
136 |
+
model_dir = Path(temp_dir) / "dummy_model"
|
137 |
+
model_dir.mkdir()
|
138 |
+
(model_dir / "config.json").write_text('{"model_type": "test"}')
|
139 |
+
(model_dir / "pytorch_model.bin").write_text('dummy')
|
140 |
+
|
141 |
+
quantizer = ModelQuantizer(
|
142 |
+
model_path=str(model_dir),
|
143 |
+
repo_name="test/test-quantized",
|
144 |
+
token="dummy_token"
|
145 |
+
)
|
146 |
+
|
147 |
+
# Test int8 model card
|
148 |
+
card_int8 = quantizer.create_quantized_model_card("int8_weight_only", "test/model")
|
149 |
+
if "int8_weight_only" in card_int8 and "GPU" in card_int8:
|
150 |
+
print("✅ int8 model card creation successful")
|
151 |
+
else:
|
152 |
+
print("❌ int8 model card creation failed")
|
153 |
+
return False
|
154 |
+
|
155 |
+
# Test int4 model card
|
156 |
+
card_int4 = quantizer.create_quantized_model_card("int4_weight_only", "test/model")
|
157 |
+
if "int4_weight_only" in card_int4 and "CPU" in card_int4:
|
158 |
+
print("✅ int4 model card creation successful")
|
159 |
+
else:
|
160 |
+
print("❌ int4 model card creation failed")
|
161 |
+
return False
|
162 |
+
|
163 |
+
return True
|
164 |
+
except Exception as e:
|
165 |
+
print(f"❌ Model card creation test failed: {e}")
|
166 |
+
return False
|
167 |
+
|
168 |
+
def test_quantized_readme_creation():
|
169 |
+
"""Test quantized README creation"""
|
170 |
+
try:
|
171 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
172 |
+
model_dir = Path(temp_dir) / "dummy_model"
|
173 |
+
model_dir.mkdir()
|
174 |
+
(model_dir / "config.json").write_text('{"model_type": "test"}')
|
175 |
+
(model_dir / "pytorch_model.bin").write_text('dummy')
|
176 |
+
|
177 |
+
quantizer = ModelQuantizer(
|
178 |
+
model_path=str(model_dir),
|
179 |
+
repo_name="test/test-quantized",
|
180 |
+
token="dummy_token"
|
181 |
+
)
|
182 |
+
|
183 |
+
# Test int8 README
|
184 |
+
readme_int8 = quantizer.create_quantized_readme("int8_weight_only", "test/model")
|
185 |
+
if "int8_weight_only" in readme_int8 and "GPU optimized" in readme_int8:
|
186 |
+
print("✅ int8 README creation successful")
|
187 |
+
else:
|
188 |
+
print("❌ int8 README creation failed")
|
189 |
+
return False
|
190 |
+
|
191 |
+
# Test int4 README
|
192 |
+
readme_int4 = quantizer.create_quantized_readme("int4_weight_only", "test/model")
|
193 |
+
if "int4_weight_only" in readme_int4 and "CPU optimized" in readme_int4:
|
194 |
+
print("✅ int4 README creation successful")
|
195 |
+
else:
|
196 |
+
print("❌ int4 README creation failed")
|
197 |
+
return False
|
198 |
+
|
199 |
+
return True
|
200 |
+
except Exception as e:
|
201 |
+
print(f"❌ README creation test failed: {e}")
|
202 |
+
return False
|
203 |
+
|
204 |
+
def main():
|
205 |
+
"""Run all quantization tests"""
|
206 |
+
print("🧪 Running Quantization Tests")
|
207 |
+
print("=" * 40)
|
208 |
+
|
209 |
+
tests = [
|
210 |
+
("Import Test", test_quantization_imports),
|
211 |
+
("Initialization Test", test_quantizer_initialization),
|
212 |
+
("Config Creation Test", test_quantization_config_creation),
|
213 |
+
("Model Validation Test", test_model_validation),
|
214 |
+
("Model Card Test", test_quantized_model_card_creation),
|
215 |
+
("README Test", test_quantized_readme_creation),
|
216 |
+
]
|
217 |
+
|
218 |
+
passed = 0
|
219 |
+
total = len(tests)
|
220 |
+
|
221 |
+
for test_name, test_func in tests:
|
222 |
+
print(f"\n📋 Running {test_name}...")
|
223 |
+
try:
|
224 |
+
if test_func():
|
225 |
+
passed += 1
|
226 |
+
print(f"✅ {test_name} passed")
|
227 |
+
else:
|
228 |
+
print(f"❌ {test_name} failed")
|
229 |
+
except Exception as e:
|
230 |
+
print(f"❌ {test_name} failed with exception: {e}")
|
231 |
+
|
232 |
+
print("\n" + "=" * 40)
|
233 |
+
print(f"📊 Test Results: {passed}/{total} tests passed")
|
234 |
+
|
235 |
+
if passed == total:
|
236 |
+
print("🎉 All quantization tests passed!")
|
237 |
+
return 0
|
238 |
+
else:
|
239 |
+
print("⚠️ Some tests failed. Check the output above.")
|
240 |
+
return 1
|
241 |
+
|
242 |
+
if __name__ == "__main__":
|
243 |
+
# Setup logging
|
244 |
+
logging.basicConfig(
|
245 |
+
level=logging.INFO,
|
246 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
247 |
+
)
|
248 |
+
|
249 |
+
exit(main())
|
tests/test_trainer_selection.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script to verify trainer selection logic
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
# Add project root to path
|
11 |
+
project_root = Path(__file__).parent.parent
|
12 |
+
sys.path.insert(0, str(project_root))
|
13 |
+
sys.path.insert(0, str(project_root / "config"))
|
14 |
+
|
15 |
+
def test_config_trainer_type():
|
16 |
+
"""Test that config files have the correct trainer_type"""
|
17 |
+
print("Testing config trainer_type...")
|
18 |
+
|
19 |
+
# Test base config
|
20 |
+
from train_smollm3 import SmolLM3Config
|
21 |
+
base_config = SmolLM3Config()
|
22 |
+
assert base_config.trainer_type == "sft", f"Base config should have trainer_type='sft', got {base_config.trainer_type}"
|
23 |
+
print("✅ Base config trainer_type: sft")
|
24 |
+
|
25 |
+
# Test DPO config
|
26 |
+
from train_smollm3_dpo import SmolLM3DPOConfig
|
27 |
+
dpo_config = SmolLM3DPOConfig()
|
28 |
+
assert dpo_config.trainer_type == "dpo", f"DPO config should have trainer_type='dpo', got {dpo_config.trainer_type}"
|
29 |
+
print("✅ DPO config trainer_type: dpo")
|
30 |
+
|
31 |
+
return True
|
32 |
+
|
33 |
+
def test_trainer_classes_exist():
|
34 |
+
"""Test that trainer classes exist in the trainer module"""
|
35 |
+
print("Testing trainer class existence...")
|
36 |
+
|
37 |
+
try:
|
38 |
+
# Add src to path
|
39 |
+
sys.path.insert(0, str(project_root / "src"))
|
40 |
+
|
41 |
+
# Import trainer module
|
42 |
+
import trainer
|
43 |
+
print("✅ Trainer module imported successfully")
|
44 |
+
|
45 |
+
# Check if classes exist
|
46 |
+
assert hasattr(trainer, 'SmolLM3Trainer'), "SmolLM3Trainer class not found"
|
47 |
+
assert hasattr(trainer, 'SmolLM3DPOTrainer'), "SmolLM3DPOTrainer class not found"
|
48 |
+
print("✅ Both trainer classes exist")
|
49 |
+
|
50 |
+
return True
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
print(f"❌ Failed to check trainer classes: {e}")
|
54 |
+
return False
|
55 |
+
|
56 |
+
def test_config_inheritance():
|
57 |
+
"""Test that DPO config properly inherits from base config"""
|
58 |
+
print("Testing config inheritance...")
|
59 |
+
|
60 |
+
try:
|
61 |
+
from train_smollm3 import SmolLM3Config
|
62 |
+
from train_smollm3_dpo import SmolLM3DPOConfig
|
63 |
+
|
64 |
+
# Test that DPO config inherits from base config
|
65 |
+
base_config = SmolLM3Config()
|
66 |
+
dpo_config = SmolLM3DPOConfig()
|
67 |
+
|
68 |
+
# Check that DPO config has all base config fields
|
69 |
+
base_fields = set(base_config.__dict__.keys())
|
70 |
+
dpo_fields = set(dpo_config.__dict__.keys())
|
71 |
+
|
72 |
+
# DPO config should have all base fields plus DPO-specific ones
|
73 |
+
assert base_fields.issubset(dpo_fields), "DPO config missing base config fields"
|
74 |
+
print("✅ DPO config properly inherits from base config")
|
75 |
+
|
76 |
+
# Check that trainer_type is overridden correctly
|
77 |
+
assert dpo_config.trainer_type == "dpo", "DPO config should have trainer_type='dpo'"
|
78 |
+
assert base_config.trainer_type == "sft", "Base config should have trainer_type='sft'"
|
79 |
+
print("✅ Trainer type inheritance works correctly")
|
80 |
+
|
81 |
+
return True
|
82 |
+
|
83 |
+
except Exception as e:
|
84 |
+
print(f"❌ Failed to test config inheritance: {e}")
|
85 |
+
return False
|
86 |
+
|
87 |
+
def main():
|
88 |
+
"""Run all tests"""
|
89 |
+
print("🧪 Testing Trainer Selection Implementation")
|
90 |
+
print("=" * 50)
|
91 |
+
|
92 |
+
tests = [
|
93 |
+
test_config_trainer_type,
|
94 |
+
test_trainer_classes_exist,
|
95 |
+
test_config_inheritance,
|
96 |
+
]
|
97 |
+
|
98 |
+
passed = 0
|
99 |
+
total = len(tests)
|
100 |
+
|
101 |
+
for test in tests:
|
102 |
+
try:
|
103 |
+
if test():
|
104 |
+
passed += 1
|
105 |
+
else:
|
106 |
+
print(f"❌ Test {test.__name__} failed")
|
107 |
+
except Exception as e:
|
108 |
+
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
109 |
+
|
110 |
+
print("=" * 50)
|
111 |
+
print(f"Tests passed: {passed}/{total}")
|
112 |
+
|
113 |
+
if passed == total:
|
114 |
+
print("🎉 All tests passed!")
|
115 |
+
return 0
|
116 |
+
else:
|
117 |
+
print("❌ Some tests failed!")
|
118 |
+
return 1
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
exit(main())
|
test_training_fix.py → tests/test_training_fix_1.py
RENAMED
File without changes
|
tests/test_unified_model_card.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for the unified model card system
|
4 |
+
Verifies template processing, variable substitution, and conditional sections
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import tempfile
|
10 |
+
import shutil
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# Add the project root to the path
|
14 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
15 |
+
|
16 |
+
from scripts.model_tonic.generate_model_card import ModelCardGenerator
|
17 |
+
|
18 |
+
def test_basic_model_card():
|
19 |
+
"""Test basic model card generation without quantized models"""
|
20 |
+
print("🧪 Testing basic model card generation...")
|
21 |
+
|
22 |
+
# Create test variables
|
23 |
+
variables = {
|
24 |
+
"model_name": "Test SmolLM3 Model",
|
25 |
+
"model_description": "A test fine-tuned SmolLM3 model",
|
26 |
+
"repo_name": "test-user/test-model",
|
27 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
28 |
+
"dataset_name": "OpenHermes-FR",
|
29 |
+
"training_config_type": "H100 Lightweight",
|
30 |
+
"trainer_type": "SFTTrainer",
|
31 |
+
"batch_size": "8",
|
32 |
+
"gradient_accumulation_steps": "16",
|
33 |
+
"learning_rate": "5e-6",
|
34 |
+
"max_epochs": "3",
|
35 |
+
"max_seq_length": "2048",
|
36 |
+
"hardware_info": "GPU (H100)",
|
37 |
+
"experiment_name": "test-experiment",
|
38 |
+
"trackio_url": "https://trackio.space/test",
|
39 |
+
"dataset_repo": "test/trackio-experiments",
|
40 |
+
"dataset_size": "~80K samples",
|
41 |
+
"dataset_format": "Chat format",
|
42 |
+
"author_name": "Test User",
|
43 |
+
"model_name_slug": "test_smollm3_model",
|
44 |
+
"quantized_models": False,
|
45 |
+
"dataset_sample_size": "80000"
|
46 |
+
}
|
47 |
+
|
48 |
+
try:
|
49 |
+
# Create generator
|
50 |
+
generator = ModelCardGenerator()
|
51 |
+
|
52 |
+
# Generate model card
|
53 |
+
content = generator.generate_model_card(variables)
|
54 |
+
|
55 |
+
# Check that content was generated
|
56 |
+
assert content is not None
|
57 |
+
assert len(content) > 0
|
58 |
+
|
59 |
+
# Check that basic sections are present
|
60 |
+
assert "Test SmolLM3 Model" in content
|
61 |
+
assert "test-user/test-model" in content
|
62 |
+
assert "HuggingFaceTB/SmolLM3-3B" in content
|
63 |
+
|
64 |
+
# Check that quantized sections are NOT present
|
65 |
+
assert "Quantized Models" not in content
|
66 |
+
assert "int8" not in content
|
67 |
+
assert "int4" not in content
|
68 |
+
|
69 |
+
print("✅ Basic model card generation test passed")
|
70 |
+
return True
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
print(f"❌ Basic model card generation test failed: {e}")
|
74 |
+
return False
|
75 |
+
|
76 |
+
def test_quantized_model_card():
|
77 |
+
"""Test model card generation with quantized models"""
|
78 |
+
print("🧪 Testing quantized model card generation...")
|
79 |
+
|
80 |
+
# Create test variables with quantized models
|
81 |
+
variables = {
|
82 |
+
"model_name": "Test SmolLM3 Model with Quantization",
|
83 |
+
"model_description": "A test fine-tuned SmolLM3 model with quantized versions",
|
84 |
+
"repo_name": "test-user/test-model",
|
85 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
86 |
+
"dataset_name": "OpenHermes-FR",
|
87 |
+
"training_config_type": "H100 Lightweight",
|
88 |
+
"trainer_type": "SFTTrainer",
|
89 |
+
"batch_size": "8",
|
90 |
+
"gradient_accumulation_steps": "16",
|
91 |
+
"learning_rate": "5e-6",
|
92 |
+
"max_epochs": "3",
|
93 |
+
"max_seq_length": "2048",
|
94 |
+
"hardware_info": "GPU (H100)",
|
95 |
+
"experiment_name": "test-experiment",
|
96 |
+
"trackio_url": "https://trackio.space/test",
|
97 |
+
"dataset_repo": "test/trackio-experiments",
|
98 |
+
"dataset_size": "~80K samples",
|
99 |
+
"dataset_format": "Chat format",
|
100 |
+
"author_name": "Test User",
|
101 |
+
"model_name_slug": "test_smollm3_model",
|
102 |
+
"quantized_models": True,
|
103 |
+
"dataset_sample_size": "80000"
|
104 |
+
}
|
105 |
+
|
106 |
+
try:
|
107 |
+
# Create generator
|
108 |
+
generator = ModelCardGenerator()
|
109 |
+
|
110 |
+
# Generate model card
|
111 |
+
content = generator.generate_model_card(variables)
|
112 |
+
|
113 |
+
# Check that content was generated
|
114 |
+
assert content is not None
|
115 |
+
assert len(content) > 0
|
116 |
+
|
117 |
+
# Check that basic sections are present
|
118 |
+
assert "Test SmolLM3 Model with Quantization" in content
|
119 |
+
assert "test-user/test-model" in content
|
120 |
+
|
121 |
+
# Check that quantized sections ARE present
|
122 |
+
assert "Quantized Models" in content
|
123 |
+
assert "int8" in content
|
124 |
+
assert "int4" in content
|
125 |
+
assert "test-user/test-model/int8" in content
|
126 |
+
assert "test-user/test-model/int4" in content
|
127 |
+
|
128 |
+
print("✅ Quantized model card generation test passed")
|
129 |
+
return True
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
print(f"❌ Quantized model card generation test failed: {e}")
|
133 |
+
return False
|
134 |
+
|
135 |
+
def test_template_processing():
|
136 |
+
"""Test template processing and variable substitution"""
|
137 |
+
print("🧪 Testing template processing...")
|
138 |
+
|
139 |
+
try:
|
140 |
+
# Create generator
|
141 |
+
generator = ModelCardGenerator()
|
142 |
+
|
143 |
+
# Test variable substitution
|
144 |
+
test_variables = {
|
145 |
+
"model_name": "Test Model",
|
146 |
+
"repo_name": "test/repo",
|
147 |
+
"quantized_models": True
|
148 |
+
}
|
149 |
+
|
150 |
+
# Generate content
|
151 |
+
content = generator.generate_model_card(test_variables)
|
152 |
+
|
153 |
+
# Check variable substitution
|
154 |
+
assert "Test Model" in content
|
155 |
+
assert "test/repo" in content
|
156 |
+
|
157 |
+
# Check conditional processing
|
158 |
+
assert "Quantized Models" in content
|
159 |
+
|
160 |
+
print("✅ Template processing test passed")
|
161 |
+
return True
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
print(f"❌ Template processing test failed: {e}")
|
165 |
+
return False
|
166 |
+
|
167 |
+
def test_file_saving():
|
168 |
+
"""Test saving generated model cards to files"""
|
169 |
+
print("🧪 Testing file saving...")
|
170 |
+
|
171 |
+
try:
|
172 |
+
# Create temporary directory
|
173 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
174 |
+
output_path = os.path.join(temp_dir, "test_readme.md")
|
175 |
+
|
176 |
+
# Create generator
|
177 |
+
generator = ModelCardGenerator()
|
178 |
+
|
179 |
+
# Test variables
|
180 |
+
variables = {
|
181 |
+
"model_name": "Test Model",
|
182 |
+
"model_description": "Test description",
|
183 |
+
"repo_name": "test/repo",
|
184 |
+
"base_model": "HuggingFaceTB/SmolLM3-3B",
|
185 |
+
"dataset_name": "Test Dataset",
|
186 |
+
"training_config_type": "Test Config",
|
187 |
+
"trainer_type": "SFTTrainer",
|
188 |
+
"batch_size": "8",
|
189 |
+
"gradient_accumulation_steps": "16",
|
190 |
+
"learning_rate": "5e-6",
|
191 |
+
"max_epochs": "3",
|
192 |
+
"max_seq_length": "2048",
|
193 |
+
"hardware_info": "GPU",
|
194 |
+
"experiment_name": "test-exp",
|
195 |
+
"trackio_url": "https://trackio.space/test",
|
196 |
+
"dataset_repo": "test/dataset",
|
197 |
+
"dataset_size": "1K samples",
|
198 |
+
"dataset_format": "Chat format",
|
199 |
+
"author_name": "Test User",
|
200 |
+
"model_name_slug": "test_model",
|
201 |
+
"quantized_models": False,
|
202 |
+
"dataset_sample_size": None
|
203 |
+
}
|
204 |
+
|
205 |
+
# Generate and save
|
206 |
+
content = generator.generate_model_card(variables)
|
207 |
+
success = generator.save_model_card(content, output_path)
|
208 |
+
|
209 |
+
# Check that file was created
|
210 |
+
assert success
|
211 |
+
assert os.path.exists(output_path)
|
212 |
+
|
213 |
+
# Check file content
|
214 |
+
with open(output_path, 'r', encoding='utf-8') as f:
|
215 |
+
saved_content = f.read()
|
216 |
+
|
217 |
+
assert "Test Model" in saved_content
|
218 |
+
assert "test/repo" in saved_content
|
219 |
+
|
220 |
+
print("✅ File saving test passed")
|
221 |
+
return True
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
print(f"❌ File saving test failed: {e}")
|
225 |
+
return False
|
226 |
+
|
227 |
+
def test_error_handling():
|
228 |
+
"""Test error handling for missing template and invalid variables"""
|
229 |
+
print("🧪 Testing error handling...")
|
230 |
+
|
231 |
+
try:
|
232 |
+
# Test with non-existent template
|
233 |
+
try:
|
234 |
+
generator = ModelCardGenerator("non_existent_template.md")
|
235 |
+
content = generator.generate_model_card({})
|
236 |
+
assert False, "Should have raised FileNotFoundError"
|
237 |
+
except FileNotFoundError:
|
238 |
+
print("✅ Correctly handled missing template")
|
239 |
+
|
240 |
+
# Test with minimal variables
|
241 |
+
generator = ModelCardGenerator()
|
242 |
+
content = generator.generate_model_card({})
|
243 |
+
|
244 |
+
# Should still generate some content
|
245 |
+
assert content is not None
|
246 |
+
assert len(content) > 0
|
247 |
+
|
248 |
+
print("✅ Error handling test passed")
|
249 |
+
return True
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
print(f"❌ Error handling test failed: {e}")
|
253 |
+
return False
|
254 |
+
|
255 |
+
def main():
|
256 |
+
"""Run all tests"""
|
257 |
+
print("🚀 Starting unified model card system tests...")
|
258 |
+
print("=" * 50)
|
259 |
+
|
260 |
+
tests = [
|
261 |
+
test_basic_model_card,
|
262 |
+
test_quantized_model_card,
|
263 |
+
test_template_processing,
|
264 |
+
test_file_saving,
|
265 |
+
test_error_handling
|
266 |
+
]
|
267 |
+
|
268 |
+
passed = 0
|
269 |
+
total = len(tests)
|
270 |
+
|
271 |
+
for test in tests:
|
272 |
+
try:
|
273 |
+
if test():
|
274 |
+
passed += 1
|
275 |
+
except Exception as e:
|
276 |
+
print(f"❌ Test {test.__name__} failed with exception: {e}")
|
277 |
+
|
278 |
+
print("=" * 50)
|
279 |
+
print(f"📊 Test Results: {passed}/{total} tests passed")
|
280 |
+
|
281 |
+
if passed == total:
|
282 |
+
print("🎉 All tests passed! Unified model card system is working correctly.")
|
283 |
+
return 0
|
284 |
+
else:
|
285 |
+
print("⚠️ Some tests failed. Please check the implementation.")
|
286 |
+
return 1
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
exit(main())
|