Spaces:
Running
Running
solves dataset dict issue
Browse files
data.py
CHANGED
@@ -190,26 +190,56 @@ class SmolLM3Dataset:
|
|
190 |
"length": input_length,
|
191 |
}
|
192 |
|
193 |
-
# Process the dataset
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
-
|
209 |
-
if
|
210 |
-
logger.info(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
-
return
|
213 |
|
214 |
def get_train_dataset(self) -> Dataset:
|
215 |
"""Get training dataset"""
|
|
|
190 |
"length": input_length,
|
191 |
}
|
192 |
|
193 |
+
# Process the dataset - handle both single dataset and dictionary of splits
|
194 |
+
if isinstance(self.dataset, dict):
|
195 |
+
# Process each split individually
|
196 |
+
processed_dataset = {}
|
197 |
+
for split_name, split_dataset in self.dataset.items():
|
198 |
+
logger.info(f"Processing {split_name} split...")
|
199 |
+
|
200 |
+
# Format the split
|
201 |
+
processed_split = split_dataset.map(
|
202 |
+
format_chat_template,
|
203 |
+
remove_columns=split_dataset.column_names,
|
204 |
+
desc=f"Formatting {split_name} dataset"
|
205 |
+
)
|
206 |
+
|
207 |
+
# Tokenize the split
|
208 |
+
tokenized_split = processed_split.map(
|
209 |
+
tokenize_function,
|
210 |
+
remove_columns=processed_split.column_names,
|
211 |
+
desc=f"Tokenizing {split_name} dataset",
|
212 |
+
batched=True,
|
213 |
+
)
|
214 |
+
|
215 |
+
processed_dataset[split_name] = tokenized_split
|
216 |
+
else:
|
217 |
+
# Single dataset
|
218 |
+
processed_dataset = self.dataset.map(
|
219 |
+
format_chat_template,
|
220 |
+
remove_columns=self.dataset.column_names,
|
221 |
+
desc="Formatting dataset"
|
222 |
+
)
|
223 |
+
|
224 |
+
# Tokenize the dataset
|
225 |
+
processed_dataset = processed_dataset.map(
|
226 |
+
tokenize_function,
|
227 |
+
remove_columns=processed_dataset.column_names,
|
228 |
+
desc="Tokenizing dataset",
|
229 |
+
batched=True,
|
230 |
+
)
|
231 |
|
232 |
+
# Log processing results
|
233 |
+
if isinstance(processed_dataset, dict):
|
234 |
+
logger.info(f"Dataset processed. Train samples: {len(processed_dataset['train'])}")
|
235 |
+
if "validation" in processed_dataset:
|
236 |
+
logger.info(f"Validation samples: {len(processed_dataset['validation'])}")
|
237 |
+
if "test" in processed_dataset:
|
238 |
+
logger.info(f"Test samples: {len(processed_dataset['test'])}")
|
239 |
+
else:
|
240 |
+
logger.info(f"Dataset processed. Samples: {len(processed_dataset)}")
|
241 |
|
242 |
+
return processed_dataset
|
243 |
|
244 |
def get_train_dataset(self) -> Dataset:
|
245 |
"""Get training dataset"""
|