LHRuig commited on
Commit
d41cf19
·
verified ·
1 Parent(s): 8c2d04d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +100 -0
  2. dockerfile.dockerfile +6 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
7
+
8
+ # ===== 1. Initialize BLIP-2 for Auto-Captioning =====
9
+ def load_blip_model():
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
12
+ model = Blip2ForConditionalGeneration.from_pretrained(
13
+ "Salesforce/blip2-opt-2.7b",
14
+ torch_dtype=torch.float16
15
+ ).to(device)
16
+ return processor, model, device
17
+
18
+ processor, model, device = load_blip_model()
19
+
20
+ def generate_caption(image_path, trigger_word):
21
+ image = Image.open(image_path)
22
+ inputs = processor(image, return_tensors="pt").to(device, torch.float16)
23
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
24
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
25
+ return f"a photo of [{trigger_word}], {caption}"
26
+
27
+ # ===== 2. Install Kohya_SS Manually =====
28
+ if not os.path.exists("kohya_ss"):
29
+ print("⬇️ Installing Kohya_SS...")
30
+ os.system("git clone https://github.com/bmaltais/kohya_ss")
31
+ os.system("cd kohya_ss && pip install -r requirements.txt")
32
+ os.system("cd kohya_ss && pip install .")
33
+
34
+ # ===== 3. Training Function =====
35
+ def train_lora(images, trigger_word, progress=gr.Progress()):
36
+ progress(0.1, desc="Preparing data...")
37
+
38
+ # Save images + auto-caption
39
+ os.makedirs("train", exist_ok=True)
40
+ for i, img in enumerate(progress.tqdm(images, desc="Processing images")):
41
+ img_path = f"train/img_{i}.jpg"
42
+ img.save(img_path)
43
+ caption = generate_caption(img_path, trigger_word)
44
+ with open(f"train/img_{i}.txt", "w") as f:
45
+ f.write(caption)
46
+
47
+ # Train LoRA (optimized for HF Spaces T4 GPU)
48
+ cmd = """
49
+ python kohya_ss/train_network.py \
50
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
51
+ --train_data_dir="train" \
52
+ --output_dir="output" \
53
+ --resolution=512 \
54
+ --network_dim=32 \
55
+ --lr=1e-4 \
56
+ --max_train_steps=800 \
57
+ --mixed_precision="fp16" \
58
+ --save_precision="fp16" \
59
+ --optimizer_type="AdamW8bit" \
60
+ --xformers
61
+ """
62
+
63
+ progress(0.8, desc="Training LoRA...")
64
+ subprocess.run(cmd, shell=True, check=True)
65
+
66
+ return "output/lora.safetensors"
67
+
68
+ # ===== 4. Gradio UI =====
69
+ with gr.Blocks(title="1-Click LoRA Trainer") as demo:
70
+ gr.Markdown("""
71
+ ## 🎨 Weights.gg-Style LoRA Trainer
72
+ Upload 30 images + set a trigger word to train a custom LoRA.
73
+ """)
74
+
75
+ with gr.Row():
76
+ with gr.Column():
77
+ images = gr.Files(
78
+ label="Upload Character Images (30 max)",
79
+ file_types=["image"],
80
+ interactive=True
81
+ )
82
+ trigger = gr.Textbox(
83
+ label="Trigger Word",
84
+ placeholder="E.g., 'my_char' (used as [my_char] in prompts)"
85
+ )
86
+ train_btn = gr.Button("🚀 Train LoRA", variant="primary")
87
+
88
+ with gr.Column():
89
+ output = gr.File(label="Download LoRA")
90
+ gallery = gr.Gallery(label="Training Preview")
91
+
92
+ train_btn.click(
93
+ train_lora,
94
+ inputs=[images, trigger],
95
+ outputs=output,
96
+ api_name="train"
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch(server_name="0.0.0.0", server_port=7860)
dockerfile.dockerfile ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+ WORKDIR /app
3
+ RUN apt-get update && apt-get install -y git ffmpeg
4
+ COPY . .
5
+ RUN pip install -r requirements.txt
6
+ CMD ["python", "app.py"]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=3.50.2
2
+ torch==2.1.2
3
+ transformers>=4.38.2
4
+ accelerate>=0.27.2
5
+ diffusers>=0.27.2
6
+ huggingface-hub>=0.22.2
7
+ xformers==0.0.23.post1
8
+ Pillow>=10.4.0