Kevin Fink commited on
Commit
ee975a5
·
1 Parent(s): debdc1c
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -97,6 +97,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
97
  examples['text'],
98
  max_length=max_length, # Set to None for dynamic padding
99
  truncation=True,
 
100
  )
101
 
102
  # Setup the decoder input IDs (shifted right)
@@ -104,6 +105,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
104
  examples['target'],
105
  max_length=max_length, # Set to None for dynamic padding
106
  truncation=True,
 
107
  text_target=examples['target'] # Use text_target for target text
108
  )
109
 
@@ -147,6 +149,14 @@ def predict(text):
147
  def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
148
  config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
149
  model = AutoModelForSeq2SeqLM.from_config(config)
 
 
 
 
 
 
 
 
150
  result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
151
  return result
152
  # Create Gradio interface
 
97
  examples['text'],
98
  max_length=max_length, # Set to None for dynamic padding
99
  truncation=True,
100
+ padding=True,
101
  )
102
 
103
  # Setup the decoder input IDs (shifted right)
 
105
  examples['target'],
106
  max_length=max_length, # Set to None for dynamic padding
107
  truncation=True,
108
+ padding=True,
109
  text_target=examples['target'] # Use text_target for target text
110
  )
111
 
 
149
  def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
150
  config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
151
  model = AutoModelForSeq2SeqLM.from_config(config)
152
+ lora_config = LoraConfig(
153
+ r=16, # Rank of the low-rank adaptation
154
+ lora_alpha=32, # Scaling factor
155
+ lora_dropout=0.1, # Dropout for LoRA layers
156
+ bias="none" # Bias handling
157
+ )
158
+ model = get_peft_model(model, lora_config)
159
+ model.gradient_checkpointing_enable()
160
  result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
161
  return result
162
  # Create Gradio interface