Spaces:
Runtime error
Runtime error
Update app.py to process training data
Browse files
app.py
CHANGED
@@ -1,13 +1,41 @@
|
|
1 |
-
import
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from peft import get_peft_model, LoraConfig, TaskType
|
4 |
-
import
|
|
|
|
|
5 |
|
6 |
FLUX_MODEL_NAME = "black-forest-labs/FLUX"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Load the FLUX model and tokenizer
|
10 |
-
model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME)
|
11 |
tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME)
|
12 |
|
13 |
# Define LoRA Config
|
@@ -23,7 +51,44 @@ def train_lora(alpha, r, lora_dropout, bias):
|
|
23 |
# Get the PEFT model
|
24 |
model = get_peft_model(model, peft_config)
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
iface = gr.Interface(
|
29 |
fn=train_lora,
|
@@ -31,7 +96,10 @@ iface = gr.Interface(
|
|
31 |
gr.Slider(1, 100, value=32, label="LoRA Alpha"),
|
32 |
gr.Slider(1, 64, value=8, label="LoRA r"),
|
33 |
gr.Slider(0, 1, value=0.1, label="LoRA Dropout"),
|
34 |
-
gr.Checkbox(label="LoRA Bias")
|
|
|
|
|
|
|
35 |
],
|
36 |
outputs="text"
|
37 |
)
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import os
|
8 |
|
9 |
FLUX_MODEL_NAME = "black-forest-labs/FLUX"
|
10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
class FluxDataset(Dataset):
|
13 |
+
def __init__(self, image_dir, prompt_file):
|
14 |
+
self.image_dir = image_dir
|
15 |
+
with open(prompt_file, 'r') as f:
|
16 |
+
self.prompts = [line.strip() for line in f if line.strip()]
|
17 |
+
self.image_files = [f"image_{i+1}.jpg" for i in range(len(self.prompts))]
|
18 |
+
self.transform = transforms.Compose([
|
19 |
+
transforms.Resize((224, 224)),
|
20 |
+
transforms.ToTensor(),
|
21 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
22 |
+
])
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.prompts)
|
26 |
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
img_path = os.path.join(self.image_dir, self.image_files[idx])
|
29 |
+
image = Image.open(img_path).convert('RGB')
|
30 |
+
image = self.transform(image)
|
31 |
+
return {
|
32 |
+
'image': image,
|
33 |
+
'prompt': self.prompts[idx]
|
34 |
+
}
|
35 |
+
|
36 |
+
def train_lora(alpha, r, lora_dropout, bias, batch_size, num_epochs, learning_rate):
|
37 |
# Load the FLUX model and tokenizer
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME).to(DEVICE)
|
39 |
tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME)
|
40 |
|
41 |
# Define LoRA Config
|
|
|
51 |
# Get the PEFT model
|
52 |
model = get_peft_model(model, peft_config)
|
53 |
|
54 |
+
# Prepare dataset and dataloader
|
55 |
+
dataset = FluxDataset('path/to/image/directory', 'prompts.txt')
|
56 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
57 |
+
|
58 |
+
# Optimizer
|
59 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
60 |
+
|
61 |
+
# Training loop
|
62 |
+
model.train()
|
63 |
+
for epoch in range(num_epochs):
|
64 |
+
total_loss = 0
|
65 |
+
for batch in dataloader:
|
66 |
+
images = batch['image'].to(DEVICE)
|
67 |
+
prompts = batch['prompt']
|
68 |
+
|
69 |
+
# Tokenize prompts
|
70 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
|
71 |
+
|
72 |
+
# Forward pass
|
73 |
+
outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=inputs.input_ids)
|
74 |
+
loss = outputs.loss
|
75 |
+
|
76 |
+
# Backward pass and optimize
|
77 |
+
loss.backward()
|
78 |
+
optimizer.step()
|
79 |
+
optimizer.zero_grad()
|
80 |
+
|
81 |
+
total_loss += loss.item()
|
82 |
+
|
83 |
+
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")
|
84 |
+
|
85 |
+
# Save the LoRA model
|
86 |
+
model.save_pretrained("path/to/save/lora_model")
|
87 |
+
|
88 |
+
return "LoRA training completed and model saved."
|
89 |
+
|
90 |
+
# Gradio interface
|
91 |
+
import gradio as gr
|
92 |
|
93 |
iface = gr.Interface(
|
94 |
fn=train_lora,
|
|
|
96 |
gr.Slider(1, 100, value=32, label="LoRA Alpha"),
|
97 |
gr.Slider(1, 64, value=8, label="LoRA r"),
|
98 |
gr.Slider(0, 1, value=0.1, label="LoRA Dropout"),
|
99 |
+
gr.Checkbox(label="LoRA Bias"),
|
100 |
+
gr.Number(value=4, label="Batch Size"),
|
101 |
+
gr.Number(value=5, label="Number of Epochs"),
|
102 |
+
gr.Number(value=1e-4, label="Learning Rate")
|
103 |
],
|
104 |
outputs="text"
|
105 |
)
|