Kevin Fink commited on
Commit
10e867c
·
1 Parent(s): 8849792
Files changed (1) hide show
  1. app.py +38 -40
app.py CHANGED
@@ -115,47 +115,45 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
115
  max_length = model.get_input_embeddings().weight.shape[0]
116
  try:
117
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
118
- try:
119
- load_from_disk(f'/data/{hub_id.strip()}_validation_dataset')
120
- try:
121
- train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
122
- try:
123
-
124
- saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
125
- print("FOUND TEST")
126
- # Create Trainer
127
- trainer = Trainer(
128
- model=model,
129
- args=training_args,
130
- train_dataset=train_dataset,
131
- eval_dataset=saved_test_dataset,
132
- compute_metrics=compute_metrics,
133
- )
134
- except:
135
- if len(dataset['train']) == len(train_dataset['train']):
136
- dataset = load_dataset(dataset_name.strip())
137
- del dataset['train']
138
- del dataset['validation']
139
- test_set = dataset.map(tokenize_function, batched=True)
140
- test_set['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
141
- return 'TRAINING DONE'
142
- except:
143
  dataset = load_dataset(dataset_name.strip())
144
- train_size = len(dataset['train'])
145
- third_size = train_size // 3
146
- del dataset['test']
147
- del dataset['validation']
148
- print("FOUND VALIDATION")
149
- saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
150
- third_third = dataset['train'].select(range(third_size*2, train_size))
151
- dataset['train'] = third_third
152
- print(dataset)
153
- print(dataset.keys())
154
- tokenized_second_half = dataset.map(tokenize_function, batched=True)
155
- dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
156
- dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
157
- return 'THIRD THIRD LOADED'
158
- except:
 
 
 
159
  dataset = load_dataset(dataset_name.strip())
160
  train_size = len(dataset['train'])
161
  third_size = train_size // 3
 
115
  max_length = model.get_input_embeddings().weight.shape[0]
116
  try:
117
  saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset')
118
+ if os.access(f'/data/{hub_id.strip()}_validation_dataset'):
119
+ dataset = load_dataset(dataset_name.strip())
120
+ train_size = len(dataset['train'])
121
+ third_size = train_size // 3
122
+ del dataset['test']
123
+ del dataset['validation']
124
+ print("FOUND VALIDATION")
125
+ saved_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset2')
126
+ third_third = dataset['train'].select(range(third_size*2, train_size))
127
+ dataset['train'] = third_third
128
+ print(dataset)
129
+ print(dataset.keys())
130
+ tokenized_second_half = dataset.map(tokenize_function, batched=True)
131
+ dataset['train'] = concatenate_datasets([saved_dataset['train'], tokenized_second_half['train']])
132
+ dataset['train'].save_to_disk(f'/data/{hub_id.strip()}_train_dataset3')
133
+ return 'THIRD THIRD LOADED'
134
+
135
+ if not os.access(f'/data/{hub_id.strip()}_train_dataset3'):
136
+ train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
137
+ if len(dataset['train']) == len(train_dataset['train']):
 
 
 
 
 
138
  dataset = load_dataset(dataset_name.strip())
139
+ del dataset['train']
140
+ del dataset['validation']
141
+ test_set = dataset.map(tokenize_function, batched=True)
142
+ test_set['test'].save_to_disk(f'/data/{hub_id.strip()}_test_dataset')
143
+ return 'TRAINING DONE'
144
+ else:
145
+ train_dataset = load_from_disk(f'/data/{hub_id.strip()}_train_dataset3')
146
+ saved_test_dataset = load_from_disk(f'/data/{hub_id.strip()}_test_dataset')
147
+ print("FOUND TEST")
148
+ # Create Trainer
149
+ trainer = Trainer(
150
+ model=model,
151
+ args=training_args,
152
+ train_dataset=train_dataset,
153
+ eval_dataset=saved_test_dataset,
154
+ compute_metrics=compute_metrics,
155
+ )
156
+ if os.access(f'/data/{hub_id.strip()}_train_dataset' and not os.access(f'/data/{hub_id.strip()}_train_dataset3')):
157
  dataset = load_dataset(dataset_name.strip())
158
  train_size = len(dataset['train'])
159
  third_size = train_size // 3