Kevin Fink commited on
Commit
05f8623
·
1 Parent(s): d177146
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -86,7 +86,8 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
86
 
87
  tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
88
 
89
- max_length = model.get_input_embeddings().weight.shape[0]
 
90
 
91
  def tokenize_function(examples):
92
 
@@ -95,7 +96,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
95
  examples['text'],
96
  max_length=max_length, # Set to None for dynamic padding
97
  truncation=True,
98
- padding=True,
99
  return_tensors='pt',
100
  )
101
 
@@ -104,7 +105,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
104
  examples['target'],
105
  max_length=max_length, # Set to None for dynamic padding
106
  truncation=True,
107
- padding=True,
108
  #text_target=examples['target'],
109
  return_tensors='pt',
110
  )
@@ -124,12 +125,14 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
124
  saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
125
  print("FOUND TEST")
126
  # Create Trainer
 
127
  trainer = Trainer(
128
  model=model,
129
  args=training_args,
130
  train_dataset=train_dataset,
131
  eval_dataset=saved_test_dataset,
132
  compute_metrics=compute_metrics,
 
133
  )
134
 
135
  elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):
 
86
 
87
  tokenizer = AutoTokenizer.from_pretrained('google/t5-efficient-tiny-nh8')
88
 
89
+ #max_length = model.get_input_embeddings().weight.shape[0]
90
+ max_length = 512
91
 
92
  def tokenize_function(examples):
93
 
 
96
  examples['text'],
97
  max_length=max_length, # Set to None for dynamic padding
98
  truncation=True,
99
+ padding='max_length',
100
  return_tensors='pt',
101
  )
102
 
 
105
  examples['target'],
106
  max_length=max_length, # Set to None for dynamic padding
107
  truncation=True,
108
+ padding='max_length',
109
  #text_target=examples['target'],
110
  return_tensors='pt',
111
  )
 
125
  saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
126
  print("FOUND TEST")
127
  # Create Trainer
128
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
129
  trainer = Trainer(
130
  model=model,
131
  args=training_args,
132
  train_dataset=train_dataset,
133
  eval_dataset=saved_test_dataset,
134
  compute_metrics=compute_metrics,
135
+ data_collator=data_collator,
136
  )
137
 
138
  elif os.access(f'/data/{hub_id.strip()}_train_dataset3', os.R_OK):