amos1088 commited on
Commit
005e7d1
·
1 Parent(s): d29bf84
Files changed (1) hide show
  1. app.py +93 -5
app.py CHANGED
@@ -313,7 +313,7 @@ def prepare_dpo_dataset(df):
313
  return pd.DataFrame(dpo_data)
314
 
315
 
316
- def train_model(train_df, val_df, epochs=3, batch_size=4, lr=5e-5, max_samples=None):
317
  """Training with DPO (Direct Preference Optimization)"""
318
  global current_model, current_tokenizer
319
 
@@ -415,7 +415,7 @@ def train_model(train_df, val_df, epochs=3, batch_size=4, lr=5e-5, max_samples=N
415
  report_to=[],
416
  max_length=seq_length,
417
  max_prompt_length=seq_length,
418
- beta=0.1,
419
  optim="adamw_8bit" if current_model_id == "openai/gpt-oss-20b" else "adamw_torch",
420
  dataloader_num_workers=2, # A100 can handle parallel loading
421
  )
@@ -443,8 +443,65 @@ def train_model(train_df, val_df, epochs=3, batch_size=4, lr=5e-5, max_samples=N
443
 
444
  # Custom callback for status updates
445
  from transformers import TrainerCallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  class StatusCallback(TrainerCallback):
 
 
 
 
 
448
  def on_log(self, args, state, control, logs=None, **kwargs):
449
  if logs:
450
  with training_lock:
@@ -452,13 +509,32 @@ def train_model(train_df, val_df, epochs=3, batch_size=4, lr=5e-5, max_samples=N
452
  training_status["logs"].append(f"Step {state.global_step}: Loss = {logs['loss']:.4f}")
453
  if "eval_loss" in logs:
454
  training_status["logs"].append(f"Eval Loss = {logs['eval_loss']:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  # Update progress
456
  if state.global_step > 0:
457
  total_steps = len(train_dataset) // batch_size * epochs
458
  training_status["progress"] = min(int((state.global_step / total_steps) * 100), 99)
459
 
460
- # Add callback
461
- dpo_trainer.add_callback(StatusCallback())
 
462
 
463
  # Train
464
  try:
@@ -471,6 +547,18 @@ def train_model(train_df, val_df, epochs=3, batch_size=4, lr=5e-5, max_samples=N
471
  current_tokenizer.save_pretrained(save_path)
472
  logger.info(f"Model saved to {save_path}")
473
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # Update global model reference
475
  current_model = dpo_trainer.model
476
  current_model.eval()
@@ -546,7 +634,7 @@ def run_training(csv_path, shuffle_flag=False, split_ratio=0.8):
546
  max_samples = 2000 # Start conservative
547
  else:
548
  max_samples = None
549
- train_model(train_df, test_df, epochs=3, batch_size=32, lr=5e-5, max_samples=max_samples)
550
 
551
  with training_lock:
552
  training_status["status"] = "completed"
 
313
  return pd.DataFrame(dpo_data)
314
 
315
 
316
+ def train_model(train_df, val_df, epochs=3, batch_size=4, lr=2e-5, max_samples=None):
317
  """Training with DPO (Direct Preference Optimization)"""
318
  global current_model, current_tokenizer
319
 
 
415
  report_to=[],
416
  max_length=seq_length,
417
  max_prompt_length=seq_length,
418
+ beta=1.0, # Increased from 0.1 for stronger preference learning
419
  optim="adamw_8bit" if current_model_id == "openai/gpt-oss-20b" else "adamw_torch",
420
  dataloader_num_workers=2, # A100 can handle parallel loading
421
  )
 
443
 
444
  # Custom callback for status updates
445
  from transformers import TrainerCallback
446
+ import numpy as np
447
+
448
+ def compute_accuracy_metrics(trainer, eval_dataset, num_samples=100):
449
+ """Compute accuracy metrics on a subset of eval data"""
450
+ # Sample subset for faster evaluation
451
+ indices = np.random.choice(len(eval_dataset), min(num_samples, len(eval_dataset)), replace=False)
452
+
453
+ predictions_yes = 0
454
+ predictions_no = 0
455
+ correct = 0
456
+
457
+ for idx in indices:
458
+ item = eval_dataset[int(idx)]
459
+ prompt = item['prompt']
460
+ true_choice = item['chosen'] # This is the correct answer
461
+
462
+ # Tokenize and run inference
463
+ inputs = current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
464
+ inputs = {k: v.to(trainer.model.device) for k, v in inputs.items()}
465
+
466
+ with torch.no_grad():
467
+ outputs = trainer.model(**inputs)
468
+ logits = outputs.logits[0, -1, :]
469
+
470
+ # Get token IDs
471
+ yes_token_id = current_tokenizer.encode("yes", add_special_tokens=False)[0]
472
+ no_token_id = current_tokenizer.encode("no", add_special_tokens=False)[0]
473
+
474
+ yes_logit = logits[yes_token_id].item()
475
+ no_logit = logits[no_token_id].item()
476
+
477
+ # Get prediction
478
+ prediction = " yes" if yes_logit > no_logit else " no"
479
+
480
+ if prediction == " yes":
481
+ predictions_yes += 1
482
+ else:
483
+ predictions_no += 1
484
+
485
+ if prediction.strip() == true_choice.strip():
486
+ correct += 1
487
+
488
+ accuracy = correct / len(indices)
489
+ yes_ratio = predictions_yes / len(indices)
490
+ no_ratio = predictions_no / len(indices)
491
+
492
+ return {
493
+ 'accuracy': accuracy,
494
+ 'yes_ratio': yes_ratio,
495
+ 'no_ratio': no_ratio,
496
+ 'total_samples': len(indices)
497
+ }
498
 
499
  class StatusCallback(TrainerCallback):
500
+ def __init__(self, trainer, eval_dataset):
501
+ self.trainer = trainer
502
+ self.eval_dataset = eval_dataset
503
+ self.eval_every_n_steps = 50 # Evaluate every 50 steps
504
+
505
  def on_log(self, args, state, control, logs=None, **kwargs):
506
  if logs:
507
  with training_lock:
 
509
  training_status["logs"].append(f"Step {state.global_step}: Loss = {logs['loss']:.4f}")
510
  if "eval_loss" in logs:
511
  training_status["logs"].append(f"Eval Loss = {logs['eval_loss']:.4f}")
512
+
513
+ # Compute accuracy metrics periodically
514
+ if state.global_step > 0 and state.global_step % self.eval_every_n_steps == 0:
515
+ metrics = compute_accuracy_metrics(self.trainer, self.eval_dataset)
516
+ training_status["logs"].append(
517
+ f"Step {state.global_step} Metrics: "
518
+ f"Accuracy={metrics['accuracy']:.2%}, "
519
+ f"Yes={metrics['yes_ratio']:.1%}, "
520
+ f"No={metrics['no_ratio']:.1%}"
521
+ )
522
+
523
+ # Warn if model is biased
524
+ if metrics['yes_ratio'] < 0.2 or metrics['no_ratio'] < 0.2:
525
+ training_status["logs"].append(
526
+ f"⚠️ WARNING: Model is heavily biased! "
527
+ f"(Yes: {metrics['yes_ratio']:.1%}, No: {metrics['no_ratio']:.1%})"
528
+ )
529
+
530
  # Update progress
531
  if state.global_step > 0:
532
  total_steps = len(train_dataset) // batch_size * epochs
533
  training_status["progress"] = min(int((state.global_step / total_steps) * 100), 99)
534
 
535
+ # Add callback with trainer and eval dataset
536
+ status_callback = StatusCallback(dpo_trainer, val_dataset)
537
+ dpo_trainer.add_callback(status_callback)
538
 
539
  # Train
540
  try:
 
547
  current_tokenizer.save_pretrained(save_path)
548
  logger.info(f"Model saved to {save_path}")
549
 
550
+ # Compute final metrics
551
+ logger.info("Computing final accuracy metrics...")
552
+ final_metrics = compute_accuracy_metrics(dpo_trainer, val_dataset, num_samples=200)
553
+ logger.info(f"Final Accuracy: {final_metrics['accuracy']:.2%}")
554
+ logger.info(f"Final Prediction Distribution - Yes: {final_metrics['yes_ratio']:.1%}, No: {final_metrics['no_ratio']:.1%}")
555
+
556
+ with training_lock:
557
+ training_status["logs"].append(f"\n=== FINAL RESULTS ===")
558
+ training_status["logs"].append(f"Accuracy: {final_metrics['accuracy']:.2%}")
559
+ training_status["logs"].append(f"Yes predictions: {final_metrics['yes_ratio']:.1%}")
560
+ training_status["logs"].append(f"No predictions: {final_metrics['no_ratio']:.1%}")
561
+
562
  # Update global model reference
563
  current_model = dpo_trainer.model
564
  current_model.eval()
 
634
  max_samples = 2000 # Start conservative
635
  else:
636
  max_samples = None
637
+ train_model(train_df, test_df, epochs=3, batch_size=32, lr=2e-5, max_samples=max_samples)
638
 
639
  with training_lock:
640
  training_status["status"] = "completed"