Andybeyond commited on
Commit
bf13535
·
verified ·
1 Parent(s): feebf40

Update app.py to process training data

Browse files
Files changed (1) hide show
  1. app.py +74 -6
app.py CHANGED
@@ -1,13 +1,41 @@
1
- import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import get_peft_model, LoraConfig, TaskType
4
- import torch
 
 
5
 
6
  FLUX_MODEL_NAME = "black-forest-labs/FLUX"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def train_lora(alpha, r, lora_dropout, bias):
 
 
 
 
 
 
 
 
 
9
  # Load the FLUX model and tokenizer
10
- model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME)
11
  tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME)
12
 
13
  # Define LoRA Config
@@ -23,7 +51,44 @@ def train_lora(alpha, r, lora_dropout, bias):
23
  # Get the PEFT model
24
  model = get_peft_model(model, peft_config)
25
 
26
- return f"LoRA model created for {FLUX_MODEL_NAME} with parameters: alpha={alpha}, r={r}, dropout={lora_dropout}, bias={bias}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  iface = gr.Interface(
29
  fn=train_lora,
@@ -31,7 +96,10 @@ iface = gr.Interface(
31
  gr.Slider(1, 100, value=32, label="LoRA Alpha"),
32
  gr.Slider(1, 64, value=8, label="LoRA r"),
33
  gr.Slider(0, 1, value=0.1, label="LoRA Dropout"),
34
- gr.Checkbox(label="LoRA Bias")
 
 
 
35
  ],
36
  outputs="text"
37
  )
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import get_peft_model, LoraConfig, TaskType
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import os
8
 
9
  FLUX_MODEL_NAME = "black-forest-labs/FLUX"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ class FluxDataset(Dataset):
13
+ def __init__(self, image_dir, prompt_file):
14
+ self.image_dir = image_dir
15
+ with open(prompt_file, 'r') as f:
16
+ self.prompts = [line.strip() for line in f if line.strip()]
17
+ self.image_files = [f"image_{i+1}.jpg" for i in range(len(self.prompts))]
18
+ self.transform = transforms.Compose([
19
+ transforms.Resize((224, 224)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ def __len__(self):
25
+ return len(self.prompts)
26
 
27
+ def __getitem__(self, idx):
28
+ img_path = os.path.join(self.image_dir, self.image_files[idx])
29
+ image = Image.open(img_path).convert('RGB')
30
+ image = self.transform(image)
31
+ return {
32
+ 'image': image,
33
+ 'prompt': self.prompts[idx]
34
+ }
35
+
36
+ def train_lora(alpha, r, lora_dropout, bias, batch_size, num_epochs, learning_rate):
37
  # Load the FLUX model and tokenizer
38
+ model = AutoModelForCausalLM.from_pretrained(FLUX_MODEL_NAME).to(DEVICE)
39
  tokenizer = AutoTokenizer.from_pretrained(FLUX_MODEL_NAME)
40
 
41
  # Define LoRA Config
 
51
  # Get the PEFT model
52
  model = get_peft_model(model, peft_config)
53
 
54
+ # Prepare dataset and dataloader
55
+ dataset = FluxDataset('path/to/image/directory', 'prompts.txt')
56
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
57
+
58
+ # Optimizer
59
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
60
+
61
+ # Training loop
62
+ model.train()
63
+ for epoch in range(num_epochs):
64
+ total_loss = 0
65
+ for batch in dataloader:
66
+ images = batch['image'].to(DEVICE)
67
+ prompts = batch['prompt']
68
+
69
+ # Tokenize prompts
70
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
71
+
72
+ # Forward pass
73
+ outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, labels=inputs.input_ids)
74
+ loss = outputs.loss
75
+
76
+ # Backward pass and optimize
77
+ loss.backward()
78
+ optimizer.step()
79
+ optimizer.zero_grad()
80
+
81
+ total_loss += loss.item()
82
+
83
+ print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")
84
+
85
+ # Save the LoRA model
86
+ model.save_pretrained("path/to/save/lora_model")
87
+
88
+ return "LoRA training completed and model saved."
89
+
90
+ # Gradio interface
91
+ import gradio as gr
92
 
93
  iface = gr.Interface(
94
  fn=train_lora,
 
96
  gr.Slider(1, 100, value=32, label="LoRA Alpha"),
97
  gr.Slider(1, 64, value=8, label="LoRA r"),
98
  gr.Slider(0, 1, value=0.1, label="LoRA Dropout"),
99
+ gr.Checkbox(label="LoRA Bias"),
100
+ gr.Number(value=4, label="Batch Size"),
101
+ gr.Number(value=5, label="Number of Epochs"),
102
+ gr.Number(value=1e-4, label="Learning Rate")
103
  ],
104
  outputs="text"
105
  )