Kevin Fink
commited on
Commit
·
ee975a5
1
Parent(s):
debdc1c
dev
Browse files
app.py
CHANGED
@@ -97,6 +97,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
97 |
examples['text'],
|
98 |
max_length=max_length, # Set to None for dynamic padding
|
99 |
truncation=True,
|
|
|
100 |
)
|
101 |
|
102 |
# Setup the decoder input IDs (shifted right)
|
@@ -104,6 +105,7 @@ def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size
|
|
104 |
examples['target'],
|
105 |
max_length=max_length, # Set to None for dynamic padding
|
106 |
truncation=True,
|
|
|
107 |
text_target=examples['target'] # Use text_target for target text
|
108 |
)
|
109 |
|
@@ -147,6 +149,14 @@ def predict(text):
|
|
147 |
def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
148 |
config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
|
149 |
model = AutoModelForSeq2SeqLM.from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
|
151 |
return result
|
152 |
# Create Gradio interface
|
|
|
97 |
examples['text'],
|
98 |
max_length=max_length, # Set to None for dynamic padding
|
99 |
truncation=True,
|
100 |
+
padding=True,
|
101 |
)
|
102 |
|
103 |
# Setup the decoder input IDs (shifted right)
|
|
|
105 |
examples['target'],
|
106 |
max_length=max_length, # Set to None for dynamic padding
|
107 |
truncation=True,
|
108 |
+
padding=True,
|
109 |
text_target=examples['target'] # Use text_target for target text
|
110 |
)
|
111 |
|
|
|
149 |
def run_train(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
150 |
config = AutoConfig.from_pretrained("google/t5-efficient-tiny")
|
151 |
model = AutoModelForSeq2SeqLM.from_config(config)
|
152 |
+
lora_config = LoraConfig(
|
153 |
+
r=16, # Rank of the low-rank adaptation
|
154 |
+
lora_alpha=32, # Scaling factor
|
155 |
+
lora_dropout=0.1, # Dropout for LoRA layers
|
156 |
+
bias="none" # Bias handling
|
157 |
+
)
|
158 |
+
model = get_peft_model(model, lora_config)
|
159 |
+
model.gradient_checkpointing_enable()
|
160 |
result = fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad)
|
161 |
return result
|
162 |
# Create Gradio interface
|