openfree's picture
Update app.py
b79c514 verified
raw
history blame
26.9 kB
import os
import subprocess
from typing import Union
from huggingface_hub import whoami, HfApi
from fastapi import FastAPI
from starlette.middleware.sessions import SessionMiddleware
import sys
import gradio as gr
from PIL import Image
import torch
import uuid
import shutil
import json
import yaml
from slugify import slugify
from transformers import AutoProcessor, AutoModelForCausalLM
import numpy as np
# Set environment variables
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Check if we're running on HF Spaces
is_spaces = True if os.environ.get("SPACE_ID") else False
# FastAPI app setup
app = FastAPI()
app.add_middleware(SessionMiddleware, secret_key="your-secret-key")
# Constants
MAX_IMAGES = 150
# Hugging Face token setup
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is not set")
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
# Initialize HF API
api = HfApi(token=HF_TOKEN)
# Create default train config
def get_default_train_config(lora_name, username, trigger_word=None):
"""Generate a default training configuration"""
slugged_lora_name = slugify(lora_name)
config = {
"config": {
"name": slugged_lora_name,
"process": [{
"model": {
"name_or_path": "black-forest-labs/FLUX.1-dev",
"assistant_lora_path": None,
"low_vram": False,
},
"network": {
"linear": 16,
"linear_alpha": 16
},
"train": {
"skip_first_sample": True,
"steps": 1000,
"lr": 4e-4,
"disable_sampling": False
},
"datasets": [{
"folder_path": "", # Will be filled later
}],
"save": {
"push_to_hub": True,
"hf_repo_id": f"{username}/{slugged_lora_name}",
"hf_private": True,
"hf_token": HF_TOKEN
},
"sample": {
"sample_steps": 28,
"sample_every": 1000,
"prompts": []
}
}]
}
}
if trigger_word:
config["config"]["process"][0]["trigger_word"] = trigger_word
return config
# Helper functions
def load_captioning(uploaded_files, concept_sentence):
"""Load images and prepare captioning UI"""
uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
txt_files = [file for file in uploaded_files if file.endswith('.txt')]
txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
updates = []
if len(uploaded_images) <= 1:
raise gr.Error(
"Please upload at least 2 images to train your model (the ideal number is between 4-30)"
)
elif len(uploaded_images) > MAX_IMAGES:
raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
# Update captioning area visibility
updates.append(gr.update(visible=True))
# Update individual captioning rows
for i in range(1, MAX_IMAGES + 1):
visible = i <= len(uploaded_images)
updates.append(gr.update(visible=visible))
image_value = uploaded_images[i - 1] if visible else None
updates.append(gr.update(value=image_value, visible=visible))
corresponding_caption = False
if image_value:
base_name = os.path.splitext(os.path.basename(image_value))[0]
if base_name in txt_files_dict:
with open(txt_files_dict[base_name], 'r') as file:
corresponding_caption = file.read()
text_value = corresponding_caption if visible and corresponding_caption else "[trigger]" if visible and concept_sentence else None
updates.append(gr.update(value=text_value, visible=visible))
# Update sample caption area
updates.append(gr.update(visible=True))
updates.append(gr.update(placeholder=f'A portrait of person in a bustling cafe {concept_sentence}', value=f'A person in a bustling cafe {concept_sentence}'))
updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}"))
updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall"))
return updates
def hide_captioning():
"""Hide captioning UI elements"""
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
def create_dataset(images, *captions):
"""Create dataset directory with images and captions"""
destination_folder = str(f"datasets/{uuid.uuid4()}")
if not os.path.exists(destination_folder):
os.makedirs(destination_folder)
jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl")
with open(jsonl_file_path, "a") as jsonl_file:
for index, image in enumerate(images):
if image: # Skip None values
new_image_path = shutil.copy(image, destination_folder)
caption = captions[index]
file_name = os.path.basename(new_image_path)
data = {"file_name": file_name, "prompt": caption}
jsonl_file.write(json.dumps(data) + "\n")
return destination_folder
def run_captioning(images, concept_sentence, *captions):
"""Run automatic captioning using Microsoft Florence model"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16
# Load model and processor
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
captions = list(captions)
for i, image_path in enumerate(images):
if not image_path: # Skip None values
continue
if isinstance(image_path, str): # If image is a file path
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
print(f"Error opening image {image_path}: {e}")
continue
prompt = "<DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text, task=prompt, image_size=(image.width, image.height)
)
caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
if concept_sentence:
caption_text = f"{caption_text} [trigger]"
captions[i] = caption_text
yield captions
# Clean up to free memory
model.to("cpu")
del model
del processor
torch.cuda.empty_cache()
except Exception as e:
print(f"Error in captioning: {e}")
raise gr.Error(f"Captioning failed: {str(e)}")
def update_pricing(steps):
"""Update estimated cost based on training steps"""
try:
seconds_per_iteration = 7.54
total_seconds = (steps * seconds_per_iteration) + 240
cost_per_second = 0.80/60/60
cost = round(cost_per_second * total_seconds, 2)
cost_preview = f'''To train this LoRA, a paid L4 GPU will be used during training.
### Estimated to take <b>~{round(int(total_seconds)/60, 2)} minutes</b> with your current settings <small>({int(steps)} iterations)</small>'''
return gr.update(visible=True), cost_preview, gr.update(visible=False), gr.update(visible=True)
except:
return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True)
def run_training_process(config_path):
"""Run the actual training process"""
try:
# This is a simplified placeholder for the actual training code
# Instead of using the ai-toolkit which is causing errors, we'll implement our own training logic
# Call to a direct training script that doesn't require the problematic dependencies
script_path = os.path.join(os.getcwd(), "direct_train_lora.py")
with open(script_path, "w") as f:
f.write("""
import os
import sys
import yaml
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
import json
def train_lora(config_path):
# Load config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
process_config = config['config']['process'][0]
# Get basic parameters
model_name = process_config['model']['name_or_path']
lora_rank = process_config['network']['linear']
steps = process_config['train']['steps']
lr = process_config['train']['lr']
dataset_path = process_config['datasets'][0]['folder_path']
repo_id = process_config['save']['hf_repo_id']
hf_token = process_config['save']['hf_token']
# Load dataset
dataset = []
with open(os.path.join(dataset_path, "metadata.jsonl"), 'r') as f:
for line in f:
data = json.loads(line)
image_path = os.path.join(dataset_path, data['file_name'])
prompt = data['prompt']
dataset.append({"image_path": image_path, "text": prompt})
# Load base model
print(f"Loading model {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
use_auth_token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
# Configure LoRA
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA
model = get_peft_model(model, lora_config)
# Training parameters
training_args = TrainingArguments(
output_dir=f"./lora_train/{repo_id.split('/')[-1]}",
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=lr,
max_steps=steps,
fp16=True,
logging_steps=10,
save_steps=steps // 2,
push_to_hub=True,
hub_model_id=repo_id,
hub_token=hf_token,
)
# Simple dataset preparation
def process_batch(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=256
)
# Convert dataset to huggingface format
train_dataset = load_dataset('json', data_files={'train': dataset_path + '/metadata.jsonl'})['train']
# Set up trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]),
'attention_mask': torch.stack([f['attention_mask'] for f in data])},
)
# Train
print("Starting training...")
trainer.train()
# Save and push to hub
model.save_pretrained(f"./lora_final/{repo_id.split('/')[-1]}")
tokenizer.save_pretrained(f"./lora_final/{repo_id.split('/')[-1]}")
if process_config['save']['push_to_hub']:
model.push_to_hub(repo_id, use_auth_token=hf_token)
tokenizer.push_to_hub(repo_id, use_auth_token=hf_token)
print(f"Training completed! Model saved to {repo_id}")
return repo_id
if __name__ == "__main__":
if len(sys.argv) > 1:
train_lora(sys.argv[1])
else:
print("Please provide config path")
""")
result = subprocess.run([sys.executable, script_path, config_path],
capture_output=True, text=True, check=True)
print(result.stdout)
if result.returncode != 0:
raise Exception(f"Training script failed: {result.stderr}")
# Extract repo ID from config
with open(config_path, "r") as f:
config = yaml.safe_load(f)
repo_id = config["config"]["process"][0]["save"]["hf_repo_id"]
return repo_id
except Exception as e:
raise Exception(f"Training process failed: {str(e)}")
def start_training(
lora_name,
concept_sentence,
which_model,
steps,
lr,
rank,
dataset_folder,
sample_1,
sample_2,
sample_3,
use_more_advanced_options,
more_advanced_options,
):
"""Start the LoRA training process"""
if not lora_name:
raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
try:
username = whoami()["name"]
except:
raise gr.Error("Failed to get username. Please check your HF_TOKEN.")
print("Started training")
slugged_lora_name = slugify(lora_name)
# Get base config
config = get_default_train_config(lora_name, username, concept_sentence)
# Update config with form values
config["config"]["process"][0]["train"]["steps"] = int(steps)
config["config"]["process"][0]["train"]["lr"] = float(lr)
config["config"]["process"][0]["network"]["linear"] = int(rank)
config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
# Add sample prompts if provided
if sample_1 or sample_2 or sample_3:
config["config"]["process"][0]["sample"]["prompts"] = []
if sample_1:
config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
if sample_2:
config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
if sample_3:
config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
else:
config["config"]["process"][0]["train"]["disable_sampling"] = True
# Apply advanced options if enabled
if use_more_advanced_options:
try:
more_advanced_options_dict = yaml.safe_load(more_advanced_options)
def recursive_update(d, u):
for k, v in u.items():
if isinstance(v, dict) and v:
d[k] = recursive_update(d.get(k, {}), v)
else:
d[k] = v
return d
config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
except Exception as e:
raise gr.Error(f"Error in advanced options: {str(e)}")
try:
# Save the config
os.makedirs("tmp", exist_ok=True)
config_path = f"tmp/{uuid.uuid4()}-{slugged_lora_name}.yaml"
with open(config_path, "w") as f:
yaml.dump(config, f)
# Run training process
repo_id = run_training_process(config_path)
return f"""# Training completed successfully!
## Your model is available at: <a href='https://huggingface.co/{repo_id}'>{repo_id}</a>"""
except Exception as e:
raise gr.Error(f"Training failed: {str(e)}")
# UI Theme and CSS
custom_theme = gr.themes.Base(
primary_hue="indigo",
secondary_hue="slate",
neutral_hue="slate",
).set(
background_fill_primary="#1a1a1a",
background_fill_secondary="#2d2d2d",
border_color_primary="#404040",
button_primary_background_fill="#4F46E5",
button_primary_background_fill_dark="#4338CA",
button_primary_background_fill_hover="#6366F1",
button_primary_border_color="#4F46E5",
button_primary_border_color_dark="#4338CA",
button_primary_text_color="white",
button_primary_text_color_dark="white",
button_secondary_background_fill="#374151",
button_secondary_background_fill_dark="#1F2937",
button_secondary_background_fill_hover="#4B5563",
button_secondary_text_color="white",
button_secondary_text_color_dark="white",
block_background_fill="#2d2d2d",
block_background_fill_dark="#1F2937",
block_label_background_fill="#4F46E5",
block_label_background_fill_dark="#4338CA",
block_label_text_color="white",
block_label_text_color_dark="white",
block_title_text_color="white",
block_title_text_color_dark="white",
input_background_fill="#374151",
input_background_fill_dark="#1F2937",
input_border_color="#4B5563",
input_border_color_dark="#374151",
input_placeholder_color="#9CA3AF",
input_placeholder_color_dark="#6B7280",
)
css = '''
/* Base styles */
h1 {
font-size: 2.5em;
text-align: center;
margin-bottom: 0.5em;
color: white !important;
}
h3 {
margin-top: 0;
font-size: 1.2em;
color: white !important;
}
/* Ensure all text is white */
.markdown, .markdown h1, .markdown h2, .markdown h3,
.markdown h4, .markdown h5, .markdown h6, .markdown p,
label, .label-text, .gradio-radio label span, .gradio-checkbox label span,
input, textarea, .gradio-textbox input, .gradio-textbox textarea,
.gradio-number input, select, option, button {
color: white !important;
}
/* Input style improvements */
input[type="text"], textarea, .input-text, .input-textarea {
background-color: #374151 !important;
border-color: #4B5563 !important;
color: white !important;
}
/* Button styling */
button {
transition: all 0.3s ease;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
/* Image area */
.image-upload-area {
border: 2px dashed #4B5563;
border-radius: 12px;
padding: 20px;
text-align: center;
margin-bottom: 20px;
}
/* Caption rows */
.caption-row {
display: flex;
align-items: center;
margin-bottom: 10px;
gap: 10px;
}
'''
# Gradio UI
with gr.Blocks(theme=custom_theme, css=css) as demo:
gr.Markdown(
"""# ๐Ÿ†” Gini LoRA ํ•™์Šต
### 1) LoRA ์ด๋ฆ„ ์ž…๋ ฅ 2) ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด ์ž…๋ ฅ 3) ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ(2-30์žฅ ๊ถŒ์žฅ) 4) ๋น„์ „ ์ธ์‹ LLM ๋ผ๋ฒจ๋ง 5) START ํด๋ฆญ""",
elem_classes=["markdown"]
)
with gr.Tab("Train"):
with gr.Column():
# LoRA ์„ค์ •
with gr.Group():
with gr.Row():
lora_name = gr.Textbox(
label="LoRA ์ด๋ฆ„",
info="๊ณ ์œ ํ•œ ์ด๋ฆ„์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค",
placeholder="์˜ˆ: Persian Miniature Style, Cat Toy"
)
concept_sentence = gr.Textbox(
label="ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด/๋ฌธ์žฅ",
info="์‚ฌ์šฉํ•  ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด๋‚˜ ๋ฌธ์žฅ",
placeholder="p3rs0n์ด๋‚˜ trtcrd๊ฐ™์€ ํŠน์ดํ•œ ๋‹จ์–ด, ๋˜๋Š” 'in the style of CNSTLL'๊ฐ™์€ ๋ฌธ์žฅ"
)
model_warning = gr.Markdown(visible=False)
which_model = gr.Radio(
["๊ณ ํ€„๋ฆฌํ‹ฐ ๋งž์ถค ํ•™์Šต ๋ชจ๋ธ"],
label="๊ธฐ๋ณธ ๋ชจ๋ธ",
value="๊ณ ํ€„๋ฆฌํ‹ฐ ๋งž์ถค ํ•™์Šต ๋ชจ๋ธ"
)
# ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ
with gr.Group(visible=True, elem_classes="image-upload-area") as image_upload:
with gr.Row():
images = gr.File(
file_types=["image", ".txt"],
label="Upload your images",
file_count="multiple",
interactive=True,
visible=True,
scale=1,
)
with gr.Column(scale=3, visible=False) as captioning_area:
with gr.Column():
gr.Markdown(
"""# ์ด๋ฏธ์ง€ ๋ผ๋ฒจ๋ง
<p style="margin-top:0"> ๋น„์ „์ธ์‹ LLM์ด ์ด๋ฏธ์ง€๋ฅผ ์ธ์‹ํ•˜์—ฌ ์ž๋™์œผ๋กœ ๋ผ๋ฒจ๋ง(์ด๋ฏธ์ง€ ์ธ์‹์„ ์œ„ํ•œ ํ•„์ˆ˜ ์„ค๋ช…). [trigger] 'ํŠธ๋ฆฌ๊ฑฐ ์›Œ๋“œ'๋Š” ํ•™์Šตํ•œ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋Š” ๊ณ ์œ  ํ‚ค๊ฐ’</p>
""", elem_classes="group_padding")
do_captioning = gr.Button("๋น„์ „ ์ธ์‹ LLM ์ž๋™ ๋ผ๋ฒจ๋ง")
output_components = [captioning_area]
caption_list = []
for i in range(1, MAX_IMAGES + 1):
locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
with locals()[f"captioning_row_{i}"]:
locals()[f"image_{i}"] = gr.Image(
type="filepath",
width=111,
height=111,
min_width=111,
interactive=False,
scale=2,
show_label=False,
show_share_button=False,
show_download_button=False,
)
locals()[f"caption_{i}"] = gr.Textbox(
label=f"Caption {i}", scale=15, interactive=True
)
output_components.append(locals()[f"captioning_row_{i}"])
output_components.append(locals()[f"image_{i}"])
output_components.append(locals()[f"caption_{i}"])
caption_list.append(locals()[f"caption_{i}"])
# ๊ณ ๊ธ‰ ์„ค์ •
with gr.Accordion("Advanced options", open=False):
steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
with gr.Accordion("Even more advanced options", open=False):
use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False)
more_advanced_options = gr.Code(
value="""
device: cuda:0
model:
is_flux: true
quantize: true
network:
linear: 16
linear_alpha: 16
type: lora
sample:
guidance_scale: 3.5
height: 1024
neg: ''
sample_steps: 28
sampler: flowmatch
seed: 42
walk_seed: true
width: 1024
save:
dtype: float16
hf_private: true
max_step_saves_to_keep: 4
push_to_hub: true
save_every: 10000
train:
batch_size: 1
dtype: bf16
ema_config:
ema_decay: 0.99
use_ema: true
gradient_accumulation_steps: 1
gradient_checkpointing: true
noise_scheduler: flowmatch
optimizer: adamw8bit
train_text_encoder: false
train_unet: true
""",
language="yaml"
)
# ์ƒ˜ํ”Œ ํ”„๋กฌํ”„ํŠธ
with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
gr.Markdown(
"Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)"
)
sample_1 = gr.Textbox(label="Test prompt 1")
sample_2 = gr.Textbox(label="Test prompt 2")
sample_3 = gr.Textbox(label="Test prompt 3")
# ๋น„์šฉ ์•ˆ๋‚ด
with gr.Group(visible=False) as cost_preview:
cost_preview_info = gr.Markdown(elem_id="cost_preview_info", elem_classes="group_padding")
payment_update = gr.Button("I have set up a payment method", visible=False)
# ์กฐํ•ฉ ๋ณ€์ˆ˜
output_components.append(sample)
output_components.append(sample_1)
output_components.append(sample_2)
output_components.append(sample_3)
# ์‹œ์ž‘ ๋ฒ„ํŠผ
start = gr.Button("START ํด๋ฆญ ('์•ฝ 15-20๋ถ„ ํ›„ ํ•™์Šต์ด ์ข…๋ฃŒ๋˜๊ณ  ์™„๋ฃŒ ๋ฉ”์‹œ์ง€๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค')", visible=False)
# ์ง„ํ–‰ ์ƒํƒœ
progress_area = gr.Markdown("")
# ์ƒํƒœ ๋ณ€์ˆ˜
dataset_folder = gr.State()
# ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ
images.upload(
load_captioning,
inputs=[images, concept_sentence],
outputs=output_components
).then(
update_pricing,
inputs=[steps],
outputs=[cost_preview, cost_preview_info, payment_update, start]
)
images.clear(
hide_captioning,
outputs=[captioning_area, cost_preview, sample, start]
)
images.delete(
load_captioning,
inputs=[images, concept_sentence],
outputs=output_components
).then(
update_pricing,
inputs=[steps],
outputs=[cost_preview, cost_preview_info, payment_update, start]
)
steps.change(
update_pricing,
inputs=[steps],
outputs=[cost_preview, cost_preview_info, payment_update, start]
)
start.click(
fn=create_dataset,
inputs=[images] + caption_list,
outputs=dataset_folder
).then(
fn=start_training,
inputs=[
lora_name,
concept_sentence,
which_model,
steps,
lr,
rank,
dataset_folder,
sample_1,
sample_2,
sample_3,
use_more_advanced_options,
more_advanced_options
],
outputs=progress_area,
)
do_captioning.click(
fn=run_captioning,
inputs=[images, concept_sentence] + caption_list,
outputs=caption_list
)
# Launch the app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, auth=("gini", "pick"), show_error=True)