Vishwas1 commited on
Commit
55f1be4
·
verified ·
1 Parent(s): d3d62d9

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +9 -12
train_model.py CHANGED
@@ -70,9 +70,9 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
70
  # Check if dataset_name includes a configuration
71
  if '/' in dataset_name:
72
  dataset, config = dataset_name.split('/', 1)
73
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
74
  else:
75
- dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
76
  logging.info("Dataset loaded successfully for generation task.")
77
  def tokenize_function(examples):
78
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
@@ -185,6 +185,8 @@ def main():
185
  if tokenizer.pad_token is None:
186
  logging.info("Setting pad_token to eos_token.")
187
  tokenizer.pad_token = tokenizer.eos_token
 
 
188
  model = initialize_model(
189
  task=args.task,
190
  model_name=args.model_name,
@@ -195,7 +197,10 @@ def main():
195
  attention_heads=args.attention_heads
196
  )
197
  model.resize_token_embeddings(len(tokenizer))
 
198
  else:
 
 
199
  model = initialize_model(
200
  task=args.task,
201
  model_name=args.model_name,
@@ -206,7 +211,7 @@ def main():
206
  attention_heads=args.attention_heads
207
  )
208
  except Exception as e:
209
- logging.error(f"Error initializing tokenizer: {str(e)}")
210
  raise e
211
 
212
  # Load and prepare dataset
@@ -221,9 +226,6 @@ def main():
221
  logging.error("Failed to load and prepare dataset.")
222
  raise e
223
 
224
- # Initialize model (Already initialized above)
225
- # model = initialize_model(...) # Moved above to handle pad_token
226
-
227
  # Define data collator
228
  if args.task == "generation":
229
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
@@ -245,7 +247,7 @@ def main():
245
  learning_rate=5e-4,
246
  remove_unused_columns=False,
247
  push_to_hub=False # We'll handle pushing manually
248
-
249
  )
250
  elif args.task == "classification":
251
  training_args = TrainingArguments(
@@ -313,8 +315,3 @@ def main():
313
 
314
  if __name__ == "__main__":
315
  main()
316
-
317
-
318
-
319
-
320
-
 
70
  # Check if dataset_name includes a configuration
71
  if '/' in dataset_name:
72
  dataset, config = dataset_name.split('/', 1)
73
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
74
  else:
75
+ dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train', use_auth_token=True)
76
  logging.info("Dataset loaded successfully for generation task.")
77
  def tokenize_function(examples):
78
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
 
185
  if tokenizer.pad_token is None:
186
  logging.info("Setting pad_token to eos_token.")
187
  tokenizer.pad_token = tokenizer.eos_token
188
+ logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
189
+ # Resize model's token embeddings after setting pad_token
190
  model = initialize_model(
191
  task=args.task,
192
  model_name=args.model_name,
 
197
  attention_heads=args.attention_heads
198
  )
199
  model.resize_token_embeddings(len(tokenizer))
200
+ logging.info("Resized token embeddings to accommodate pad_token.")
201
  else:
202
+ logging.info(f"Tokenizer already has pad_token set to: {tokenizer.pad_token}")
203
+ # Initialize model normally
204
  model = initialize_model(
205
  task=args.task,
206
  model_name=args.model_name,
 
211
  attention_heads=args.attention_heads
212
  )
213
  except Exception as e:
214
+ logging.error(f"Error initializing tokenizer or model: {str(e)}")
215
  raise e
216
 
217
  # Load and prepare dataset
 
226
  logging.error("Failed to load and prepare dataset.")
227
  raise e
228
 
 
 
 
229
  # Define data collator
230
  if args.task == "generation":
231
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
 
247
  learning_rate=5e-4,
248
  remove_unused_columns=False,
249
  push_to_hub=False # We'll handle pushing manually
250
+
251
  )
252
  elif args.task == "classification":
253
  training_args = TrainingArguments(
 
315
 
316
  if __name__ == "__main__":
317
  main()