tejash300 commited on
Commit
c575db1
·
verified ·
1 Parent(s): 814be0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -71,7 +71,8 @@ def load_document_context(task_id):
71
  def fine_tune_cuad_model():
72
  """
73
  Fine tunes a QA model on the CUAD dataset for clause extraction.
74
- This demo uses one epoch; adjust parameters as needed.
 
75
  """
76
  from datasets import load_dataset
77
  import numpy as np
@@ -81,9 +82,11 @@ def fine_tune_cuad_model():
81
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
82
 
83
  if "train" in dataset:
84
- train_dataset = dataset["train"].select(range(1000))
 
85
  if "validation" in dataset:
86
- val_dataset = dataset["validation"].select(range(200))
 
87
  else:
88
  split = train_dataset.train_test_split(test_size=0.2)
89
  train_dataset = split["train"]
@@ -148,17 +151,18 @@ def fine_tune_cuad_model():
148
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
149
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
150
 
 
151
  training_args = TrainingArguments(
152
  output_dir="./fine_tuned_legal_qa",
153
  evaluation_strategy="steps",
154
- eval_steps=100,
155
  learning_rate=2e-5,
156
- per_device_train_batch_size=16,
157
- per_device_eval_batch_size=16,
158
- num_train_epochs=1,
159
  weight_decay=0.01,
160
- logging_steps=50,
161
- save_steps=100,
162
  load_best_model_at_end=True,
163
  report_to=[] # Disable wandb logging
164
  )
@@ -737,3 +741,4 @@ if __name__ == "__main__":
737
  else:
738
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
739
  run()
 
 
71
  def fine_tune_cuad_model():
72
  """
73
  Fine tunes a QA model on the CUAD dataset for clause extraction.
74
+ For testing, we use only 50 training examples (and 10 for validation)
75
+ and set training arguments for very fast, minimal training.
76
  """
77
  from datasets import load_dataset
78
  import numpy as np
 
82
  dataset = load_dataset("theatticusproject/cuad-qa", trust_remote_code=True)
83
 
84
  if "train" in dataset:
85
+ # Use only 50 examples for training
86
+ train_dataset = dataset["train"].select(range(50))
87
  if "validation" in dataset:
88
+ # Use 10 examples for validation
89
+ val_dataset = dataset["validation"].select(range(10))
90
  else:
91
  split = train_dataset.train_test_split(test_size=0.2)
92
  train_dataset = split["train"]
 
151
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
152
  val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "start_positions", "end_positions"])
153
 
154
+ # Adjust training arguments for fast testing
155
  training_args = TrainingArguments(
156
  output_dir="./fine_tuned_legal_qa",
157
  evaluation_strategy="steps",
158
+ eval_steps=10,
159
  learning_rate=2e-5,
160
+ per_device_train_batch_size=4,
161
+ per_device_eval_batch_size=4,
162
+ num_train_epochs=0.1, # Very short training for testing purposes
163
  weight_decay=0.01,
164
+ logging_steps=5,
165
+ save_steps=10,
166
  load_best_model_at_end=True,
167
  report_to=[] # Disable wandb logging
168
  )
 
741
  else:
742
  print("\n⚠️ Ngrok setup failed. API will only be available locally.\n")
743
  run()
744
+