LHRuig commited on
Commit
1c1f2e1
·
verified ·
1 Parent(s): fea9a9e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +57 -0
  2. caption.py +21 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import gradio as gr
4
+ from caption import generate_caption # Your BLIP-2 captioning script
5
+
6
+ # ===== 1. Install Kohya_SS Manually =====
7
+ if not os.path.exists("kohya_ss"):
8
+ print("⬇️ Cloning Kohya_SS...")
9
+ os.system("git clone https://github.com/bmaltais/kohya_ss")
10
+ os.chdir("kohya_ss")
11
+ os.system("pip install -r requirements.txt")
12
+ os.system("pip install .") # Editable install
13
+ os.chdir("..")
14
+
15
+ # ===== 2. Training Function =====
16
+ def train_lora(images, trigger_word, model_choice="Flux"):
17
+ # Save images
18
+ os.makedirs("train", exist_ok=True)
19
+ for i, img in enumerate(images):
20
+ img_path = f"train/img_{i}.jpg"
21
+ img.save(img_path)
22
+ # Auto-caption (from caption.py)
23
+ caption = generate_caption(img_path, trigger_word)
24
+ with open(f"train/img_{i}.txt", "w") as f:
25
+ f.write(caption)
26
+
27
+ # Train LoRA (simplified Kohya_SS command)
28
+ cmd = f"""
29
+ python kohya_ss/train_network.py \
30
+ --pretrained_model_name_or_path="{model_choice}" \
31
+ --train_data_dir="train" \
32
+ --output_dir="output" \
33
+ --resolution=512 \
34
+ --network_dim=64 \
35
+ --lr=1e-4 \
36
+ --max_train_steps=1000
37
+ """
38
+ subprocess.run(cmd, shell=True)
39
+ return "output/lora.safetensors"
40
+
41
+ # ===== 3. Gradio UI =====
42
+ with gr.Blocks() as demo:
43
+ gr.Markdown("## 🎨 1-Click LoRA Trainer")
44
+ with gr.Row():
45
+ images = gr.Files(label="Upload 30 Images", file_types=["image"])
46
+ trigger = gr.Textbox(label="Trigger Word (e.g., 'my_char')")
47
+ train_btn = gr.Button("🚀 Train LoRA")
48
+ output = gr.File(label="Download LoRA")
49
+
50
+ train_btn.click(
51
+ train_lora,
52
+ inputs=[images, trigger],
53
+ outputs=output
54
+ )
55
+
56
+ if __name__ == "__main__":
57
+ demo.launch()
caption.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
2
+ from PIL import Image
3
+ import torch
4
+
5
+ def generate_caption(image_path, trigger_word):
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # Load BLIP-2 (smaller model for HF Spaces)
9
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
10
+ model = Blip2ForConditionalGeneration.from_pretrained(
11
+ "Salesforce/blip2-opt-2.7b",
12
+ torch_dtype=torch.float16
13
+ ).to(device)
14
+
15
+ # Generate caption
16
+ image = Image.open(image_path)
17
+ inputs = processor(image, return_tensors="pt").to(device, torch.float16)
18
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
19
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
20
+
21
+ return f"a photo of [{trigger_word}], {caption}"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==3.50.2
2
+ torch==2.2.1
3
+ accelerate==0.27.2
4
+ diffusers==0.27.2
5
+ transformers==4.38.2
6
+ huggingface-hub==0.22.2
7
+ xformers==0.0.23.post1