Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -70,7 +70,7 @@ def fine_tune_cuad_model():
|
|
70 |
"""
|
71 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
72 |
For testing, we use only 50 training examples (and 10 for validation)
|
73 |
-
and restrict training to
|
74 |
"""
|
75 |
from datasets import load_dataset
|
76 |
import numpy as np
|
@@ -149,20 +149,19 @@ def fine_tune_cuad_model():
|
|
149 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
150 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
151 |
|
152 |
-
# Set max_steps to
|
153 |
training_args = TrainingArguments(
|
154 |
output_dir="./fine_tuned_legal_qa",
|
155 |
-
max_steps=
|
156 |
-
evaluation_strategy="
|
157 |
-
eval_steps=5,
|
158 |
learning_rate=2e-5,
|
159 |
per_device_train_batch_size=4,
|
160 |
per_device_eval_batch_size=4,
|
161 |
num_train_epochs=1,
|
162 |
weight_decay=0.01,
|
163 |
logging_steps=1,
|
164 |
-
save_steps=
|
165 |
-
load_best_model_at_end=
|
166 |
report_to=[] # Disable wandb logging
|
167 |
)
|
168 |
|
@@ -191,7 +190,7 @@ def fine_tune_cuad_model():
|
|
191 |
try:
|
192 |
try:
|
193 |
nlp = spacy.load("en_core_web_sm")
|
194 |
-
except:
|
195 |
spacy.cli.download("en_core_web_sm")
|
196 |
nlp = spacy.load("en_core_web_sm")
|
197 |
print("✅ Loading NLP models...")
|
|
|
70 |
"""
|
71 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
72 |
For testing, we use only 50 training examples (and 10 for validation)
|
73 |
+
and restrict training to 1 step with evaluation disabled.
|
74 |
"""
|
75 |
from datasets import load_dataset
|
76 |
import numpy as np
|
|
|
149 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
150 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
151 |
|
152 |
+
# Set max_steps to 1 for very fast testing and disable evaluation
|
153 |
training_args = TrainingArguments(
|
154 |
output_dir="./fine_tuned_legal_qa",
|
155 |
+
max_steps=1, # Only one training step
|
156 |
+
evaluation_strategy="no", # Disable evaluation during training
|
|
|
157 |
learning_rate=2e-5,
|
158 |
per_device_train_batch_size=4,
|
159 |
per_device_eval_batch_size=4,
|
160 |
num_train_epochs=1,
|
161 |
weight_decay=0.01,
|
162 |
logging_steps=1,
|
163 |
+
save_steps=1,
|
164 |
+
load_best_model_at_end=False,
|
165 |
report_to=[] # Disable wandb logging
|
166 |
)
|
167 |
|
|
|
190 |
try:
|
191 |
try:
|
192 |
nlp = spacy.load("en_core_web_sm")
|
193 |
+
except Exception:
|
194 |
spacy.cli.download("en_core_web_sm")
|
195 |
nlp = spacy.load("en_core_web_sm")
|
196 |
print("✅ Loading NLP models...")
|