Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +9 -12
train_model.py
CHANGED
@@ -70,9 +70,9 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
70 |
# Check if dataset_name includes a configuration
|
71 |
if '/' in dataset_name:
|
72 |
dataset, config = dataset_name.split('/', 1)
|
73 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
74 |
else:
|
75 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
76 |
logging.info("Dataset loaded successfully for generation task.")
|
77 |
def tokenize_function(examples):
|
78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
@@ -185,6 +185,8 @@ def main():
|
|
185 |
if tokenizer.pad_token is None:
|
186 |
logging.info("Setting pad_token to eos_token.")
|
187 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
188 |
model = initialize_model(
|
189 |
task=args.task,
|
190 |
model_name=args.model_name,
|
@@ -195,7 +197,10 @@ def main():
|
|
195 |
attention_heads=args.attention_heads
|
196 |
)
|
197 |
model.resize_token_embeddings(len(tokenizer))
|
|
|
198 |
else:
|
|
|
|
|
199 |
model = initialize_model(
|
200 |
task=args.task,
|
201 |
model_name=args.model_name,
|
@@ -206,7 +211,7 @@ def main():
|
|
206 |
attention_heads=args.attention_heads
|
207 |
)
|
208 |
except Exception as e:
|
209 |
-
logging.error(f"Error initializing tokenizer: {str(e)}")
|
210 |
raise e
|
211 |
|
212 |
# Load and prepare dataset
|
@@ -221,9 +226,6 @@ def main():
|
|
221 |
logging.error("Failed to load and prepare dataset.")
|
222 |
raise e
|
223 |
|
224 |
-
# Initialize model (Already initialized above)
|
225 |
-
# model = initialize_model(...) # Moved above to handle pad_token
|
226 |
-
|
227 |
# Define data collator
|
228 |
if args.task == "generation":
|
229 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
@@ -245,7 +247,7 @@ def main():
|
|
245 |
learning_rate=5e-4,
|
246 |
remove_unused_columns=False,
|
247 |
push_to_hub=False # We'll handle pushing manually
|
248 |
-
|
249 |
)
|
250 |
elif args.task == "classification":
|
251 |
training_args = TrainingArguments(
|
@@ -313,8 +315,3 @@ def main():
|
|
313 |
|
314 |
if __name__ == "__main__":
|
315 |
main()
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
70 |
# Check if dataset_name includes a configuration
|
71 |
if '/' in dataset_name:
|
72 |
dataset, config = dataset_name.split('/', 1)
|
73 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
|
74 |
else:
|
75 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
|
76 |
logging.info("Dataset loaded successfully for generation task.")
|
77 |
def tokenize_function(examples):
|
78 |
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
|
|
185 |
if tokenizer.pad_token is None:
|
186 |
logging.info("Setting pad_token to eos_token.")
|
187 |
tokenizer.pad_token = tokenizer.eos_token
|
188 |
+
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
189 |
+
# Resize model's token embeddings after setting pad_token
|
190 |
model = initialize_model(
|
191 |
task=args.task,
|
192 |
model_name=args.model_name,
|
|
|
197 |
attention_heads=args.attention_heads
|
198 |
)
|
199 |
model.resize_token_embeddings(len(tokenizer))
|
200 |
+
logging.info("Resized token embeddings to accommodate pad_token.")
|
201 |
else:
|
202 |
+
logging.info(f"Tokenizer already has pad_token set to: {tokenizer.pad_token}")
|
203 |
+
# Initialize model normally
|
204 |
model = initialize_model(
|
205 |
task=args.task,
|
206 |
model_name=args.model_name,
|
|
|
211 |
attention_heads=args.attention_heads
|
212 |
)
|
213 |
except Exception as e:
|
214 |
+
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
215 |
raise e
|
216 |
|
217 |
# Load and prepare dataset
|
|
|
226 |
logging.error("Failed to load and prepare dataset.")
|
227 |
raise e
|
228 |
|
|
|
|
|
|
|
229 |
# Define data collator
|
230 |
if args.task == "generation":
|
231 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
|
|
247 |
learning_rate=5e-4,
|
248 |
remove_unused_columns=False,
|
249 |
push_to_hub=False # We'll handle pushing manually
|
250 |
+
|
251 |
)
|
252 |
elif args.task == "classification":
|
253 |
training_args = TrainingArguments(
|
|
|
315 |
|
316 |
if __name__ == "__main__":
|
317 |
main()
|
|
|
|
|
|
|
|
|
|