Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -71,7 +71,8 @@ def load_document_context(task_id):
|
|
71 |
def fine_tune_cuad_model():
|
72 |
"""
|
73 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
74 |
-
|
|
|
75 |
"""
|
76 |
from datasets import load_dataset
|
77 |
import numpy as np
|
@@ -81,9 +82,11 @@ def fine_tune_cuad_model():
|
|
81 |
dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
|
82 |
|
83 |
if "train" in dataset:
|
84 |
-
|
|
|
85 |
if "validation" in dataset:
|
86 |
-
|
|
|
87 |
else:
|
88 |
split = train_dataset.train_test_split(test_size=0.2)
|
89 |
train_dataset = split["train"]
|
@@ -148,17 +151,18 @@ def fine_tune_cuad_model():
|
|
148 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
149 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
150 |
|
|
|
151 |
training_args = TrainingArguments(
|
152 |
output_dir="./fine_tuned_legal_qa",
|
153 |
evaluation_strategy="steps",
|
154 |
-
eval_steps=
|
155 |
learning_rate=2e-5,
|
156 |
-
per_device_train_batch_size=
|
157 |
-
per_device_eval_batch_size=
|
158 |
-
num_train_epochs=1,
|
159 |
weight_decay=0.01,
|
160 |
-
logging_steps=
|
161 |
-
save_steps=
|
162 |
load_best_model_at_end=True,
|
163 |
report_to=[] # Disable wandb logging
|
164 |
)
|
@@ -737,3 +741,4 @@ if __name__ == "__main__":
|
|
737 |
else:
|
738 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
739 |
run()
|
|
|
|
71 |
def fine_tune_cuad_model():
|
72 |
"""
|
73 |
Fine tunes a QA model on the CUAD dataset for clause extraction.
|
74 |
+
For testing, we use only 50 training examples (and 10 for validation)
|
75 |
+
and set training arguments for very fast, minimal training.
|
76 |
"""
|
77 |
from datasets import load_dataset
|
78 |
import numpy as np
|
|
|
82 |
dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
|
83 |
|
84 |
if "train" in dataset:
|
85 |
+
# Use only 50 examples for training
|
86 |
+
train_dataset = dataset["train"].select(range(50))
|
87 |
if "validation" in dataset:
|
88 |
+
# Use 10 examples for validation
|
89 |
+
val_dataset = dataset["validation"].select(range(10))
|
90 |
else:
|
91 |
split = train_dataset.train_test_split(test_size=0.2)
|
92 |
train_dataset = split["train"]
|
|
|
151 |
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
152 |
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
|
153 |
|
154 |
+
# Adjust training arguments for fast testing
|
155 |
training_args = TrainingArguments(
|
156 |
output_dir="./fine_tuned_legal_qa",
|
157 |
evaluation_strategy="steps",
|
158 |
+
eval_steps=10,
|
159 |
learning_rate=2e-5,
|
160 |
+
per_device_train_batch_size=4,
|
161 |
+
per_device_eval_batch_size=4,
|
162 |
+
num_train_epochs=0.1, # Very short training for testing purposes
|
163 |
weight_decay=0.01,
|
164 |
+
logging_steps=5,
|
165 |
+
save_steps=10,
|
166 |
load_best_model_at_end=True,
|
167 |
report_to=[] # Disable wandb logging
|
168 |
)
|
|
|
741 |
else:
|
742 |
print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
|
743 |
run()
|
744 |
+
|