flux-lora-ft / app.py
Andybeyond's picture
Update app.py for FLUX LoRA fine-tuning
5f44078 verified
raw
history blame
1.14 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import torch
FLUX_MODEL_NAME = "black-forest-labs/FLUX"
def train_lora(alpha, r, lora_dropout, bias):
# Load the FLUX model and tokenizer
model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME)
# Define LoRA Config
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=r,
lora_alpha=alpha,
lora_dropout=lora_dropout,
bias=bias
)
# Get the PEFT model
model = get_peft_model(model, peft_config)
return f"LoRA model created for {FLUX_MODEL_NAME} with parameters: alpha={alpha}, r={r}, dropout={lora_dropout}, bias={bias}"
iface = gr.Interface(
fn=train_lora,
inputs=[
gr.Slider(1, 100, value=32, label="LoRA Alpha"),
gr.Slider(1, 64, value=8, label="LoRA r"),
gr.Slider(0, 1, value=0.1, label="LoRA Dropout"),
gr.Checkbox(label="LoRA Bias")
],
outputs="text"
)
iface.launch()