Tonic commited on
Commit
40f937e
·
verified ·
1 Parent(s): cb932c7

solves dataset dict issue

Browse files
Files changed (1) hide show
  1. data.py +48 -18
data.py CHANGED
@@ -190,26 +190,56 @@ class SmolLM3Dataset:
190
  "length": input_length,
191
  }
192
 
193
- # Process the dataset
194
- processed_dataset = self.dataset.map(
195
- format_chat_template,
196
- remove_columns=self.dataset["train"].column_names,
197
- desc="Formatting dataset"
198
- )
199
-
200
- # Tokenize the dataset
201
- tokenized_dataset = processed_dataset.map(
202
- tokenize_function,
203
- remove_columns=processed_dataset["train"].column_names,
204
- desc="Tokenizing dataset",
205
- batched=True,
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- logger.info(f"Dataset processed. Train samples: {len(tokenized_dataset['train'])}")
209
- if "validation" in tokenized_dataset:
210
- logger.info(f"Validation samples: {len(tokenized_dataset['validation'])}")
 
 
 
 
 
 
211
 
212
- return tokenized_dataset
213
 
214
  def get_train_dataset(self) -> Dataset:
215
  """Get training dataset"""
 
190
  "length": input_length,
191
  }
192
 
193
+ # Process the dataset - handle both single dataset and dictionary of splits
194
+ if isinstance(self.dataset, dict):
195
+ # Process each split individually
196
+ processed_dataset = {}
197
+ for split_name, split_dataset in self.dataset.items():
198
+ logger.info(f"Processing {split_name} split...")
199
+
200
+ # Format the split
201
+ processed_split = split_dataset.map(
202
+ format_chat_template,
203
+ remove_columns=split_dataset.column_names,
204
+ desc=f"Formatting {split_name} dataset"
205
+ )
206
+
207
+ # Tokenize the split
208
+ tokenized_split = processed_split.map(
209
+ tokenize_function,
210
+ remove_columns=processed_split.column_names,
211
+ desc=f"Tokenizing {split_name} dataset",
212
+ batched=True,
213
+ )
214
+
215
+ processed_dataset[split_name] = tokenized_split
216
+ else:
217
+ # Single dataset
218
+ processed_dataset = self.dataset.map(
219
+ format_chat_template,
220
+ remove_columns=self.dataset.column_names,
221
+ desc="Formatting dataset"
222
+ )
223
+
224
+ # Tokenize the dataset
225
+ processed_dataset = processed_dataset.map(
226
+ tokenize_function,
227
+ remove_columns=processed_dataset.column_names,
228
+ desc="Tokenizing dataset",
229
+ batched=True,
230
+ )
231
 
232
+ # Log processing results
233
+ if isinstance(processed_dataset, dict):
234
+ logger.info(f"Dataset processed. Train samples: {len(processed_dataset['train'])}")
235
+ if "validation" in processed_dataset:
236
+ logger.info(f"Validation samples: {len(processed_dataset['validation'])}")
237
+ if "test" in processed_dataset:
238
+ logger.info(f"Test samples: {len(processed_dataset['test'])}")
239
+ else:
240
+ logger.info(f"Dataset processed. Samples: {len(processed_dataset)}")
241
 
242
+ return processed_dataset
243
 
244
  def get_train_dataset(self) -> Dataset:
245
  """Get training dataset"""