Tonic commited on
Commit
0ded6bb
·
1 Parent(s): b11b94b

adds improved launch for reasoning gpt-oss configs and new config for medical reasoning

Browse files
config/train_gpt_oss_custom.py CHANGED
@@ -109,6 +109,9 @@ class GPTOSSEnhancedCustomConfig:
109
  # Field Mapping - Customize for your dataset format
110
  input_field: str = "prompt" # Field containing the input/prompt
111
  target_field: str = "accepted_completion" # Field containing the target/completion
 
 
 
112
 
113
  # OpenHermes-FR specific fields
114
  filter_bad_entries: bool = True # Filter entries marked as bad
@@ -127,7 +130,14 @@ class GPTOSSEnhancedCustomConfig:
127
  max_length: Optional[int] = None # Maximum sequence length (None = use max_seq_length)
128
 
129
  # Custom Dataset Formats Support
130
- dataset_format: str = "openhermes_fr" # "openhermes_fr", "messages", "text", "custom"
 
 
 
 
 
 
 
131
 
132
  # GPT-OSS Harmony Format Configuration
133
  use_harmony_format: bool = True # Enable GPT-OSS harmony format
@@ -344,7 +354,7 @@ class GPTOSSEnhancedCustomConfig:
344
  raise ValueError("max_seq_length must be >= 1")
345
 
346
  # Validate dataset format
347
- valid_formats = ["openhermes_fr", "messages", "text", "custom"]
348
  if self.dataset_format not in valid_formats:
349
  raise ValueError(f"dataset_format must be one of {valid_formats}")
350
 
@@ -383,6 +393,12 @@ class GPTOSSEnhancedCustomConfig:
383
  print(f" • Target Field: {self.target_field}")
384
  print(f" • Filter Bad Entries: {self.filter_bad_entries}")
385
  print(f" • Max Samples: {self.max_samples or 'All'}")
 
 
 
 
 
 
386
 
387
  print(f"\n💾 Memory & Performance:")
388
  print(f" • Mixed Precision: {'BF16' if self.bf16 else 'FP32'}")
 
109
  # Field Mapping - Customize for your dataset format
110
  input_field: str = "prompt" # Field containing the input/prompt
111
  target_field: str = "accepted_completion" # Field containing the target/completion
112
+ # Optional global conversational context
113
+ system_message: Optional[str] = None
114
+ developer_message: Optional[str] = None
115
 
116
  # OpenHermes-FR specific fields
117
  filter_bad_entries: bool = True # Filter entries marked as bad
 
130
  max_length: Optional[int] = None # Maximum sequence length (None = use max_seq_length)
131
 
132
  # Custom Dataset Formats Support
133
+ dataset_format: str = "openhermes_fr" # "openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"
134
+
135
+ # Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT) mapping
136
+ question_field: str = "Question"
137
+ reasoning_field: str = "Complex_CoT"
138
+ response_field: str = "Response"
139
+ reason_prefix: str = "Reasoning: "
140
+ answer_prefix: str = "Final Answer: "
141
 
142
  # GPT-OSS Harmony Format Configuration
143
  use_harmony_format: bool = True # Enable GPT-OSS harmony format
 
354
  raise ValueError("max_seq_length must be >= 1")
355
 
356
  # Validate dataset format
357
+ valid_formats = ["openhermes_fr", "messages", "text", "custom", "medical_o1_sft", "preference"]
358
  if self.dataset_format not in valid_formats:
359
  raise ValueError(f"dataset_format must be one of {valid_formats}")
360
 
 
393
  print(f" • Target Field: {self.target_field}")
394
  print(f" • Filter Bad Entries: {self.filter_bad_entries}")
395
  print(f" • Max Samples: {self.max_samples or 'All'}")
396
+ if self.system_message or self.developer_message:
397
+ print(" • Context messages set:")
398
+ if self.system_message:
399
+ print(" - system message: provided")
400
+ if self.developer_message:
401
+ print(" - developer message: provided")
402
 
403
  print(f"\n💾 Memory & Performance:")
404
  print(f" • Mixed Precision: {'BF16' if self.bf16 else 'FP32'}")
config/train_gpt_oss_medical_o1_sft.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-OSS Medical o1 SFT Training Configuration
3
+ Dataset: FreedomIntelligence/medical-o1-reasoning-SFT
4
+ Format: Question | Complex_CoT | Response → GPT-OSS Harmony text
5
+
6
+ This configuration uses GPT-OSS Harmony formatting to combine the medical
7
+ dataset's question, chain-of-thought (Complex_CoT), and final response into a
8
+ single assistant turn, with optional system and developer messages.
9
+ """
10
+
11
+ from config.train_gpt_oss_custom import GPTOSSEnhancedCustomConfig
12
+
13
+ # Medical-o1 SFT configuration for GPT-OSS
14
+ config = GPTOSSEnhancedCustomConfig(
15
+ # ============================================================================
16
+ # DATASET CONFIGURATION
17
+ # ============================================================================
18
+ dataset_name="FreedomIntelligence/medical-o1-reasoning-SFT",
19
+ dataset_config="en", # Use English split by default (can be changed to en_mix/zh/zh_mix)
20
+ dataset_split="train",
21
+ dataset_format="medical_o1_sft", # Enable medical formatter in training script
22
+
23
+ # Field mapping and prefixes
24
+ input_field="Question", # used for length filtering pre-format
25
+ target_field="Response", # used for length filtering pre-format
26
+ question_field="Question",
27
+ reasoning_field="Complex_CoT",
28
+ response_field="Response",
29
+ reason_prefix="Reasoning: ",
30
+ answer_prefix="Final Answer: ",
31
+
32
+ # GPT-OSS Harmony formatting
33
+ use_harmony_format=True,
34
+ use_chat_template=False,
35
+ system_message=(
36
+ "You are GPT-Tonic, a large language model trained by TonicAI."
37
+ ),
38
+ developer_message=(
39
+ "You are an intelligent assistant that can answer customer service queries"
40
+ ),
41
+ chat_template_kwargs={
42
+ "add_generation_prompt": True,
43
+ "tokenize": False,
44
+ "reasoning_effort": "low",
45
+ "model_identity": "You are GPT-Tonic, a large language model trained by TonicAI.",
46
+ "builtin_tools": [],
47
+ },
48
+
49
+ # Filtering & sampling
50
+ filter_bad_entries=False,
51
+ max_samples=None,
52
+ min_length=10,
53
+ max_length=2048,
54
+
55
+ # ============================================================================
56
+ # TRAINING HYPERPARAMETERS
57
+ # ============================================================================
58
+ num_train_epochs=1.0,
59
+ batch_size=2,
60
+ gradient_accumulation_steps=8,
61
+ learning_rate=2e-4,
62
+ min_lr=2e-5,
63
+ weight_decay=0.01,
64
+ warmup_ratio=0.03,
65
+ max_grad_norm=1.0,
66
+
67
+ # Sequence length
68
+ max_seq_length=2048,
69
+
70
+ # ============================================================================
71
+ # MIXED PRECISION / PERFORMANCE
72
+ # ============================================================================
73
+ fp16=False,
74
+ bf16=True,
75
+ tf32=True,
76
+
77
+ dataloader_num_workers=4,
78
+ dataloader_pin_memory=True,
79
+ dataloader_prefetch_factor=2,
80
+ dataset_num_proc=4,
81
+ group_by_length=True,
82
+ remove_unused_columns=True,
83
+
84
+ # ============================================================================
85
+ # LORA & QUANTIZATION
86
+ # ============================================================================
87
+ use_lora=True,
88
+ lora_config={
89
+ "r": 8,
90
+ "lora_alpha": 16,
91
+ "lora_dropout": 0.05,
92
+ "target_modules": "all-linear",
93
+ "target_parameters": [
94
+ "7.mlp.experts.gate_up_proj",
95
+ "7.mlp.experts.down_proj",
96
+ "15.mlp.experts.gate_up_proj",
97
+ "15.mlp.experts.down_proj",
98
+ "23.mlp.experts.gate_up_proj",
99
+ "23.mlp.experts.down_proj",
100
+ ],
101
+ "bias": "none",
102
+ "task_type": "CAUSAL_LM",
103
+ },
104
+
105
+ use_quantization=True,
106
+ quantization_config={
107
+ "dequantize": True,
108
+ "load_in_4bit": False,
109
+ # Optional MXFP4 config is auto-applied by training script if available
110
+ },
111
+
112
+ # ============================================================================
113
+ # LOGGING & EVAL
114
+ # ============================================================================
115
+ eval_strategy="steps",
116
+ eval_steps=200,
117
+ logging_steps=10,
118
+ save_strategy="steps",
119
+ save_steps=500,
120
+ save_total_limit=3,
121
+ save_only_model=True,
122
+ metric_for_best_model="eval_loss",
123
+ greater_is_better=False,
124
+ load_best_model_at_end=False,
125
+ eval_accumulation_steps=2,
126
+ eval_batch_size=1,
127
+ eval_ratio=0.01,
128
+ test_ratio=0.01,
129
+
130
+ # ============================================================================
131
+ # MONITORING & HUB
132
+ # ============================================================================
133
+ enable_tracking=True,
134
+ log_artifacts=False,
135
+ log_metrics=True,
136
+ log_config=True,
137
+ push_to_hub=False,
138
+ hub_model_id=None,
139
+ hub_private_repo=False,
140
+ )
141
+
142
+ # Quick summary for visibility when the config is imported
143
+ print("\n🩺 GPT-OSS Medical o1 SFT Configuration")
144
+ print("=" * 60)
145
+ print(f"📊 Dataset: {config.dataset_name} [{config.dataset_config}] (medical_o1_sft)")
146
+ print(f"📈 Training: {config.num_train_epochs} epoch | batch {config.batch_size} x acc {config.gradient_accumulation_steps}")
147
+ print(f"🧠 LoRA Rank: {config.lora_config['r']}")
148
+ print(f"📏 Sequence Length: {config.max_seq_length}")
149
+ print(f"🎵 Harmony Format: {'Enabled' if config.use_harmony_format else 'Disabled'}")
150
+ print("=" * 60)
151
+
launch.sh CHANGED
@@ -267,6 +267,12 @@ show_training_configs() {
267
  echo " - Learning Rate: Configurable"
268
  echo " - Maximum flexibility with all parameters"
269
  echo ""
 
 
 
 
 
 
270
  fi
271
  }
272
 
@@ -376,6 +382,17 @@ get_training_config() {
376
  MAX_SEQ_LENGTH=1024
377
  CONFIG_FILE="config/train_gpt_oss_openhermes_fr_memory_optimized.py"
378
  ;;
 
 
 
 
 
 
 
 
 
 
 
379
  "GPT-OSS Custom Dataset")
380
  MODEL_NAME="openai/gpt-oss-20b"
381
  DATASET_NAME="legmlai/openhermes-fr" # Will be customizable
@@ -411,10 +428,11 @@ get_custom_dataset_config() {
411
  echo "1. OpenHermes-FR (prompt + accepted_completion fields)"
412
  echo "2. Messages format (chat conversations)"
413
  echo "3. Text format (plain text field)"
414
- echo "4. Custom format (specify field names)"
 
415
  echo ""
416
-
417
- select_option "Select dataset format:" "OpenHermes-FR" "Messages format" "Text format" "Custom format" DATASET_FORMAT
418
 
419
  case "$DATASET_FORMAT" in
420
  "OpenHermes-FR")
@@ -435,6 +453,18 @@ get_custom_dataset_config() {
435
  DATASET_FORMAT_CODE="text"
436
  FILTER_BAD_ENTRIES="false"
437
  ;;
 
 
 
 
 
 
 
 
 
 
 
 
438
  "Custom format")
439
  get_input "Input field name" "prompt" INPUT_FIELD
440
  get_input "Target field name (leave empty if not needed)" "accepted_completion" TARGET_FIELD
@@ -442,6 +472,12 @@ get_custom_dataset_config() {
442
  get_input "Filter bad entries? (true/false)" "false" FILTER_BAD_ENTRIES
443
  ;;
444
  esac
 
 
 
 
 
 
445
 
446
  # Dataset Filtering Options
447
  echo ""
@@ -492,6 +528,22 @@ get_custom_dataset_config() {
492
  update_enhanced_gpt_oss_config
493
  }
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  # Function to get custom configuration
496
  get_custom_config() {
497
  print_step "Custom Configuration Setup"
@@ -574,6 +626,18 @@ config = GPTOSSEnhancedCustomConfig(
574
  min_length=$MIN_LENGTH,
575
  max_length=$(if [ -n "$MAX_LENGTH" ]; then echo "$MAX_LENGTH"; else echo "None"; fi),
576
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  # ============================================================================
578
  # TRAINING HYPERPARAMETERS
579
  # ============================================================================
@@ -811,6 +875,7 @@ else
811
  "GPT-OSS OpenHermes-FR (Recommended)" \
812
  "GPT-OSS OpenHermes-FR Memory Optimized" \
813
  "GPT-OSS Custom Dataset" \
 
814
  TRAINING_CONFIG_TYPE
815
  fi
816
 
 
267
  echo " - Learning Rate: Configurable"
268
  echo " - Maximum flexibility with all parameters"
269
  echo ""
270
+ echo "8. GPT-OSS Medical o1 SFT (Reasoning)"
271
+ echo " - Model: openai/gpt-oss-20b"
272
+ echo " - Dataset: FreedomIntelligence/medical-o1-reasoning-SFT"
273
+ echo " - Format: Question | Complex_CoT | Response"
274
+ echo " - Harmony formatting with optional system/developer messages"
275
+ echo ""
276
  fi
277
  }
278
 
 
382
  MAX_SEQ_LENGTH=1024
383
  CONFIG_FILE="config/train_gpt_oss_openhermes_fr_memory_optimized.py"
384
  ;;
385
+ "GPT-OSS Medical o1 SFT (Reasoning)")
386
+ MODEL_NAME="openai/gpt-oss-20b"
387
+ DATASET_NAME="FreedomIntelligence/medical-o1-reasoning-SFT"
388
+ MAX_EPOCHS=1
389
+ BATCH_SIZE=2
390
+ GRADIENT_ACCUMULATION_STEPS=8
391
+ LEARNING_RATE=2e-4
392
+ MAX_SEQ_LENGTH=2048
393
+ CONFIG_FILE="config/train_gpt_oss_medical_o1_sft.py"
394
+ generate_medical_o1_sft_config
395
+ ;;
396
  "GPT-OSS Custom Dataset")
397
  MODEL_NAME="openai/gpt-oss-20b"
398
  DATASET_NAME="legmlai/openhermes-fr" # Will be customizable
 
428
  echo "1. OpenHermes-FR (prompt + accepted_completion fields)"
429
  echo "2. Messages format (chat conversations)"
430
  echo "3. Text format (plain text field)"
431
+ echo "4. Medical o1 SFT (Question | Complex_CoT | Response)"
432
+ echo "5. Custom format (specify field names)"
433
  echo ""
434
+
435
+ select_option "Select dataset format:" "OpenHermes-FR" "Messages format" "Text format" "Medical o1 SFT" "Custom format" DATASET_FORMAT
436
 
437
  case "$DATASET_FORMAT" in
438
  "OpenHermes-FR")
 
453
  DATASET_FORMAT_CODE="text"
454
  FILTER_BAD_ENTRIES="false"
455
  ;;
456
+ "Medical o1 SFT")
457
+ INPUT_FIELD="Question"
458
+ TARGET_FIELD="Response"
459
+ DATASET_FORMAT_CODE="medical_o1_sft"
460
+ FILTER_BAD_ENTRIES="false"
461
+ # Field mappings and prefixes
462
+ get_input "Question field name" "Question" MED_Q_FIELD
463
+ get_input "Reasoning field name" "Complex_CoT" MED_REASON_FIELD
464
+ get_input "Response field name" "Response" MED_RESP_FIELD
465
+ get_input "Reason prefix (before reasoning)" "Reasoning: " MED_REASON_PREFIX
466
+ get_input "Answer prefix (before final answer)" "Final Answer: " MED_ANSWER_PREFIX
467
+ ;;
468
  "Custom format")
469
  get_input "Input field name" "prompt" INPUT_FIELD
470
  get_input "Target field name (leave empty if not needed)" "accepted_completion" TARGET_FIELD
 
472
  get_input "Filter bad entries? (true/false)" "false" FILTER_BAD_ENTRIES
473
  ;;
474
  esac
475
+
476
+ # Optional Harmony context
477
+ echo ""
478
+ print_info "💬 Harmony Context (optional)"
479
+ get_input "System message" "You are GPT-Tonic, a large language model trained by TonicAI." SYSTEM_MESSAGE
480
+ get_input "Developer message" "You are an intelligent assistant that can answer customer service queries" DEVELOPER_MESSAGE
481
 
482
  # Dataset Filtering Options
483
  echo ""
 
528
  update_enhanced_gpt_oss_config
529
  }
530
 
531
+ # Function to materialize a default Medical o1 SFT config file
532
+ generate_medical_o1_sft_config() {
533
+ print_info "Ensuring medical o1 SFT configuration exists..."
534
+ if [ -f "config/train_gpt_oss_medical_o1_sft.py" ]; then
535
+ print_status "Medical o1 SFT config already present"
536
+ return
537
+ fi
538
+ cat > config/train_gpt_oss_medical_o1_sft.py << 'EOF'
539
+ """
540
+ Auto-generated placeholder. A richer version will be imported at runtime.
541
+ """
542
+ from config.train_gpt_oss_medical_o1_sft import config # reuse main config
543
+ EOF
544
+ print_status "Medical o1 SFT config placeholder created"
545
+ }
546
+
547
  # Function to get custom configuration
548
  get_custom_config() {
549
  print_step "Custom Configuration Setup"
 
626
  min_length=$MIN_LENGTH,
627
  max_length=$(if [ -n "$MAX_LENGTH" ]; then echo "$MAX_LENGTH"; else echo "None"; fi),
628
 
629
+ # Harmony context
630
+ system_message=$(if [ -n "$SYSTEM_MESSAGE" ]; then printf '%s' "\"$SYSTEM_MESSAGE\""; else echo "None"; fi),
631
+ developer_message=$(if [ -n "$DEVELOPER_MESSAGE" ]; then printf '%s' "\"$DEVELOPER_MESSAGE\""; else echo "None"; fi),
632
+ use_harmony_format=True,
633
+
634
+ # Medical o1 SFT mapping (ignored unless dataset_format == 'medical_o1_sft')
635
+ question_field=$(if [ -n "$MED_Q_FIELD" ]; then echo "\"$MED_Q_FIELD\""; else echo "\"Question\""; fi),
636
+ reasoning_field=$(if [ -n "$MED_REASON_FIELD" ]; then echo "\"$MED_REASON_FIELD\""; else echo "\"Complex_CoT\""; fi),
637
+ response_field=$(if [ -n "$MED_RESP_FIELD" ]; then echo "\"$MED_RESP_FIELD\""; else echo "\"Response\""; fi),
638
+ reason_prefix=$(if [ -n "$MED_REASON_PREFIX" ]; then printf '%s' "\"$MED_REASON_PREFIX\""; else echo "\"Reasoning: \""; fi),
639
+ answer_prefix=$(if [ -n "$MED_ANSWER_PREFIX" ]; then printf '%s' "\"$MED_ANSWER_PREFIX\""; else echo "\"Final Answer: \""; fi),
640
+
641
  # ============================================================================
642
  # TRAINING HYPERPARAMETERS
643
  # ============================================================================
 
875
  "GPT-OSS OpenHermes-FR (Recommended)" \
876
  "GPT-OSS OpenHermes-FR Memory Optimized" \
877
  "GPT-OSS Custom Dataset" \
878
+ "GPT-OSS Medical o1 SFT (Reasoning)" \
879
  TRAINING_CONFIG_TYPE
880
  fi
881
 
scripts/training/train_gpt_oss.py CHANGED
@@ -277,31 +277,66 @@ def apply_dataset_filtering(dataset, config):
277
 
278
  return dataset
279
 
280
- def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  """
282
- Format data for GPT-OSS Harmony format following the exact template structure.
283
- Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja
284
- """
285
- # GPT-OSS Harmony format structure (exact template compliance)
286
- # User message: <|start|>user<|message|>content<|end|>
287
- # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference)
288
- # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training)
289
-
290
- harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}"
291
-
292
  if add_eos_token:
293
- # Use <|return|> for training as per template specification
294
- # This indicates the end of generation in training
295
- harmony_text += "<|return|>"
296
  else:
297
- # Use <|end|> for inference
298
- harmony_text += "<|end|>"
299
-
300
- return harmony_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
- def format_gpt_oss_harmony_prompt(prompt: str) -> str:
303
- """Prefix-only Harmony prompt up to assistant content marker for DPO."""
304
- return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>"
 
 
 
 
 
 
 
 
 
 
305
 
306
  def process_dataset_format(dataset, config):
307
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
@@ -321,6 +356,8 @@ def process_dataset_format(dataset, config):
321
  add_eos_token = getattr(config, 'add_eos_token', True)
322
  use_harmony_format = getattr(config, 'use_harmony_format', True)
323
  trainer_type = getattr(config, 'trainer_type', 'sft')
 
 
324
 
325
  print(f"Processing dataset format: {dataset_format}")
326
  print(f"Input field: {input_field}, Target field: {target_field}")
@@ -338,7 +375,11 @@ def process_dataset_format(dataset, config):
338
  chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
339
  rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
340
  if use_harmony_format:
341
- prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
 
 
 
 
342
  chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
343
  rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
344
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
@@ -355,7 +396,11 @@ def process_dataset_format(dataset, config):
355
  chosen_val = example.get(chosen_field, '')
356
  rejected_val = example.get(rejected_field, '')
357
  if use_harmony_format:
358
- prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
 
 
 
 
359
  chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
360
  rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
361
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
@@ -376,7 +421,13 @@ def process_dataset_format(dataset, config):
376
  if concatenate_fields:
377
  if use_harmony_format:
378
  # Use exact GPT-OSS Harmony format from template
379
- text = format_gpt_oss_harmony(prompt, completion, add_eos_token)
 
 
 
 
 
 
380
  else:
381
  # Fallback to standard format with separator
382
  text = prompt + field_separator + completion
@@ -414,7 +465,13 @@ def process_dataset_format(dataset, config):
414
 
415
  if user_message and assistant_message:
416
  # Use GPT-OSS Harmony format
417
- text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token)
 
 
 
 
 
 
418
  else:
419
  # Fallback to simple concatenation
420
  text = ""
@@ -438,6 +495,44 @@ def process_dataset_format(dataset, config):
438
 
439
  dataset = dataset.map(format_messages, remove_columns=dataset.column_names, num_proc=num_proc)
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  elif dataset_format == "text":
442
  # Process plain text format
443
  text_field = input_field
 
277
 
278
  return dataset
279
 
280
+ def _build_harmony_text(
281
+ user_content: str,
282
+ assistant_content: str,
283
+ add_eos_token: bool = True,
284
+ system_message: str | None = None,
285
+ developer_message: str | None = None,
286
+ ) -> str:
287
+ """Compose a Harmony-formatted conversation with optional system/developer messages.
288
+
289
+ Structure (training):
290
+ <|start|>system<|message|>...<|end|> (optional)
291
+ <|start|>developer<|message|>...<|end|> (optional)
292
+ <|start|>user<|message|>...<|end|>
293
+ <|start|>assistant<|channel|>final<|message|>...<|return|>
294
  """
295
+ parts: list[str] = []
296
+ if system_message:
297
+ parts.append(f"<|start|>system<|message|>{system_message}<|end|>")
298
+ if developer_message:
299
+ parts.append(f"<|start|>developer<|message|>{developer_message}<|end|>")
300
+ parts.append(f"<|start|>user<|message|>{user_content}<|end|>")
301
+ parts.append(f"<|start|>assistant<|channel|>final<|message|>{assistant_content}")
 
 
 
302
  if add_eos_token:
303
+ parts[-1] += "<|return|>"
 
 
304
  else:
305
+ parts[-1] += "<|end|>"
306
+ return "".join(parts)
307
+
308
+ def format_gpt_oss_harmony(
309
+ prompt: str,
310
+ completion: str,
311
+ add_eos_token: bool = True,
312
+ system_message: str | None = None,
313
+ developer_message: str | None = None,
314
+ ) -> str:
315
+ """
316
+ Format data for GPT-OSS Harmony format following the exact template structure.
317
+ Spec: `https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja`.
318
+ """
319
+ return _build_harmony_text(
320
+ user_content=prompt,
321
+ assistant_content=completion,
322
+ add_eos_token=add_eos_token,
323
+ system_message=system_message,
324
+ developer_message=developer_message,
325
+ )
326
 
327
+ def format_gpt_oss_harmony_prompt(
328
+ prompt: str,
329
+ system_message: str | None = None,
330
+ developer_message: str | None = None,
331
+ ) -> str:
332
+ """Prefix-only Harmony prompt up to assistant content marker for DPO, with optional context."""
333
+ parts: list[str] = []
334
+ if system_message:
335
+ parts.append(f"<|start|>system<|message|>{system_message}<|end|>")
336
+ if developer_message:
337
+ parts.append(f"<|start|>developer<|message|>{developer_message}<|end|>")
338
+ parts.append(f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>")
339
+ return "".join(parts)
340
 
341
  def process_dataset_format(dataset, config):
342
  """Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
 
356
  add_eos_token = getattr(config, 'add_eos_token', True)
357
  use_harmony_format = getattr(config, 'use_harmony_format', True)
358
  trainer_type = getattr(config, 'trainer_type', 'sft')
359
+ system_message = getattr(config, 'system_message', None)
360
+ developer_message = getattr(config, 'developer_message', None)
361
 
362
  print(f"Processing dataset format: {dataset_format}")
363
  print(f"Input field: {input_field}, Target field: {target_field}")
 
375
  chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
376
  rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
377
  if use_harmony_format:
378
+ prompt_text = format_gpt_oss_harmony_prompt(
379
+ prompt_val,
380
+ system_message=system_message,
381
+ developer_message=developer_message,
382
+ )
383
  chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
384
  rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
385
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
 
396
  chosen_val = example.get(chosen_field, '')
397
  rejected_val = example.get(rejected_field, '')
398
  if use_harmony_format:
399
+ prompt_text = format_gpt_oss_harmony_prompt(
400
+ prompt_val,
401
+ system_message=system_message,
402
+ developer_message=developer_message,
403
+ )
404
  chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
405
  rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
406
  return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
 
421
  if concatenate_fields:
422
  if use_harmony_format:
423
  # Use exact GPT-OSS Harmony format from template
424
+ text = format_gpt_oss_harmony(
425
+ prompt,
426
+ completion,
427
+ add_eos_token,
428
+ system_message=system_message,
429
+ developer_message=developer_message,
430
+ )
431
  else:
432
  # Fallback to standard format with separator
433
  text = prompt + field_separator + completion
 
465
 
466
  if user_message and assistant_message:
467
  # Use GPT-OSS Harmony format
468
+ text = format_gpt_oss_harmony(
469
+ user_message,
470
+ assistant_message,
471
+ add_eos_token,
472
+ system_message=system_message,
473
+ developer_message=developer_message,
474
+ )
475
  else:
476
  # Fallback to simple concatenation
477
  text = ""
 
495
 
496
  dataset = dataset.map(format_messages, remove_columns=dataset.column_names, num_proc=num_proc)
497
 
498
+ elif dataset_format == "medical_o1_sft":
499
+ # Process Medical-o1 SFT format: Question | Complex_CoT | Response
500
+ # Defaults align with FreedomIntelligence/medical-o1-reasoning-SFT
501
+ question_field = getattr(config, 'question_field', input_field or 'Question')
502
+ reasoning_field = getattr(config, 'reasoning_field', 'Complex_CoT')
503
+ response_field = getattr(config, 'response_field', target_field or 'Response')
504
+ reason_prefix = getattr(config, 'reason_prefix', 'Reasoning: ')
505
+ answer_prefix = getattr(config, 'answer_prefix', 'Final Answer: ')
506
+
507
+ def format_medical(example):
508
+ q = example.get(question_field, '') or ''
509
+ cot = example.get(reasoning_field, '') or ''
510
+ ans = example.get(response_field, '') or ''
511
+
512
+ # Combine reasoning and final answer in a single assistant turn
513
+ assistant_text = "\n\n".join(
514
+ [s for s in [
515
+ f"{reason_prefix}{cot}".strip() if cot else '',
516
+ f"{answer_prefix}{ans}".strip() if ans else ''
517
+ ] if s]
518
+ ) or ans
519
+
520
+ if use_harmony_format:
521
+ text = format_gpt_oss_harmony(
522
+ q,
523
+ assistant_text,
524
+ add_eos_token,
525
+ system_message=system_message,
526
+ developer_message=developer_message,
527
+ )
528
+ else:
529
+ text = f"Q: {q}\n\n{assistant_text}"
530
+ if add_eos_token:
531
+ text += "</s>"
532
+ return {"text": text}
533
+
534
+ dataset = dataset.map(format_medical, remove_columns=dataset.column_names, num_proc=num_proc)
535
+
536
  elif dataset_format == "text":
537
  # Process plain text format
538
  text_field = input_field