Vishwas1 commited on
Commit
706ea4a
·
verified ·
1 Parent(s): 8fd0cb7

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +2 -2
train_model.py CHANGED
@@ -65,13 +65,13 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
- train_dataset = load_dataset(dataset_name,use_auth_token=True)
69
  dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
70
  logging.info("Dataset loaded successfully for generation task.")
71
  def tokenize_function(examples):
72
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
73
  elif task == "classification":
74
- train_dataset = load_dataset(dataset_name,use_auth_token=True)
75
  dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
76
  logging.info("Dataset loaded successfully for classification task.")
77
  # Assuming the dataset has 'text' and 'label' columns
 
65
  logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
66
  try:
67
  if task == "generation":
68
+ train_dataset = load_dataset(dataset_name,split='train',use_auth_token=True)
69
  dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
70
  logging.info("Dataset loaded successfully for generation task.")
71
  def tokenize_function(examples):
72
  return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
73
  elif task == "classification":
74
+ train_dataset = load_dataset(dataset_name,split='train',use_auth_token=True)
75
  dataset = train_dataset['train'].shuffle(seed=42).select(range(500))
76
  logging.info("Dataset loaded successfully for classification task.")
77
  # Assuming the dataset has 'text' and 'label' columns