import torch from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from peft import get_peft_model, LoraConfig, TaskType from PIL import Image import torchvision.transforms as transforms import os FLUX_MODEL_NAME = "black-forest-labs/FLUX.1-dev" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class FluxDataset(Dataset): def __init__(self, image_dir, prompt_file): self.image_dir = image_dir with open(prompt_file, 'r') as f: self.prompts = [line.strip() for line in f if line.strip()] self.image_files = [f"image_{i+1}.jpg" for i in range(len(self.prompts))] self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.prompts) def __getitem__(self, idx): img_path = os.path.join(self.image_dir, self.image_files[idx]) image = Image.open(img_path).convert('RGB') image = self.transform(image) return { 'image': image, 'prompt': self.prompts[idx] } def train_lora(alpha, r, lora_dropout, bias, batch_size, num_epochs, learning_rate): # Load the FLUX model and tokenizer model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME) # Define LoRA Config peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=r, lora_alpha=alpha, lora_dropout=lora_dropout, bias=bias ) # Get the PEFT model model = get_peft_model(model, peft_config) # Prepare dataset and dataloader dataset = FluxDataset('path/to/image/directory', 'prompts.txt') dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # Training loop model.train() for epoch in range(num_epochs): total_loss = 0 for batch in dataloader: images = batch['image'].to(DEVICE) prompts = batch['prompt'] # Tokenize prompts inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(DEVICE) # Forward pass outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=inputs.input_ids) loss = outputs.loss # Backward pass and optimize loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}") # Save the LoRA model model.save_pretrained("path/to/save/lora_model") return "LoRA training completed and model saved." # Gradio interface import gradio as gr iface = gr.Interface( fn=train_lora, inputs=[ gr.Slider(1, 100, value=32, label="LoRA Alpha"), gr.Slider(1, 64, value=8, label="LoRA r"), gr.Slider(0, 1, value=0.1, label="LoRA Dropout"), gr.Checkbox(label="LoRA Bias"), gr.Number(value=4, label="Batch Size"), gr.Number(value=5, label="Number of Epochs"), gr.Number(value=1e-4, label="Learning Rate") ], outputs="text" ) iface.launch()