Kevin Fink commited on
Commit
d7a9615
·
1 Parent(s): 6fdec3f
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -38,10 +38,10 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
38
  preds = preds[0]
39
  # Replace -100s used for padding as we can't decode them
40
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
41
- preds = np.array(preds)
42
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
43
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
44
- labels = np.array(labels)
45
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
46
 
47
  result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
@@ -59,7 +59,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
59
 
60
  # Set training arguments
61
  training_args = TrainingArguments(
62
- remove_unused_columns=False,
63
  torch_empty_cache_steps=100,
64
  overwrite_output_dir=True,
65
  output_dir='/data/results',
@@ -208,6 +208,18 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
208
  #print('DONE')
209
  #return 'RUN AGAIN TO LOAD REST OF DATA'
210
  dataset = load_dataset(dataset_name.strip())
 
 
 
 
 
 
 
 
 
 
 
 
211
  #dataset['train'] = dataset['train'].select(range(8000))
212
  dataset['train'] = dataset['train'].select(range(4000))
213
  dataset['validation'] = dataset['validation'].select(range(200))
 
38
  preds = preds[0]
39
  # Replace -100s used for padding as we can't decode them
40
  preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
41
+ #preds = np.array(preds)
42
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
43
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
44
+ #labels = np.array(labels)
45
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
46
 
47
  result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
 
59
 
60
  # Set training arguments
61
  training_args = TrainingArguments(
62
+ remove_unused_columns=False,
63
  torch_empty_cache_steps=100,
64
  overwrite_output_dir=True,
65
  output_dir='/data/results',
 
208
  #print('DONE')
209
  #return 'RUN AGAIN TO LOAD REST OF DATA'
210
  dataset = load_dataset(dataset_name.strip())
211
+ for o, d in enumerate(dataset['validation']['text']):
212
+ if not isinstance(d, str):
213
+ print('hit')
214
+ print(type(d))
215
+ print(o)
216
+ for o, d in enumerate(dataset['validation']['target']):
217
+ if not isinstance(d, str):
218
+ print('hit')
219
+ print(type(d))
220
+ print(o)
221
+ return 'done'
222
+
223
  #dataset['train'] = dataset['train'].select(range(8000))
224
  dataset['train'] = dataset['train'].select(range(4000))
225
  dataset['validation'] = dataset['validation'].select(range(200))