Kevin Fink
commited on
Commit
·
451a63d
1
Parent(s):
0ee2b72
dev
Browse files
app.py
CHANGED
|
@@ -28,8 +28,9 @@ model_save_path = '/data/lora_finetuned_model' # Specify your desired save path
|
|
| 28 |
model.save_pretrained(model_save_path)
|
| 29 |
'''
|
| 30 |
|
| 31 |
-
def fine_tune_model(
|
| 32 |
try:
|
|
|
|
| 33 |
torch.cuda.empty_cache()
|
| 34 |
torch.nn.CrossEntropyLoss()
|
| 35 |
#rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
|
@@ -335,7 +336,7 @@ except Exception as e:
|
|
| 335 |
# Create Gradio interface
|
| 336 |
try:
|
| 337 |
iface = gr.Interface(
|
| 338 |
-
fn=
|
| 339 |
inputs=[
|
| 340 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
| 341 |
gr.Textbox(label="HF hub to push to after training"),
|
|
|
|
| 28 |
model.save_pretrained(model_save_path)
|
| 29 |
'''
|
| 30 |
|
| 31 |
+
def fine_tune_model(dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
|
| 32 |
try:
|
| 33 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-efficient-tiny-nh8")
|
| 34 |
torch.cuda.empty_cache()
|
| 35 |
torch.nn.CrossEntropyLoss()
|
| 36 |
#rouge_metric = evaluate.load("rouge", cache_dir='/data/cache')
|
|
|
|
| 336 |
# Create Gradio interface
|
| 337 |
try:
|
| 338 |
iface = gr.Interface(
|
| 339 |
+
fn=fine_tune_model,
|
| 340 |
inputs=[
|
| 341 |
gr.Textbox(label="Dataset Name (e.g., 'imdb')"),
|
| 342 |
gr.Textbox(label="HF hub to push to after training"),
|