import os import subprocess from typing import Union from huggingface_hub import whoami, HfApi from fastapi import FastAPI from starlette.middleware.sessions import SessionMiddleware import sys import gradio as gr from PIL import Image import torch import uuid import shutil import json import yaml from slugify import slugify from transformers import AutoProcessor, AutoModelForCausalLM import numpy as np # Set environment variables os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Check if we're running on HF Spaces is_spaces = True if os.environ.get("SPACE_ID") else False # FastAPI app setup app = FastAPI() app.add_middleware(SessionMiddleware, secret_key="your-secret-key") # Constants MAX_IMAGES = 150 # Hugging Face token setup HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is not set") os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN # Initialize HF API api = HfApi(token=HF_TOKEN) # Create default train config def get_default_train_config(lora_name, username, trigger_word=None): """Generate a default training configuration""" slugged_lora_name = slugify(lora_name) config = { "config": { "name": slugged_lora_name, "process": [{ "model": { "name_or_path": "black-forest-labs/FLUX.1-dev", "assistant_lora_path": None, "low_vram": False, }, "network": { "linear": 16, "linear_alpha": 16 }, "train": { "skip_first_sample": True, "steps": 1000, "lr": 4e-4, "disable_sampling": False }, "datasets": [{ "folder_path": "", # Will be filled later }], "save": { "push_to_hub": True, "hf_repo_id": f"{username}/{slugged_lora_name}", "hf_private": True, "hf_token": HF_TOKEN }, "sample": { "sample_steps": 28, "sample_every": 1000, "prompts": [] } }] } } if trigger_word: config["config"]["process"][0]["trigger_word"] = trigger_word return config # Helper functions def load_captioning(uploaded_files, concept_sentence): """Load images and prepare captioning UI""" uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] txt_files = [file for file in uploaded_files if file.endswith('.txt')] txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} updates = [] if len(uploaded_images) <= 1: raise gr.Error( "Please upload at least 2 images to train your model (the ideal number is between 4-30)" ) elif len(uploaded_images) > MAX_IMAGES: raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") # Update captioning area visibility updates.append(gr.update(visible=True)) # Update individual captioning rows for i in range(1, MAX_IMAGES + 1): visible = i <= len(uploaded_images) updates.append(gr.update(visible=visible)) image_value = uploaded_images[i - 1] if visible else None updates.append(gr.update(value=image_value, visible=visible)) corresponding_caption = False if image_value: base_name = os.path.splitext(os.path.basename(image_value))[0] if base_name in txt_files_dict: with open(txt_files_dict[base_name], 'r') as file: corresponding_caption = file.read() text_value = corresponding_caption if visible and corresponding_caption else "[trigger]" if visible and concept_sentence else None updates.append(gr.update(value=text_value, visible=visible)) # Update sample caption area updates.append(gr.update(visible=True)) updates.append(gr.update(placeholder=f'A portrait of person in a bustling cafe {concept_sentence}', value=f'A person in a bustling cafe {concept_sentence}')) updates.append(gr.update(placeholder=f"A mountainous landscape in the style of {concept_sentence}")) updates.append(gr.update(placeholder=f"A {concept_sentence} in a mall")) return updates def hide_captioning(): """Hide captioning UI elements""" return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) def create_dataset(images, *captions): """Create dataset directory with images and captions""" destination_folder = str(f"datasets/{uuid.uuid4()}") if not os.path.exists(destination_folder): os.makedirs(destination_folder) jsonl_file_path = os.path.join(destination_folder, "metadata.jsonl") with open(jsonl_file_path, "a") as jsonl_file: for index, image in enumerate(images): if image: # Skip None values new_image_path = shutil.copy(image, destination_folder) caption = captions[index] file_name = os.path.basename(new_image_path) data = {"file_name": file_name, "prompt": caption} jsonl_file.write(json.dumps(data) + "\n") return destination_folder def run_captioning(images, concept_sentence, *captions): """Run automatic captioning using Microsoft Florence model""" try: device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 # Load model and processor model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True ).to(device) processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) captions = list(captions) for i, image_path in enumerate(images): if not image_path: # Skip None values continue if isinstance(image_path, str): # If image is a file path try: image = Image.open(image_path).convert("RGB") except Exception as e: print(f"Error opening image {image_path}: {e}") continue prompt = "" inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=prompt, image_size=(image.width, image.height) ) caption_text = parsed_answer[""].replace("The image shows ", "") if concept_sentence: caption_text = f"{caption_text} [trigger]" captions[i] = caption_text yield captions # Clean up to free memory model.to("cpu") del model del processor torch.cuda.empty_cache() except Exception as e: print(f"Error in captioning: {e}") raise gr.Error(f"Captioning failed: {str(e)}") def update_pricing(steps): """Update estimated cost based on training steps""" try: seconds_per_iteration = 7.54 total_seconds = (steps * seconds_per_iteration) + 240 cost_per_second = 0.80/60/60 cost = round(cost_per_second * total_seconds, 2) cost_preview = f'''To train this LoRA, a paid L4 GPU will be used during training. ### Estimated to take ~{round(int(total_seconds)/60, 2)} minutes with your current settings ({int(steps)} iterations)''' return gr.update(visible=True), cost_preview, gr.update(visible=False), gr.update(visible=True) except: return gr.update(visible=False), "", gr.update(visible=False), gr.update(visible=True) def run_training_process(config_path): """Run the actual training process""" try: # This is a simplified placeholder for the actual training code # Instead of using the ai-toolkit which is causing errors, we'll implement our own training logic # Call to a direct training script that doesn't require the problematic dependencies script_path = os.path.join(os.getcwd(), "direct_train_lora.py") with open(script_path, "w") as f: f.write(""" import os import sys import yaml import torch from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer from datasets import load_dataset import json def train_lora(config_path): # Load config with open(config_path, 'r') as f: config = yaml.safe_load(f) process_config = config['config']['process'][0] # Get basic parameters model_name = process_config['model']['name_or_path'] lora_rank = process_config['network']['linear'] steps = process_config['train']['steps'] lr = process_config['train']['lr'] dataset_path = process_config['datasets'][0]['folder_path'] repo_id = process_config['save']['hf_repo_id'] hf_token = process_config['save']['hf_token'] # Load dataset dataset = [] with open(os.path.join(dataset_path, "metadata.jsonl"), 'r') as f: for line in f: data = json.loads(line) image_path = os.path.join(dataset_path, data['file_name']) prompt = data['prompt'] dataset.append({"image_path": image_path, "text": prompt}) # Load base model print(f"Loading model {model_name}") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, use_auth_token=hf_token ) tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) # Configure LoRA lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_rank, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # Apply LoRA model = get_peft_model(model, lora_config) # Training parameters training_args = TrainingArguments( output_dir=f"./lora_train/{repo_id.split('/')[-1]}", num_train_epochs=3, per_device_train_batch_size=1, gradient_accumulation_steps=4, learning_rate=lr, max_steps=steps, fp16=True, logging_steps=10, save_steps=steps // 2, push_to_hub=True, hub_model_id=repo_id, hub_token=hf_token, ) # Simple dataset preparation def process_batch(examples): return tokenizer( examples["text"], padding="max_length", truncation=True, max_length=256 ) # Convert dataset to huggingface format train_dataset = load_dataset('json', data_files={'train': dataset_path + '/metadata.jsonl'})['train'] # Set up trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=lambda data: {'input_ids': torch.stack([f['input_ids'] for f in data]), 'attention_mask': torch.stack([f['attention_mask'] for f in data])}, ) # Train print("Starting training...") trainer.train() # Save and push to hub model.save_pretrained(f"./lora_final/{repo_id.split('/')[-1]}") tokenizer.save_pretrained(f"./lora_final/{repo_id.split('/')[-1]}") if process_config['save']['push_to_hub']: model.push_to_hub(repo_id, use_auth_token=hf_token) tokenizer.push_to_hub(repo_id, use_auth_token=hf_token) print(f"Training completed! Model saved to {repo_id}") return repo_id if __name__ == "__main__": if len(sys.argv) > 1: train_lora(sys.argv[1]) else: print("Please provide config path") """) result = subprocess.run([sys.executable, script_path, config_path], capture_output=True, text=True, check=True) print(result.stdout) if result.returncode != 0: raise Exception(f"Training script failed: {result.stderr}") # Extract repo ID from config with open(config_path, "r") as f: config = yaml.safe_load(f) repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] return repo_id except Exception as e: raise Exception(f"Training process failed: {str(e)}") def start_training( lora_name, concept_sentence, which_model, steps, lr, rank, dataset_folder, sample_1, sample_2, sample_3, use_more_advanced_options, more_advanced_options, ): """Start the LoRA training process""" if not lora_name: raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.") try: username = whoami()["name"] except: raise gr.Error("Failed to get username. Please check your HF_TOKEN.") print("Started training") slugged_lora_name = slugify(lora_name) # Get base config config = get_default_train_config(lora_name, username, concept_sentence) # Update config with form values config["config"]["process"][0]["train"]["steps"] = int(steps) config["config"]["process"][0]["train"]["lr"] = float(lr) config["config"]["process"][0]["network"]["linear"] = int(rank) config["config"]["process"][0]["network"]["linear_alpha"] = int(rank) config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder # Add sample prompts if provided if sample_1 or sample_2 or sample_3: config["config"]["process"][0]["sample"]["prompts"] = [] if sample_1: config["config"]["process"][0]["sample"]["prompts"].append(sample_1) if sample_2: config["config"]["process"][0]["sample"]["prompts"].append(sample_2) if sample_3: config["config"]["process"][0]["sample"]["prompts"].append(sample_3) else: config["config"]["process"][0]["train"]["disable_sampling"] = True # Apply advanced options if enabled if use_more_advanced_options: try: more_advanced_options_dict = yaml.safe_load(more_advanced_options) def recursive_update(d, u): for k, v in u.items(): if isinstance(v, dict) and v: d[k] = recursive_update(d.get(k, {}), v) else: d[k] = v return d config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict) except Exception as e: raise gr.Error(f"Error in advanced options: {str(e)}") try: # Save the config os.makedirs("tmp", exist_ok=True) config_path = f"tmp/{uuid.uuid4()}-{slugged_lora_name}.yaml" with open(config_path, "w") as f: yaml.dump(config, f) # Run training process repo_id = run_training_process(config_path) return f"""# Training completed successfully! ## Your model is available at: {repo_id}""" except Exception as e: raise gr.Error(f"Training failed: {str(e)}") # UI Theme and CSS custom_theme = gr.themes.Base( primary_hue="indigo", secondary_hue="slate", neutral_hue="slate", ).set( background_fill_primary="#1a1a1a", background_fill_secondary="#2d2d2d", border_color_primary="#404040", button_primary_background_fill="#4F46E5", button_primary_background_fill_dark="#4338CA", button_primary_background_fill_hover="#6366F1", button_primary_border_color="#4F46E5", button_primary_border_color_dark="#4338CA", button_primary_text_color="white", button_primary_text_color_dark="white", button_secondary_background_fill="#374151", button_secondary_background_fill_dark="#1F2937", button_secondary_background_fill_hover="#4B5563", button_secondary_text_color="white", button_secondary_text_color_dark="white", block_background_fill="#2d2d2d", block_background_fill_dark="#1F2937", block_label_background_fill="#4F46E5", block_label_background_fill_dark="#4338CA", block_label_text_color="white", block_label_text_color_dark="white", block_title_text_color="white", block_title_text_color_dark="white", input_background_fill="#374151", input_background_fill_dark="#1F2937", input_border_color="#4B5563", input_border_color_dark="#374151", input_placeholder_color="#9CA3AF", input_placeholder_color_dark="#6B7280", ) css = ''' /* Base styles */ h1 { font-size: 2.5em; text-align: center; margin-bottom: 0.5em; color: white !important; } h3 { margin-top: 0; font-size: 1.2em; color: white !important; } /* Ensure all text is white */ .markdown, .markdown h1, .markdown h2, .markdown h3, .markdown h4, .markdown h5, .markdown h6, .markdown p, label, .label-text, .gradio-radio label span, .gradio-checkbox label span, input, textarea, .gradio-textbox input, .gradio-textbox textarea, .gradio-number input, select, option, button { color: white !important; } /* Input style improvements */ input[type="text"], textarea, .input-text, .input-textarea { background-color: #374151 !important; border-color: #4B5563 !important; color: white !important; } /* Button styling */ button { transition: all 0.3s ease; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 6px rgba(0,0,0,0.1); } /* Image area */ .image-upload-area { border: 2px dashed #4B5563; border-radius: 12px; padding: 20px; text-align: center; margin-bottom: 20px; } /* Caption rows */ .caption-row { display: flex; align-items: center; margin-bottom: 10px; gap: 10px; } ''' # Gradio UI with gr.Blocks(theme=custom_theme, css=css) as demo: gr.Markdown( """# ๐Ÿ†” Gini LoRA ํ•™์Šต ### 1) LoRA ์ด๋ฆ„ ์ž…๋ ฅ 2) ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด ์ž…๋ ฅ 3) ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ(2-30์žฅ ๊ถŒ์žฅ) 4) ๋น„์ „ ์ธ์‹ LLM ๋ผ๋ฒจ๋ง 5) START ํด๋ฆญ""", elem_classes=["markdown"] ) with gr.Tab("Train"): with gr.Column(): # LoRA ์„ค์ • with gr.Group(): with gr.Row(): lora_name = gr.Textbox( label="LoRA ์ด๋ฆ„", info="๊ณ ์œ ํ•œ ์ด๋ฆ„์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค", placeholder="์˜ˆ: Persian Miniature Style, Cat Toy" ) concept_sentence = gr.Textbox( label="ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด/๋ฌธ์žฅ", info="์‚ฌ์šฉํ•  ํŠธ๋ฆฌ๊ฑฐ ๋‹จ์–ด๋‚˜ ๋ฌธ์žฅ", placeholder="p3rs0n์ด๋‚˜ trtcrd๊ฐ™์€ ํŠน์ดํ•œ ๋‹จ์–ด, ๋˜๋Š” 'in the style of CNSTLL'๊ฐ™์€ ๋ฌธ์žฅ" ) model_warning = gr.Markdown(visible=False) which_model = gr.Radio( ["๊ณ ํ€„๋ฆฌํ‹ฐ ๋งž์ถค ํ•™์Šต ๋ชจ๋ธ"], label="๊ธฐ๋ณธ ๋ชจ๋ธ", value="๊ณ ํ€„๋ฆฌํ‹ฐ ๋งž์ถค ํ•™์Šต ๋ชจ๋ธ" ) # ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ with gr.Group(visible=True, elem_classes="image-upload-area") as image_upload: with gr.Row(): images = gr.File( file_types=["image", ".txt"], label="Upload your images", file_count="multiple", interactive=True, visible=True, scale=1, ) with gr.Column(scale=3, visible=False) as captioning_area: with gr.Column(): gr.Markdown( """# ์ด๋ฏธ์ง€ ๋ผ๋ฒจ๋ง

๋น„์ „์ธ์‹ LLM์ด ์ด๋ฏธ์ง€๋ฅผ ์ธ์‹ํ•˜์—ฌ ์ž๋™์œผ๋กœ ๋ผ๋ฒจ๋ง(์ด๋ฏธ์ง€ ์ธ์‹์„ ์œ„ํ•œ ํ•„์ˆ˜ ์„ค๋ช…). [trigger] 'ํŠธ๋ฆฌ๊ฑฐ ์›Œ๋“œ'๋Š” ํ•™์Šตํ•œ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๋Š” ๊ณ ์œ  ํ‚ค๊ฐ’

""", elem_classes="group_padding") do_captioning = gr.Button("๋น„์ „ ์ธ์‹ LLM ์ž๋™ ๋ผ๋ฒจ๋ง") output_components = [captioning_area] caption_list = [] for i in range(1, MAX_IMAGES + 1): locals()[f"captioning_row_{i}"] = gr.Row(visible=False) with locals()[f"captioning_row_{i}"]: locals()[f"image_{i}"] = gr.Image( type="filepath", width=111, height=111, min_width=111, interactive=False, scale=2, show_label=False, show_share_button=False, show_download_button=False, ) locals()[f"caption_{i}"] = gr.Textbox( label=f"Caption {i}", scale=15, interactive=True ) output_components.append(locals()[f"captioning_row_{i}"]) output_components.append(locals()[f"image_{i}"]) output_components.append(locals()[f"caption_{i}"]) caption_list.append(locals()[f"caption_{i}"]) # ๊ณ ๊ธ‰ ์„ค์ • with gr.Accordion("Advanced options", open=False): steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1) lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6) rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4) with gr.Accordion("Even more advanced options", open=False): use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False) more_advanced_options = gr.Code( value=""" device: cuda:0 model: is_flux: true quantize: true network: linear: 16 linear_alpha: 16 type: lora sample: guidance_scale: 3.5 height: 1024 neg: '' sample_steps: 28 sampler: flowmatch seed: 42 walk_seed: true width: 1024 save: dtype: float16 hf_private: true max_step_saves_to_keep: 4 push_to_hub: true save_every: 10000 train: batch_size: 1 dtype: bf16 ema_config: ema_decay: 0.99 use_ema: true gradient_accumulation_steps: 1 gradient_checkpointing: true noise_scheduler: flowmatch optimizer: adamw8bit train_text_encoder: false train_unet: true """, language="yaml" ) # ์ƒ˜ํ”Œ ํ”„๋กฌํ”„ํŠธ with gr.Accordion("Sample prompts (optional)", visible=False) as sample: gr.Markdown( "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)" ) sample_1 = gr.Textbox(label="Test prompt 1") sample_2 = gr.Textbox(label="Test prompt 2") sample_3 = gr.Textbox(label="Test prompt 3") # ๋น„์šฉ ์•ˆ๋‚ด with gr.Group(visible=False) as cost_preview: cost_preview_info = gr.Markdown(elem_id="cost_preview_info", elem_classes="group_padding") payment_update = gr.Button("I have set up a payment method", visible=False) # ์กฐํ•ฉ ๋ณ€์ˆ˜ output_components.append(sample) output_components.append(sample_1) output_components.append(sample_2) output_components.append(sample_3) # ์‹œ์ž‘ ๋ฒ„ํŠผ start = gr.Button("START ํด๋ฆญ ('์•ฝ 15-20๋ถ„ ํ›„ ํ•™์Šต์ด ์ข…๋ฃŒ๋˜๊ณ  ์™„๋ฃŒ ๋ฉ”์‹œ์ง€๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค')", visible=False) # ์ง„ํ–‰ ์ƒํƒœ progress_area = gr.Markdown("") # ์ƒํƒœ ๋ณ€์ˆ˜ dataset_folder = gr.State() # ์ด๋ฒคํŠธ ๋ฐ”์ธ๋”ฉ images.upload( load_captioning, inputs=[images, concept_sentence], outputs=output_components ).then( update_pricing, inputs=[steps], outputs=[cost_preview, cost_preview_info, payment_update, start] ) images.clear( hide_captioning, outputs=[captioning_area, cost_preview, sample, start] ) images.delete( load_captioning, inputs=[images, concept_sentence], outputs=output_components ).then( update_pricing, inputs=[steps], outputs=[cost_preview, cost_preview_info, payment_update, start] ) steps.change( update_pricing, inputs=[steps], outputs=[cost_preview, cost_preview_info, payment_update, start] ) start.click( fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder ).then( fn=start_training, inputs=[ lora_name, concept_sentence, which_model, steps, lr, rank, dataset_folder, sample_1, sample_2, sample_3, use_more_advanced_options, more_advanced_options ], outputs=progress_area, ) do_captioning.click( fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list ) # Launch the app if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, auth=("gini", "pick"), show_error=True)