Kevin Fink
commited on
Commit
·
d7a9615
1
Parent(s):
6fdec3f
deve
Browse files
app.py
CHANGED
@@ -38,10 +38,10 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
38 |
preds = preds[0]
|
39 |
# Replace -100s used for padding as we can't decode them
|
40 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
41 |
-
preds = np.array(preds)
|
42 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
43 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
44 |
-
labels = np.array(labels)
|
45 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
46 |
|
47 |
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
@@ -59,7 +59,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
59 |
|
60 |
# Set training arguments
|
61 |
training_args = TrainingArguments(
|
62 |
-
remove_unused_columns=False,
|
63 |
torch_empty_cache_steps=100,
|
64 |
overwrite_output_dir=True,
|
65 |
output_dir='/data/results',
|
@@ -208,6 +208,18 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
208 |
#print('DONE')
|
209 |
#return 'RUN AGAIN TO LOAD REST OF DATA'
|
210 |
dataset = load_dataset(dataset_name.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
#dataset['train'] = dataset['train'].select(range(8000))
|
212 |
dataset['train'] = dataset['train'].select(range(4000))
|
213 |
dataset['validation'] = dataset['validation'].select(range(200))
|
|
|
38 |
preds = preds[0]
|
39 |
# Replace -100s used for padding as we can't decode them
|
40 |
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
|
41 |
+
#preds = np.array(preds)
|
42 |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
43 |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
44 |
+
#labels = np.array(labels)
|
45 |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
46 |
|
47 |
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
|
|
59 |
|
60 |
# Set training arguments
|
61 |
training_args = TrainingArguments(
|
62 |
+
remove_unused_columns=False,
|
63 |
torch_empty_cache_steps=100,
|
64 |
overwrite_output_dir=True,
|
65 |
output_dir='/data/results',
|
|
|
208 |
#print('DONE')
|
209 |
#return 'RUN AGAIN TO LOAD REST OF DATA'
|
210 |
dataset = load_dataset(dataset_name.strip())
|
211 |
+
for o, d in enumerate(dataset['validation']['text']):
|
212 |
+
if not isinstance(d, str):
|
213 |
+
print('hit')
|
214 |
+
print(type(d))
|
215 |
+
print(o)
|
216 |
+
for o, d in enumerate(dataset['validation']['target']):
|
217 |
+
if not isinstance(d, str):
|
218 |
+
print('hit')
|
219 |
+
print(type(d))
|
220 |
+
print(o)
|
221 |
+
return 'done'
|
222 |
+
|
223 |
#dataset['train'] = dataset['train'].select(range(8000))
|
224 |
dataset['train'] = dataset['train'].select(range(4000))
|
225 |
dataset['validation'] = dataset['validation'].select(range(200))
|