Tonic commited on
Commit
cb276d8
Β·
1 Parent(s): 976e218

adds flash attention 3 kernel

Browse files
config/train_gpt_oss_openhermes_fr_memory_optimized.py CHANGED
@@ -35,7 +35,7 @@ config = GPTOSSEnhancedCustomConfig(
35
  # Dataset sampling optimized for memory constraints
36
  max_samples=800000, # Reduced from 800K for memory efficiency
37
  min_length=15, # Slightly higher minimum for quality
38
- max_length=2048, # Explicit max length for memory control
39
 
40
  # ============================================================================
41
  # MEMORY-OPTIMIZED TRAINING HYPERPARAMETERS
@@ -56,7 +56,7 @@ config = GPTOSSEnhancedCustomConfig(
56
  # MODEL CONFIGURATION - Memory Optimized for GPT-OSS
57
  # ============================================================================
58
  model_name="openai/gpt-oss-20b",
59
- max_seq_length=1024, # Reduced from 3072 for memory optimization
60
  use_flash_attention=True, # Critical for memory efficiency
61
  use_gradient_checkpointing=True, # Essential for memory optimization
62
 
@@ -106,7 +106,7 @@ config = GPTOSSEnhancedCustomConfig(
106
  # ============================================================================
107
  # Model loading with memory constraints
108
  model_kwargs={
109
- "attn_implementation": "eager", # Memory-safe attention
110
  "torch_dtype": "auto", # Let model decide (MXFP4 compatible)
111
  "use_cache": False, # Disable KV cache for training
112
  "device_map": "auto", # Automatic device mapping
@@ -114,10 +114,10 @@ config = GPTOSSEnhancedCustomConfig(
114
  "max_memory": {0: "75GB"}, # Reserve memory for other processes
115
  },
116
 
117
- # Data loading optimized for memory efficiency
118
- dataloader_num_workers=2, # Reduced workers to save memory
119
- dataloader_pin_memory=False, # Disable to save memory
120
- dataloader_prefetch_factor=1, # Minimal prefetch for memory
121
 
122
  # Memory management optimizations
123
  max_memory_per_gpu="75GB", # Explicit memory limit
@@ -126,7 +126,7 @@ config = GPTOSSEnhancedCustomConfig(
126
  remove_unused_columns=True, # Remove unnecessary data
127
 
128
  # ============================================================================
129
- # EVALUATION & LOGGING - Memory Efficient
130
  # ============================================================================
131
  eval_strategy="steps",
132
  eval_steps=500, # Less frequent evaluation for memory
@@ -134,7 +134,7 @@ config = GPTOSSEnhancedCustomConfig(
134
 
135
  save_strategy="steps",
136
  save_steps=1000, # Less frequent saves for memory/storage
137
- save_total_limit=2, # Keep only 2 checkpoints for memory
138
  save_only_model=True, # Save only model weights
139
 
140
  metric_for_best_model="eval_loss",
 
35
  # Dataset sampling optimized for memory constraints
36
  max_samples=800000, # Reduced from 800K for memory efficiency
37
  min_length=15, # Slightly higher minimum for quality
38
+ max_length=4096, # Explicit max length for memory control
39
 
40
  # ============================================================================
41
  # MEMORY-OPTIMIZED TRAINING HYPERPARAMETERS
 
56
  # MODEL CONFIGURATION - Memory Optimized for GPT-OSS
57
  # ============================================================================
58
  model_name="openai/gpt-oss-20b",
59
+ max_seq_length=4096, # Reduced from 3072 for memory optimization
60
  use_flash_attention=True, # Critical for memory efficiency
61
  use_gradient_checkpointing=True, # Essential for memory optimization
62
 
 
106
  # ============================================================================
107
  # Model loading with memory constraints
108
  model_kwargs={
109
+ "attn_implementation": "kernels-community/vllm-flash-attn3", # Much faster attention on A100/H100
110
  "torch_dtype": "auto", # Let model decide (MXFP4 compatible)
111
  "use_cache": False, # Disable KV cache for training
112
  "device_map": "auto", # Automatic device mapping
 
114
  "max_memory": {0: "75GB"}, # Reserve memory for other processes
115
  },
116
 
117
+ # Data loading optimized for throughput
118
+ dataloader_num_workers=4, # More workers for faster loading
119
+ dataloader_pin_memory=True, # Pin memory for faster host->GPU copies
120
+ dataloader_prefetch_factor=2,
121
 
122
  # Memory management optimizations
123
  max_memory_per_gpu="75GB", # Explicit memory limit
 
126
  remove_unused_columns=True, # Remove unnecessary data
127
 
128
  # ============================================================================
129
+ # EVALUATION & LOGGING - Fast Iterations
130
  # ============================================================================
131
  eval_strategy="steps",
132
  eval_steps=500, # Less frequent evaluation for memory
 
134
 
135
  save_strategy="steps",
136
  save_steps=1000, # Less frequent saves for memory/storage
137
+ save_total_limit=3, # Keep only 2 checkpoints for memory
138
  save_only_model=True, # Save only model weights
139
 
140
  metric_for_best_model="eval_loss",
requirements/requirements_core.txt CHANGED
@@ -5,6 +5,7 @@ datasets>=2.14.0
5
  accelerate>=0.20.0
6
  peft>=0.17.0 # Updated for GPT-OSS LoRA support
7
  trl>=0.20.0 # Updated for GPT-OSS compatibility
 
8
 
9
  # Hugging Face Hub for model and space management
10
  huggingface_hub>=0.19.0
 
5
  accelerate>=0.20.0
6
  peft>=0.17.0 # Updated for GPT-OSS LoRA support
7
  trl>=0.20.0 # Updated for GPT-OSS compatibility
8
+ kernels
9
 
10
  # Hugging Face Hub for model and space management
11
  huggingface_hub>=0.19.0
scripts/trackio_tonic/app.py CHANGED
@@ -25,10 +25,11 @@ class TrackioSpace:
25
  def __init__(self, hf_token: Optional[str] = None, dataset_repo: Optional[str] = None):
26
  self.experiments = {}
27
  self.current_experiment = None
 
28
 
29
  # Get dataset repository and HF token from parameters or environment variables
30
- # Use dynamic default based on environment or fallback to generic default
31
- default_dataset_repo = os.environ.get('TRACKIO_DATASET_REPO', 'trackio-experiments')
32
  self.dataset_repo = dataset_repo or default_dataset_repo
33
  self.hf_token = hf_token or os.environ.get('HF_TOKEN')
34
 
@@ -75,12 +76,14 @@ class TrackioSpace:
75
  # Fall back to backup data
76
  self._load_backup_experiments()
77
  else:
78
- # No HF token, use backup data
79
  self._load_backup_experiments()
 
80
 
81
  except Exception as e:
82
  logger.error(f"Failed to load experiments: {e}")
83
  self._load_backup_experiments()
 
84
 
85
  def _load_backup_experiments(self):
86
  """Load backup experiments when dataset is not available"""
@@ -314,6 +317,9 @@ class TrackioSpace:
314
  def _save_experiments(self):
315
  """Save experiments to HF Dataset"""
316
  try:
 
 
 
317
  if self.hf_token:
318
  from datasets import Dataset
319
  from huggingface_hub import HfApi
@@ -565,17 +571,20 @@ def create_dataset_repository(hf_token: str, dataset_repo: str) -> str:
565
  except Exception as e:
566
  return f"❌ Failed to create dataset: {str(e)}\n\nπŸ’‘ Troubleshooting:\n1. Check your HF token has write permissions\n2. Verify the username in the repository name\n3. Ensure the dataset name is valid"
567
 
568
- # Initialize API client for remote data
569
  api_client = None
570
  try:
571
- from trackio_api_client import create_trackio_client
572
- api_client = create_trackio_client()
573
- if api_client:
 
574
  logger.info("βœ… API client initialized for remote data access")
575
  else:
576
- logger.warning("⚠️ Could not initialize API client, using local data only")
577
  except ImportError:
578
  logger.warning("⚠️ API client not available, using local data only")
 
 
579
 
580
  # Add Hugging Face Spaces compatibility
581
  def is_huggingface_spaces():
@@ -590,8 +599,8 @@ def get_persistent_data_path():
590
  else:
591
  return "trackio_experiments.json"
592
 
593
- # Override the data file path for HF Spaces
594
- if is_huggingface_spaces():
595
  logger.info("πŸš€ Running on Hugging Face Spaces - using persistent storage")
596
  trackio_space.data_file = get_persistent_data_path()
597
 
 
25
  def __init__(self, hf_token: Optional[str] = None, dataset_repo: Optional[str] = None):
26
  self.experiments = {}
27
  self.current_experiment = None
28
+ self.backup_mode = False
29
 
30
  # Get dataset repository and HF token from parameters or environment variables
31
+ # Respect explicit values; avoid hardcoded defaults that might point to test repos
32
+ default_dataset_repo = os.environ.get('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
33
  self.dataset_repo = dataset_repo or default_dataset_repo
34
  self.hf_token = hf_token or os.environ.get('HF_TOKEN')
35
 
 
76
  # Fall back to backup data
77
  self._load_backup_experiments()
78
  else:
79
+ # No HF token, use backup data but do not allow saving to dataset from backup
80
  self._load_backup_experiments()
81
+ self.backup_mode = True
82
 
83
  except Exception as e:
84
  logger.error(f"Failed to load experiments: {e}")
85
  self._load_backup_experiments()
86
+ self.backup_mode = True
87
 
88
  def _load_backup_experiments(self):
89
  """Load backup experiments when dataset is not available"""
 
317
  def _save_experiments(self):
318
  """Save experiments to HF Dataset"""
319
  try:
320
+ if self.backup_mode:
321
+ logger.warning("⚠️ Backup mode active; skipping dataset save to avoid overwriting real data with demo values")
322
+ return
323
  if self.hf_token:
324
  from datasets import Dataset
325
  from huggingface_hub import HfApi
 
571
  except Exception as e:
572
  return f"❌ Failed to create dataset: {str(e)}\n\nπŸ’‘ Troubleshooting:\n1. Check your HF token has write permissions\n2. Verify the username in the repository name\n3. Ensure the dataset name is valid"
573
 
574
+ # Initialize API client for remote data if environment provides a space id/url
575
  api_client = None
576
  try:
577
+ from trackio_api_client import TrackioAPIClient
578
+ space_id = os.environ.get('TRACKIO_URL') or os.environ.get('TRACKIO_SPACE_ID')
579
+ if space_id:
580
+ api_client = TrackioAPIClient(space_id, os.environ.get('HF_TOKEN'))
581
  logger.info("βœ… API client initialized for remote data access")
582
  else:
583
+ logger.info("No TRACKIO_URL/TRACKIO_SPACE_ID set; remote API client disabled")
584
  except ImportError:
585
  logger.warning("⚠️ API client not available, using local data only")
586
+ except Exception as e:
587
+ logger.warning(f"⚠️ Could not initialize API client: {e}")
588
 
589
  # Add Hugging Face Spaces compatibility
590
  def is_huggingface_spaces():
 
599
  else:
600
  return "trackio_experiments.json"
601
 
602
+ # Override the data file path for HF Spaces if attribute exists
603
+ if is_huggingface_spaces() and hasattr(trackio_space, 'data_file'):
604
  logger.info("πŸš€ Running on Hugging Face Spaces - using persistent storage")
605
  trackio_space.data_file = get_persistent_data_path()
606
 
scripts/training/train_gpt_oss.py CHANGED
@@ -458,11 +458,17 @@ def split_dataset(dataset, config):
458
  def setup_trackio_tracking(config):
459
  """Setup Trackio tracking if enabled"""
460
 
461
- if not config.enable_tracking or not config.trackio_url:
462
  print("Trackio tracking disabled or URL not provided")
463
  return None
464
 
465
- print(f"Setting up Trackio tracking: {config.trackio_url}")
 
 
 
 
 
 
466
 
467
  # Import the correct TrackioAPIClient
468
  import sys
@@ -472,8 +478,8 @@ def setup_trackio_tracking(config):
472
 
473
  # Initialize Trackio client using the correct API
474
  trackio_client = TrackioAPIClient(
475
- space_id=config.trackio_url,
476
- hf_token=config.trackio_token
477
  )
478
 
479
  return trackio_client
 
458
  def setup_trackio_tracking(config):
459
  """Setup Trackio tracking if enabled"""
460
 
461
+ if not getattr(config, 'enable_tracking', False):
462
  print("Trackio tracking disabled or URL not provided")
463
  return None
464
 
465
+ # Resolve Trackio URL from config or environment
466
+ trackio_url = getattr(config, 'trackio_url', None) or os.environ.get('TRACKIO_URL') or os.environ.get('TRACKIO_SPACE_ID')
467
+ if not trackio_url:
468
+ print("Trackio tracking enabled but no TRACKIO_URL/TRACKIO_SPACE_ID provided; skipping Trackio setup")
469
+ return None
470
+
471
+ print(f"Setting up Trackio tracking: {trackio_url}")
472
 
473
  # Import the correct TrackioAPIClient
474
  import sys
 
478
 
479
  # Initialize Trackio client using the correct API
480
  trackio_client = TrackioAPIClient(
481
+ space_id=trackio_url,
482
+ hf_token=getattr(config, 'trackio_token', None) or os.environ.get('HF_TOKEN')
483
  )
484
 
485
  return trackio_client
src/monitoring.py CHANGED
@@ -120,12 +120,18 @@ class SmolLM3Monitor:
120
  """Setup Trackio API client"""
121
  try:
122
  # Get Trackio configuration from environment or parameters
123
- space_id = trackio_url or os.getenv('TRACKIO_SPACE_ID')
124
-
 
 
 
 
 
 
125
  if not space_id:
126
- # Use the deployed Trackio Space ID
127
- space_id = "Tonic/trackio-monitoring-20250727"
128
- logger.info(f"Using default Trackio Space ID: {space_id}")
129
 
130
  # Get HF token for Space resolution
131
  hf_token = self.hf_token or trackio_token or os.getenv('HF_TOKEN')
 
120
  """Setup Trackio API client"""
121
  try:
122
  # Get Trackio configuration from environment or parameters
123
+ # Accept either a full URL or an org/space identifier
124
+ # Prefer explicit parameter, then environment variables
125
+ space_id = (
126
+ trackio_url
127
+ or os.getenv('TRACKIO_URL')
128
+ or os.getenv('TRACKIO_SPACE_ID')
129
+ )
130
+
131
  if not space_id:
132
+ logger.warning("No Trackio Space configured via param or env (TRACKIO_URL/TRACKIO_SPACE_ID). Disabling Trackio tracking.")
133
+ self.enable_tracking = False
134
+ return
135
 
136
  # Get HF token for Space resolution
137
  hf_token = self.hf_token or trackio_token or os.getenv('HF_TOKEN')
src/trackio.py CHANGED
@@ -40,7 +40,12 @@ def init(
40
  project_name = os.environ.get('EXPERIMENT_NAME', 'smollm3_experiment')
41
 
42
  # Extract configuration from kwargs
43
- trackio_url = kwargs.get('trackio_url') or os.environ.get('TRACKIO_URL')
 
 
 
 
 
44
  trackio_token = kwargs.get('trackio_token') or os.environ.get('TRACKIO_TOKEN')
45
  hf_token = kwargs.get('hf_token') or os.environ.get('HF_TOKEN')
46
  dataset_repo = kwargs.get('dataset_repo') or os.environ.get('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
 
40
  project_name = os.environ.get('EXPERIMENT_NAME', 'smollm3_experiment')
41
 
42
  # Extract configuration from kwargs
43
+ # Accept both TRACKIO_URL (full URL or org/space) and TRACKIO_SPACE_ID
44
+ trackio_url = (
45
+ kwargs.get('trackio_url')
46
+ or os.environ.get('TRACKIO_URL')
47
+ or os.environ.get('TRACKIO_SPACE_ID')
48
+ )
49
  trackio_token = kwargs.get('trackio_token') or os.environ.get('TRACKIO_TOKEN')
50
  hf_token = kwargs.get('hf_token') or os.environ.get('HF_TOKEN')
51
  dataset_repo = kwargs.get('dataset_repo') or os.environ.get('TRACKIO_DATASET_REPO', 'tonic/trackio-experiments')
templates/spaces/trackio/app.py CHANGED
@@ -27,6 +27,7 @@ class TrackioSpace:
27
  def __init__(self, hf_token: Optional[str] = None, dataset_repo: Optional[str] = None):
28
  self.experiments = {}
29
  self.current_experiment = None
 
30
 
31
  # Get dataset repository and HF token from parameters or environment variables
32
  self.dataset_repo = dataset_repo or os.environ.get('TRACKIO_DATASET_REPO', 'Tonic/trackio-experiments')
@@ -80,10 +81,11 @@ class TrackioSpace:
80
  reverse=True
81
  ))
82
 
83
- # If no experiments found, use backup
84
  if not self.experiments:
85
  logger.info("πŸ“Š No experiments found in dataset, using backup data")
86
  self._load_backup_experiments()
 
87
 
88
  return
89
 
@@ -91,15 +93,18 @@ class TrackioSpace:
91
  if self.hf_token:
92
  success = self._load_experiments_direct()
93
  if success:
 
94
  return
95
 
96
  # Final fallback to backup data
97
  logger.info("πŸ”„ Using backup data")
98
  self._load_backup_experiments()
 
99
 
100
  except Exception as e:
101
  logger.error(f"❌ Failed to load experiments: {e}")
102
  self._load_backup_experiments()
 
103
 
104
  def _load_experiments_direct(self) -> bool:
105
  """Load experiments directly from HF Dataset without dataset manager"""
@@ -423,6 +428,9 @@ class TrackioSpace:
423
  def _save_experiments(self):
424
  """Save experiments to HF Dataset with data preservation"""
425
  try:
 
 
 
426
  # Use dataset manager for safe operations if available
427
  if self.dataset_manager:
428
  logger.info("πŸ’Ύ Saving experiments using dataset manager (data preservation)")
@@ -782,21 +790,27 @@ def create_dataset_repository(hf_token: str, dataset_repo: str) -> str:
782
  except Exception as e:
783
  return f"❌ Failed to create dataset: {str(e)}\n\nπŸ’‘ Troubleshooting:\n1. Check your HF token has write permissions\n2. Verify the username in the repository name\n3. Ensure the dataset name is valid\n4. Check internet connectivity"
784
 
785
- # Initialize API client for remote data
 
 
 
 
786
  api_client = None
787
  try:
788
  from trackio_api_client import TrackioAPIClient
789
- # Get Trackio URL from environment or use default
790
- trackio_url = os.environ.get('TRACKIO_URL', 'https://tonic-test-trackio-test.hf.space')
791
-
792
- # Clean up URL to avoid double protocol issues
793
- if trackio_url.startswith('https://https://'):
794
- trackio_url = trackio_url.replace('https://https://', 'https://')
795
- elif trackio_url.startswith('http://http://'):
796
- trackio_url = trackio_url.replace('http://http://', 'http://')
797
-
798
- api_client = TrackioAPIClient(trackio_url)
799
- logger.info(f"βœ… API client initialized for remote data access: {trackio_url}")
 
 
800
  except ImportError:
801
  logger.warning("⚠️ API client not available, using local data only")
802
  except Exception as e:
 
27
  def __init__(self, hf_token: Optional[str] = None, dataset_repo: Optional[str] = None):
28
  self.experiments = {}
29
  self.current_experiment = None
30
+ self.using_backup_data = False
31
 
32
  # Get dataset repository and HF token from parameters or environment variables
33
  self.dataset_repo = dataset_repo or os.environ.get('TRACKIO_DATASET_REPO', 'Tonic/trackio-experiments')
 
81
  reverse=True
82
  ))
83
 
84
+ # If no experiments found, use backup but mark backup mode to avoid accidental writes
85
  if not self.experiments:
86
  logger.info("πŸ“Š No experiments found in dataset, using backup data")
87
  self._load_backup_experiments()
88
+ self.using_backup_data = True
89
 
90
  return
91
 
 
93
  if self.hf_token:
94
  success = self._load_experiments_direct()
95
  if success:
96
+ self.using_backup_data = False
97
  return
98
 
99
  # Final fallback to backup data
100
  logger.info("πŸ”„ Using backup data")
101
  self._load_backup_experiments()
102
+ self.using_backup_data = True
103
 
104
  except Exception as e:
105
  logger.error(f"❌ Failed to load experiments: {e}")
106
  self._load_backup_experiments()
107
+ self.using_backup_data = True
108
 
109
  def _load_experiments_direct(self) -> bool:
110
  """Load experiments directly from HF Dataset without dataset manager"""
 
428
  def _save_experiments(self):
429
  """Save experiments to HF Dataset with data preservation"""
430
  try:
431
+ if self.using_backup_data:
432
+ logger.warning("⚠️ Using backup data; skip saving to dataset to avoid overwriting with demo values")
433
+ return
434
  # Use dataset manager for safe operations if available
435
  if self.dataset_manager:
436
  logger.info("πŸ’Ύ Saving experiments using dataset manager (data preservation)")
 
790
  except Exception as e:
791
  return f"❌ Failed to create dataset: {str(e)}\n\nπŸ’‘ Troubleshooting:\n1. Check your HF token has write permissions\n2. Verify the username in the repository name\n3. Ensure the dataset name is valid\n4. Check internet connectivity"
792
 
793
+ """
794
+ Initialize API client for remote data. We do not hardcode a default test URL to avoid
795
+ overwriting dataset content with demo data. The API client will only be initialized
796
+ when TRACKIO_URL or TRACKIO_SPACE_ID is present.
797
+ """
798
  api_client = None
799
  try:
800
  from trackio_api_client import TrackioAPIClient
801
+ # Resolve Trackio space from environment
802
+ trackio_url_env = os.environ.get('TRACKIO_URL') or os.environ.get('TRACKIO_SPACE_ID')
803
+ if trackio_url_env:
804
+ # Clean up URL to avoid double protocol issues
805
+ trackio_url = trackio_url_env
806
+ if trackio_url.startswith('https://https://'):
807
+ trackio_url = trackio_url.replace('https://https://', 'https://')
808
+ elif trackio_url.startswith('http://http://'):
809
+ trackio_url = trackio_url.replace('http://http://', 'http://')
810
+ api_client = TrackioAPIClient(trackio_url)
811
+ logger.info(f"βœ… API client initialized for remote data access: {trackio_url}")
812
+ else:
813
+ logger.info("No TRACKIO_URL/TRACKIO_SPACE_ID set; remote API client disabled")
814
  except ImportError:
815
  logger.warning("⚠️ API client not available, using local data only")
816
  except Exception as e: