tejash300 commited on
Commit
d629e1d
·
verified ·
1 Parent(s): 64af888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -70,7 +70,7 @@ def fine_tune_cuad_model():
70
  """
71
  Fine tunes a QA model on the CUAD dataset for clause extraction.
72
  For testing, we use only 50 training examples (and 10 for validation)
73
- and restrict training to 10 steps.
74
  """
75
  from datasets import load_dataset
76
  import numpy as np
@@ -149,20 +149,19 @@ def fine_tune_cuad_model():
149
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
150
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
151
 
152
- # Set max_steps to 10 for fast testing.
153
  training_args = TrainingArguments(
154
  output_dir="./fine_tuned_legal_qa",
155
- max_steps=10,
156
- evaluation_strategy="steps",
157
- eval_steps=5,
158
  learning_rate=2e-5,
159
  per_device_train_batch_size=4,
160
  per_device_eval_batch_size=4,
161
  num_train_epochs=1,
162
  weight_decay=0.01,
163
  logging_steps=1,
164
- save_steps=5,
165
- load_best_model_at_end=True,
166
  report_to=[] # Disable wandb logging
167
  )
168
 
@@ -191,7 +190,7 @@ def fine_tune_cuad_model():
191
  try:
192
  try:
193
  nlp = spacy.load("en_core_web_sm")
194
- except:
195
  spacy.cli.download("en_core_web_sm")
196
  nlp = spacy.load("en_core_web_sm")
197
  print("✅ Loading NLP models...")
 
70
  """
71
  Fine tunes a QA model on the CUAD dataset for clause extraction.
72
  For testing, we use only 50 training examples (and 10 for validation)
73
+ and restrict training to 1 step with evaluation disabled.
74
  """
75
  from datasets import load_dataset
76
  import numpy as np
 
149
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
150
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
151
 
152
+ # Set max_steps to 1 for very fast testing and disable evaluation
153
  training_args = TrainingArguments(
154
  output_dir="./fine_tuned_legal_qa",
155
+ max_steps=1, # Only one training step
156
+ evaluation_strategy="no", # Disable evaluation during training
 
157
  learning_rate=2e-5,
158
  per_device_train_batch_size=4,
159
  per_device_eval_batch_size=4,
160
  num_train_epochs=1,
161
  weight_decay=0.01,
162
  logging_steps=1,
163
+ save_steps=1,
164
+ load_best_model_at_end=False,
165
  report_to=[] # Disable wandb logging
166
  )
167
 
 
190
  try:
191
  try:
192
  nlp = spacy.load("en_core_web_sm")
193
+ except Exception:
194
  spacy.cli.download("en_core_web_sm")
195
  nlp = spacy.load("en_core_web_sm")
196
  print("✅ Loading NLP models...")