Kevin Fink commited on
Commit
0aa217c
·
1 Parent(s): bc59d39
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -117,32 +117,36 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
117
  third_size = train_size // 3
118
  max_length = model.get_input_embeddings().weight.shape[0]
119
  try:
120
- saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
121
- if 'test' in saved_dataset.keys():
122
- print("FOUND TEST")
123
- # Create Trainer
124
- trainer = Trainer(
125
- model=model,
126
- args=training_args,
127
- train_dataset=tokenized_train_dataset,
128
- eval_dataset=tokenized_test_dataset,
129
- compute_metrics=compute_metrics,
130
- )
131
- elif 'validation' in saved_dataset.keys():
132
- print("FOUND VALIDATION")
133
- third_third = dataset['train'].select(range(third_size*2, train_size))
134
- dataset['train'] = third_third
135
- tokenized_second_half = dataset.map(tokenize_function, batched=True)
136
- dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
137
- tokenized_train_dataset = dataset['train']
138
- tokenized_test_dataset = dataset['test']
 
 
 
 
139
  else:
140
  second_third = dataset['train'].select(range(third_size, third_size*2))
141
  dataset['train'] = second_third
142
  del dataset['test']
143
  tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True)
144
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
145
- dataset.save_to_disk(f'/data/{hub_id.strip()}_train_dataset')
146
  return
147
 
148
  except:
 
117
  third_size = train_size // 3
118
  max_length = model.get_input_embeddings().weight.shape[0]
119
  try:
120
+ saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
121
+ if 'validation' in saved_dataset.keys():
122
+ if 'test' in saved_dataset.keys():
123
+ print("FOUND TEST")
124
+ dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
125
+ # Create Trainer
126
+ trainer = Trainer(
127
+ model=model,
128
+ args=training_args,
129
+ train_dataset=tokenized_train_dataset,
130
+ eval_dataset=tokenized_test_dataset,
131
+ compute_metrics=compute_metrics,
132
+ )
133
+ else:
134
+ print("FOUND VALIDATION")
135
+ saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
136
+ third_third = dataset['train'].select(range(third_size*2, train_size))
137
+ dataset['train'] = third_third
138
+ tokenized_second_half = dataset.map(tokenize_function, batched=True)
139
+ dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
140
+ tokenized_train_dataset = dataset['train']
141
+ tokenized_test_dataset = dataset['test']
142
+ dataset.save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
143
  else:
144
  second_third = dataset['train'].select(range(third_size, third_size*2))
145
  dataset['train'] = second_third
146
  del dataset['test']
147
  tokenized_sh_fq_dataset = dataset.map(tokenize_function, batched=True)
148
  dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_sh_fq_dataset['train']])
149
+ dataset.save_to_disk(f'/data/{hub_id.strip()}_train_dataset2')
150
  return
151
 
152
  except: