Tonic commited on
Commit
d7d1377
Β·
verified Β·
1 Parent(s): 235d769

fixes variable cases sft/dpo

Browse files
Files changed (2) hide show
  1. launch.sh +4 -1
  2. tests/test_trainer_type_fix.py +169 -0
launch.sh CHANGED
@@ -432,6 +432,9 @@ echo ""
432
 
433
  select_option "Select trainer type:" "SFT" "DPO" TRAINER_TYPE
434
 
 
 
 
435
  # Step 4: Training parameters
436
  print_step "Step 4: Training Parameters"
437
  echo "==============================="
@@ -696,7 +699,7 @@ python scripts/training/train.py \
696
  --experiment-name "$EXPERIMENT_NAME" \
697
  --output-dir /output-checkpoint \
698
  --trackio-url "$TRACKIO_URL" \
699
- --trainer-type "$TRAINER_TYPE"
700
 
701
  # Step 16: Push model to Hugging Face Hub
702
  print_step "Step 16: Pushing Model to HF Hub"
 
432
 
433
  select_option "Select trainer type:" "SFT" "DPO" TRAINER_TYPE
434
 
435
+ # Convert trainer type to lowercase for the training script
436
+ TRAINER_TYPE_LOWER=$(echo "$TRAINER_TYPE" | tr '[:upper:]' '[:lower:]')
437
+
438
  # Step 4: Training parameters
439
  print_step "Step 4: Training Parameters"
440
  echo "==============================="
 
699
  --experiment-name "$EXPERIMENT_NAME" \
700
  --output-dir /output-checkpoint \
701
  --trackio-url "$TRACKIO_URL" \
702
+ --trainer-type "$TRAINER_TYPE_LOWER"
703
 
704
  # Step 16: Push model to Hugging Face Hub
705
  print_step "Step 16: Pushing Model to HF Hub"
tests/test_trainer_type_fix.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify trainer type conversion works correctly
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ from pathlib import Path
10
+
11
+ def test_trainer_type_conversion():
12
+ """Test that trainer type is converted to lowercase correctly"""
13
+ print("πŸ” Testing Trainer Type Conversion")
14
+ print("=" * 50)
15
+
16
+ # Test cases
17
+ test_cases = [
18
+ ("SFT", "sft"),
19
+ ("DPO", "dpo"),
20
+ ("sft", "sft"),
21
+ ("dpo", "dpo")
22
+ ]
23
+
24
+ all_passed = True
25
+ for input_type, expected_output in test_cases:
26
+ # Simulate the bash conversion: echo "$TRAINER_TYPE" | tr '[:upper:]' '[:lower:]'
27
+ converted = input_type.lower()
28
+
29
+ if converted == expected_output:
30
+ print(f"βœ… '{input_type}' -> '{converted}' (expected: '{expected_output}')")
31
+ else:
32
+ print(f"❌ '{input_type}' -> '{converted}' (expected: '{expected_output}')")
33
+ all_passed = False
34
+
35
+ return all_passed
36
+
37
+ def test_launch_script_trainer_type():
38
+ """Test that launch script handles trainer type correctly"""
39
+ print("\nπŸ” Testing Launch Script Trainer Type Handling")
40
+ print("=" * 50)
41
+
42
+ # Check if launch.sh exists
43
+ launch_script = Path("launch.sh")
44
+ if not launch_script.exists():
45
+ print("❌ launch.sh not found")
46
+ return False
47
+
48
+ # Read launch script and check for trainer type handling
49
+ script_content = launch_script.read_text(encoding='utf-8')
50
+
51
+ # Check for trainer type conversion
52
+ conversion_patterns = [
53
+ 'TRAINER_TYPE_LOWER=$(echo "$TRAINER_TYPE" | tr \'[:upper:]\' \'[:lower:]\')',
54
+ '--trainer-type "$TRAINER_TYPE_LOWER"'
55
+ ]
56
+
57
+ all_found = True
58
+ for pattern in conversion_patterns:
59
+ if pattern in script_content:
60
+ print(f"βœ… Found: {pattern}")
61
+ else:
62
+ print(f"❌ Missing: {pattern}")
63
+ all_found = False
64
+
65
+ # Check that old pattern is removed
66
+ old_pattern = '--trainer-type "$TRAINER_TYPE"'
67
+ if old_pattern in script_content:
68
+ print(f"❌ Found old pattern (should be updated): {old_pattern}")
69
+ all_found = False
70
+ else:
71
+ print(f"βœ… Old pattern removed: {old_pattern}")
72
+
73
+ return all_found
74
+
75
+ def test_training_script_validation():
76
+ """Test that training script accepts the correct trainer types"""
77
+ print("\nπŸ” Testing Training Script Validation")
78
+ print("=" * 50)
79
+
80
+ # Check if training script exists
81
+ training_script = Path("scripts/training/train.py")
82
+ if not training_script.exists():
83
+ print("❌ Training script not found")
84
+ return False
85
+
86
+ # Read training script and check for argument validation
87
+ script_content = training_script.read_text(encoding='utf-8')
88
+
89
+ # Check for trainer type argument definition
90
+ if '--trainer-type' in script_content:
91
+ print("βœ… Found trainer-type argument in training script")
92
+ else:
93
+ print("❌ Missing trainer-type argument in training script")
94
+ return False
95
+
96
+ # Check for valid choices
97
+ if 'sft' in script_content and 'dpo' in script_content:
98
+ print("βœ… Found valid trainer type choices: sft, dpo")
99
+ else:
100
+ print("❌ Missing valid trainer type choices")
101
+ return False
102
+
103
+ return True
104
+
105
+ def test_trainer_type_integration():
106
+ """Test that trainer type integration works end-to-end"""
107
+ print("\nπŸ” Testing Trainer Type Integration")
108
+ print("=" * 50)
109
+
110
+ # Test the conversion logic
111
+ test_cases = [
112
+ ("SFT", "sft"),
113
+ ("DPO", "dpo")
114
+ ]
115
+
116
+ all_passed = True
117
+ for input_type, expected_output in test_cases:
118
+ # Simulate the bash conversion
119
+ converted = input_type.lower()
120
+
121
+ # Check if the converted value is valid for the training script
122
+ valid_types = ["sft", "dpo"]
123
+
124
+ if converted in valid_types:
125
+ print(f"βœ… '{input_type}' -> '{converted}' (valid for training script)")
126
+ else:
127
+ print(f"❌ '{input_type}' -> '{converted}' (invalid for training script)")
128
+ all_passed = False
129
+
130
+ return all_passed
131
+
132
+ def main():
133
+ """Run all trainer type fix tests"""
134
+ print("πŸš€ Trainer Type Fix Verification")
135
+ print("=" * 50)
136
+
137
+ tests = [
138
+ test_trainer_type_conversion,
139
+ test_launch_script_trainer_type,
140
+ test_training_script_validation,
141
+ test_trainer_type_integration
142
+ ]
143
+
144
+ all_passed = True
145
+ for test in tests:
146
+ try:
147
+ if not test():
148
+ all_passed = False
149
+ except Exception as e:
150
+ print(f"❌ Test failed with error: {e}")
151
+ all_passed = False
152
+
153
+ print("\n" + "=" * 50)
154
+ if all_passed:
155
+ print("πŸŽ‰ ALL TRAINER TYPE FIX TESTS PASSED!")
156
+ print("βœ… Trainer type conversion: Working")
157
+ print("βœ… Launch script handling: Working")
158
+ print("βœ… Training script validation: Working")
159
+ print("βœ… Integration: Working")
160
+ print("\nThe trainer type fix is working correctly!")
161
+ else:
162
+ print("❌ SOME TRAINER TYPE FIX TESTS FAILED!")
163
+ print("Please check the failed tests above.")
164
+
165
+ return all_passed
166
+
167
+ if __name__ == "__main__":
168
+ success = main()
169
+ sys.exit(0 if success else 1)