Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -65,7 +65,7 @@ MAX_INPUT_LENGTH = 256
|
|
65 |
MAX_TARGET_LENGTH = 128
|
66 |
|
67 |
|
68 |
-
def preprocess_function(examples):
|
69 |
"""
|
70 |
Preprocess entries of the given dataset
|
71 |
|
@@ -74,12 +74,14 @@ def preprocess_function(examples):
|
|
74 |
Returns:
|
75 |
model_inputs (BatchEncoding): tokenized dataset entries
|
76 |
"""
|
|
|
77 |
inputs, targets = [], []
|
78 |
for i in range(len(examples['question'])):
|
79 |
inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
|
80 |
targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
|
81 |
|
82 |
# apply tokenization to inputs and labels
|
|
|
83 |
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
|
84 |
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
|
85 |
|
@@ -200,7 +202,8 @@ def load_data():
|
|
200 |
processed_dataset = split.map(
|
201 |
preprocess_function,
|
202 |
batched=True,
|
203 |
-
remove_columns=split.column_names
|
|
|
204 |
)
|
205 |
processed_dataset.set_format('torch')
|
206 |
|
|
|
65 |
MAX_TARGET_LENGTH = 128
|
66 |
|
67 |
|
68 |
+
def preprocess_function(examples, **kwargs):
|
69 |
"""
|
70 |
Preprocess entries of the given dataset
|
71 |
|
|
|
74 |
Returns:
|
75 |
model_inputs (BatchEncoding): tokenized dataset entries
|
76 |
"""
|
77 |
+
|
78 |
inputs, targets = [], []
|
79 |
for i in range(len(examples['question'])):
|
80 |
inputs.append(f"Antwort: {examples['provided_answer'][i]} Lösung: {examples['reference_answer'][i]} Frage: {examples['question'][i]}")
|
81 |
targets.append(f"{examples['verification_feedback'][i]} Feedback: {examples['answer_feedback'][i]}")
|
82 |
|
83 |
# apply tokenization to inputs and labels
|
84 |
+
tokenizer = kwargs["tokenizer"]
|
85 |
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
|
86 |
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, padding='max_length', truncation=True)
|
87 |
|
|
|
202 |
processed_dataset = split.map(
|
203 |
preprocess_function,
|
204 |
batched=True,
|
205 |
+
remove_columns=split.column_names,
|
206 |
+
fn_kwargs={"tokenizer": tokenizer}
|
207 |
)
|
208 |
processed_dataset.set_format('torch')
|
209 |
|