Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import get_peft_model, LoraConfig, TaskType | |
import torch | |
FLUX_MODEL_NAME = "black-forest-labs/FLUX" | |
def train_lora(alpha, r, lora_dropout, bias): | |
# Load the FLUX model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME) | |
# Define LoRA Config | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=r, | |
lora_alpha=alpha, | |
lora_dropout=lora_dropout, | |
bias=bias | |
) | |
# Get the PEFT model | |
model = get_peft_model(model, peft_config) | |
return f"LoRA model created for {FLUX_MODEL_NAME} with parameters: alpha={alpha}, r={r}, dropout={lora_dropout}, bias={bias}" | |
iface = gr.Interface( | |
fn=train_lora, | |
inputs=[ | |
gr.Slider(1, 100, value=32, label="LoRA Alpha"), | |
gr.Slider(1, 64, value=8, label="LoRA r"), | |
gr.Slider(0, 1, value=0.1, label="LoRA Dropout"), | |
gr.Checkbox(label="LoRA Bias") | |
], | |
outputs="text" | |
) | |
iface.launch() |