import os
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
import copy
import random
import time
import requests
import pandas as pd
from transformers import pipeline
from gradio_imageslider import ImageSlider
import numpy as np
import warnings
# 상단에 허깅페이스 USERNAME (해당 계정) 반드시 개별 지정할것
USERNAME = "openfree"
huggingface_token = os.getenv("HF_TOKEN")
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
#Load prompts for randomization
df = pd.read_csv('prompts.csv', header=None)
prompt_values = df.values.flatten()
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# 공통 FLUX 모델 로드
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
# LoRA를 위한 설정
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
# Image-to-Image 파이프라인 설정
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
base_model,
vae=good_vae,
transformer=pipe.transformer,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
text_encoder_2=pipe.text_encoder_2,
tokenizer_2=pipe.tokenizer_2,
torch_dtype=dtype
).to(device)
MAX_SEED = 2**32 - 1
MAX_PIXEL_BUDGET = 1024 * 1024
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def download_file(url, directory=None):
if directory is None:
directory = os.getcwd() # Use current working directory if not specified
# Get the filename from the URL
filename = url.split('/')[-1]
# Full path for the downloaded file
filepath = os.path.join(directory, filename)
# Download the file
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
# Write the content to the file
with open(filepath, 'wb') as file:
file.write(response.content)
return filepath
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
selected_index = evt.index
selected_indices = selected_indices or []
if selected_index in selected_indices:
selected_indices.remove(selected_index)
else:
if len(selected_indices) < 3:
selected_indices.append(selected_index)
else:
gr.Warning("You can select up to 3 LoRAs, remove one to select a new one.")
return gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), width, height, gr.update(), gr.update(), gr.update()
selected_info_1 = "Select LoRA 1"
selected_info_2 = "Select LoRA 2"
selected_info_3 = "Select LoRA 3"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_scale_3 = 1.15
lora_image_1 = None
lora_image_2 = None
lora_image_3 = None
if len(selected_indices) >= 1:
lora1 = loras_state[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = loras_state[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
lora_image_2 = lora2['image']
if len(selected_indices) >= 3:
lora3 = loras_state[selected_indices[2]]
selected_info_3 = f"### LoRA 3 Selected: [{lora3['title']}](https://huggingface.co/{lora3['repo']}) ✨"
lora_image_3 = lora3['image']
if selected_indices:
last_selected_lora = loras_state[selected_indices[-1]]
new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
else:
new_placeholder = "Type a prompt after selecting a LoRA"
return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_info_3, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, width, height, lora_image_1, lora_image_2, lora_image_3
def remove_lora(selected_indices, loras_state, index_to_remove):
if len(selected_indices) > index_to_remove:
selected_indices.pop(index_to_remove)
selected_info_1 = "Select LoRA 1"
selected_info_2 = "Select LoRA 2"
selected_info_3 = "Select LoRA 3"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_scale_3 = 1.15
lora_image_1 = None
lora_image_2 = None
lora_image_3 = None
for i, idx in enumerate(selected_indices):
lora = loras_state[idx]
if i == 0:
selected_info_1 = f"### LoRA 1 Selected: [{lora['title']}]({lora['repo']}) ✨"
lora_image_1 = lora['image']
elif i == 1:
selected_info_2 = f"### LoRA 2 Selected: [{lora['title']}]({lora['repo']}) ✨"
lora_image_2 = lora['image']
elif i == 2:
selected_info_3 = f"### LoRA 3 Selected: [{lora['title']}]({lora['repo']}) ✨"
lora_image_3 = lora['image']
return selected_info_1, selected_info_2, selected_info_3, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_image_1, lora_image_2, lora_image_3
def remove_lora_1(selected_indices, loras_state):
return remove_lora(selected_indices, loras_state, 0)
def remove_lora_2(selected_indices, loras_state):
return remove_lora(selected_indices, loras_state, 1)
def remove_lora_3(selected_indices, loras_state):
return remove_lora(selected_indices, loras_state, 2)
def randomize_loras(selected_indices, loras_state):
try:
if len(loras_state) < 3:
raise gr.Error("Not enough LoRAs to randomize.")
selected_indices = random.sample(range(len(loras_state)), 3)
lora1 = loras_state[selected_indices[0]]
lora2 = loras_state[selected_indices[1]]
lora3 = loras_state[selected_indices[2]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
selected_info_3 = f"### LoRA 3 Selected: [{lora3['title']}](https://huggingface.co/{lora3['repo']}) ✨"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_scale_3 = 1.15
lora_image_1 = lora1.get('image', 'path/to/default/image.png')
lora_image_2 = lora2.get('image', 'path/to/default/image.png')
lora_image_3 = lora3.get('image', 'path/to/default/image.png')
random_prompt = random.choice(prompt_values)
return selected_info_1, selected_info_2, selected_info_3, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_image_1, lora_image_2, lora_image_3, random_prompt
except Exception as e:
print(f"Error in randomize_loras: {str(e)}")
return "Error", "Error", "Error", [], 1.15, 1.15, 1.15, 'path/to/default/image.png', 'path/to/default/image.png', 'path/to/default/image.png', ""
def add_custom_lora(custom_lora, selected_indices, current_loras):
if custom_lora:
try:
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
print(f"Loaded custom LoRA: {repo}")
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
if existing_item_index is None:
if repo.endswith(".safetensors") and repo.startswith("http"):
repo = download_file(repo)
new_item = {
"image": image if image else "/home/user/app/custom.png",
"title": title,
"repo": repo,
"weights": path,
"trigger_word": trigger_word
}
print(f"New LoRA: {new_item}")
existing_item_index = len(current_loras)
current_loras.append(new_item)
# Update gallery
gallery_items = [(item["image"], item["title"]) for item in current_loras]
# Update selected_indices if there's room
if len(selected_indices) < 3:
selected_indices.append(existing_item_index)
else:
gr.Warning("You can select up to 3 LoRAs, remove one to select a new one.")
# Update selected_info and images
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
selected_info_3 = "Select a LoRA 3"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_scale_3 = 1.15
lora_image_1 = None
lora_image_2 = None
lora_image_3 = None
if len(selected_indices) >= 1:
lora1 = current_loras[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
lora_image_1 = lora1['image'] if lora1['image'] else None
if len(selected_indices) >= 2:
lora2 = current_loras[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
lora_image_2 = lora2['image'] if lora2['image'] else None
if len(selected_indices) >= 3:
lora3 = current_loras[selected_indices[2]]
selected_info_3 = f"### LoRA 3 Selected: {lora3['title']} ✨"
lora_image_3 = lora3['image'] if lora3['image'] else None
print("Finished adding custom LoRA")
return (
current_loras,
gr.update(value=gallery_items),
selected_info_1,
selected_info_2,
selected_info_3,
selected_indices,
lora_scale_1,
lora_scale_2,
lora_scale_3,
lora_image_1,
lora_image_2,
lora_image_3
)
except Exception as e:
print(e)
gr.Warning(str(e))
return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
else:
return current_loras, gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
def remove_custom_lora(selected_indices, current_loras):
if current_loras:
custom_lora_repo = current_loras[-1]['repo']
# Remove from loras list
current_loras = current_loras[:-1]
# Remove from selected_indices if selected
custom_lora_index = len(current_loras)
if custom_lora_index in selected_indices:
selected_indices.remove(custom_lora_index)
# Update gallery
gallery_items = [(item["image"], item["title"]) for item in current_loras]
# Update selected_info and images
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
selected_info_3 = "Select a LoRA 3"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_scale_3 = 1.15
lora_image_1 = None
lora_image_2 = None
lora_image_3 = None
if len(selected_indices) >= 1:
lora1 = current_loras[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = current_loras[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
lora_image_2 = lora2['image']
if len(selected_indices) >= 3:
lora3 = current_loras[selected_indices[2]]
selected_info_3 = f"### LoRA 3 Selected: [{lora3['title']}]({lora3['repo']}) ✨"
lora_image_3 = lora3['image']
return (
current_loras,
gr.update(value=gallery_items),
selected_info_1,
selected_info_2,
selected_info_3,
selected_indices,
lora_scale_1,
lora_scale_2,
lora_scale_3,
lora_image_1,
lora_image_2,
lora_image_3
)
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
print("Generating image...")
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": 1.0},
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
pipe_i2i.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
image_input = load_image(image_input_path)
final_image = pipe_i2i(
prompt=prompt_mash,
image=image_input,
strength=image_strength,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": 1.0},
output_type="pil",
).images[0]
return final_image
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3, randomize_seed, seed,
width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
try:
# 한글 감지 및 번역
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
translated = translator(prompt, max_length=512)[0]['translation_text']
print(f"Original prompt: {prompt}")
print(f"Translated prompt: {translated}")
prompt = translated
if not selected_indices:
raise gr.Error("You must select at least one LoRA before proceeding.")
selected_loras = [loras_state[idx] for idx in selected_indices]
# Build the prompt with trigger words
prepends = []
appends = []
for lora in selected_loras:
trigger_word = lora.get('trigger_word', '')
if trigger_word:
if lora.get("trigger_position") == "prepend":
prepends.append(trigger_word)
else:
appends.append(trigger_word)
prompt_mash = " ".join(prepends + [prompt] + appends)
print("Prompt Mash: ", prompt_mash)
# Unload previous LoRA weights
with calculateDuration("Unloading LoRA"):
pipe.unload_lora_weights()
pipe_i2i.unload_lora_weights()
print(f"Active adapters before loading: {pipe.get_active_adapters()}")
# Load LoRA weights with respective scales
lora_names = []
lora_weights = []
with calculateDuration("Loading LoRA weights"):
for idx, lora in enumerate(selected_loras):
try:
lora_name = f"lora_{idx}"
lora_path = lora['repo']
# Private 모델인 경우 특별 처리
if lora.get('private', False):
lora_path = load_private_model(lora_path, huggingface_token)
print(f"Using private model path: {lora_path}")
if image_input is not None:
pipe_i2i.load_lora_weights(
lora_path,
adapter_name=lora_name,
token=huggingface_token
)
else:
pipe.load_lora_weights(
lora_path,
adapter_name=lora_name,
token=huggingface_token
)
lora_names.append(lora_name)
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2 if idx == 1 else lora_scale_3)
print(f"Successfully loaded LoRA {lora_name} from {lora_path}")
except Exception as e:
print(f"Failed to load LoRA {lora_name}: {str(e)}")
continue
print("Loaded LoRAs:", lora_names)
print("Adapter weights:", lora_weights)
if lora_names:
if image_input is not None:
pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
else:
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
else:
print("No LoRAs were successfully loaded.")
return None, seed, gr.update(visible=False)
print(f"Active adapters after loading: {pipe.get_active_adapters()}")
# Randomize seed if needed
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Generate image
if image_input is not None:
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
else:
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
final_image = None
step_counter = 0
for image in image_generator:
step_counter += 1
final_image = image
progress_bar = f'
'
yield image, seed, gr.update(value=progress_bar, visible=True)
if final_image is None:
raise Exception("Failed to generate image")
return final_image, seed, gr.update(visible=False)
except Exception as e:
print(f"Error in run_lora: {str(e)}")
return None, seed, gr.update(visible=False)
run_lora.zerogpu = True
def get_huggingface_safetensors(link):
split_link = link.split("/")
if len(split_link) == 2:
model_card = ModelCard.load(link)
base_model = model_card.data.get("base_model")
print(f"Base model: {base_model}")
if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
raise Exception("Not a FLUX LoRA!")
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
fs = HfFileSystem()
safetensors_name = None
try:
list_of_files = fs.ls(link, detail=False)
for file in list_of_files:
if file.endswith(".safetensors"):
safetensors_name = file.split("/")[-1]
if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
image_elements = file.split("/")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
except Exception as e:
print(e)
raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
if not safetensors_name:
raise gr.Error("No *.safetensors file found in the repository")
return split_link[1], link, safetensors_name, trigger_word, image_url
else:
raise gr.Error("Invalid Hugging Face repository link")
def check_custom_model(link):
if link.endswith(".safetensors"):
# Treat as direct link to the LoRA weights
title = os.path.basename(link)
repo = link
path = None # No specific weight name
trigger_word = ""
image_url = None
return title, repo, path, trigger_word, image_url
elif link.startswith("https://"):
if "huggingface.co" in link:
link_split = link.split("huggingface.co/")
return get_huggingface_safetensors(link_split[1])
else:
raise Exception("Unsupported URL")
else:
# Assume it's a Hugging Face model path
return get_huggingface_safetensors(link)
def update_history(new_image, history):
"""Updates the history gallery with the new image."""
if history is None:
history = []
if new_image is not None:
history.insert(0, new_image)
return history
def refresh_models(huggingface_token):
try:
headers = {
"Authorization": f"Bearer {huggingface_token}",
"Accept": "application/json"
}
username = USERNAME
api_url = f"https://huggingface.co/api/models?author={username}"
response = requests.get(api_url, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to fetch models from HuggingFace. Status code: {response.status_code}")
all_models = response.json()
print(f"Found {len(all_models)} models for user {username}")
user_models = [
model for model in all_models
if model.get('tags') and ('flux' in [tag.lower() for tag in model.get('tags', [])] or
'flux-lora' in [tag.lower() for tag in model.get('tags', [])])
]
print(f"Found {len(user_models)} FLUX models")
new_models = []
for model in user_models:
try:
model_id = model['id']
model_card_url = f"https://huggingface.co/api/models/{model_id}"
model_info_response = requests.get(model_card_url, headers=headers)
model_info = model_info_response.json()
# 이미지 URL에 토큰을 포함시키는 방식으로 변경
is_private = model.get('private', False)
base_image_name = "1732195028106__000001000_0.jpg" # 기본 이미지 이름
try:
# 실제 이미지 파일 확인
fs = HfFileSystem(token=huggingface_token)
samples_path = f"{model_id}/samples"
files = fs.ls(samples_path, detail=True)
jpg_files = [
f['name'] for f in files
if isinstance(f, dict) and
'name' in f and
f['name'].lower().endswith('.jpg') and
any(char.isdigit() for char in os.path.basename(f['name']))
]
if jpg_files:
base_image_name = os.path.basename(jpg_files[0])
except Exception as e:
print(f"Error accessing samples folder for {model_id}: {str(e)}")
# 이미지 URL 구성 (토큰 포함)
if is_private:
# Private 모델의 경우 로컬 캐시 경로 사용
cache_dir = f"models/{model_id.replace('/', '_')}/samples"
os.makedirs(cache_dir, exist_ok=True)
# 이미지 다운로드
image_url = f"https://huggingface.co/{model_id}/resolve/main/samples/{base_image_name}"
local_image_path = os.path.join(cache_dir, base_image_name)
if not os.path.exists(local_image_path):
response = requests.get(image_url, headers=headers)
if response.status_code == 200:
with open(local_image_path, 'wb') as f:
f.write(response.content)
image_url = local_image_path
else:
image_url = f"https://huggingface.co/{model_id}/resolve/main/samples/{base_image_name}"
model_info = {
"image": image_url,
"title": f"[Private] {model_id.split('/')[-1]}" if is_private else model_id.split('/')[-1],
"repo": model_id,
"weights": "pytorch_lora_weights.safetensors",
"trigger_word": model_info.get('instance_prompt', ''),
"private": is_private
}
new_models.append(model_info)
print(f"Added model: {model_id} with image: {image_url}")
except Exception as e:
print(f"Error processing model {model['id']}: {str(e)}")
continue
updated_loras = new_models + [lora for lora in loras if lora['repo'] not in [m['repo'] for m in new_models]]
print(f"Total models after refresh: {len(updated_loras)}")
return updated_loras
except Exception as e:
print(f"Error refreshing models: {str(e)}")
return loras
def load_private_model(model_id, huggingface_token):
"""Private 모델을 로드하는 함수"""
try:
headers = {"Authorization": f"Bearer {huggingface_token}"}
# 모델 다운로드
local_dir = snapshot_download(
repo_id=model_id,
token=huggingface_token,
local_dir=f"models/{model_id.replace('/', '_')}",
local_dir_use_symlinks=False
)
# safetensors 파일 찾기
safetensors_file = None
for root, dirs, files in os.walk(local_dir):
for file in files:
if file.endswith('.safetensors'):
safetensors_file = os.path.join(root, file)
break
if safetensors_file:
break
if not safetensors_file:
raise Exception(f"No .safetensors file found in {local_dir}")
print(f"Found safetensors file: {safetensors_file}")
return safetensors_file # 전체 경로를 반환
except Exception as e:
print(f"Error loading private model {model_id}: {str(e)}")
raise e
custom_theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="purple",
neutral_hue="slate",
).set(
button_primary_background_fill="*primary_500",
button_primary_background_fill_dark="*primary_600",
button_primary_background_fill_hover="*primary_400",
button_primary_border_color="*primary_500",
button_primary_border_color_dark="*primary_600",
button_primary_text_color="white",
button_primary_text_color_dark="white",
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_dark="*neutral_700",
button_secondary_background_fill_hover="*neutral_50",
button_secondary_text_color="*neutral_800",
button_secondary_text_color_dark="white",
background_fill_primary="*neutral_50",
background_fill_primary_dark="*neutral_900",
block_background_fill="white",
block_background_fill_dark="*neutral_800",
block_label_background_fill="*primary_500",
block_label_background_fill_dark="*primary_600",
block_label_text_color="white",
block_label_text_color_dark="white",
block_title_text_color="*neutral_800",
block_title_text_color_dark="white",
input_background_fill="white",
input_background_fill_dark="*neutral_800",
input_border_color="*neutral_200",
input_border_color_dark="*neutral_700",
input_placeholder_color="*neutral_400",
input_placeholder_color_dark="*neutral_400",
shadow_spread="8px",
shadow_inset="0px 2px 4px 0px rgba(0,0,0,0.05)"
)
css = '''
/* 기본 버튼 및 컴포넌트 스타일 */
#gen_btn {
height: 100%
}
#title {
text-align: center
}
#title h1 {
font-size: 3em;
display: inline-flex;
align-items: center
}
#title img {
width: 100px;
margin-right: 0.25em
}
#lora_list {
background: var(--block-background-fill);
padding: 0 1em .3em;
font-size: 90%
}
/* 커스텀 LoRA 카드 스타일 */
.custom_lora_card {
margin-bottom: 1em
}
.card_internal {
display: flex;
height: 100px;
margin-top: .5em
}
.card_internal img {
margin-right: 1em
}
/* 유틸리티 클래스 */
.styler {
--form-gap-width: 0px !important
}
/* 프로그레스 바 스타일 */
#progress {
height: 30px;
width: 90% !important;
margin: 0 auto !important;
}
#progress .generating {
display: none
}
.progress-container {
width: 100%;
height: 30px;
background-color: #f0f0f0;
border-radius: 15px;
overflow: hidden;
margin-bottom: 20px
}
.progress-bar {
height: 100%;
background-color: #4f46e5;
width: calc(var(--current) / var(--total) * 100%);
transition: width 0.5s ease-in-out
}
/* 컴포넌트 특정 스타일 */
#component-8, .button_total {
height: 100%;
align-self: stretch;
}
#loaded_loras [data-testid="block-info"] {
font-size: 80%
}
#custom_lora_structure {
background: var(--block-background-fill)
}
#custom_lora_btn {
margin-top: auto;
margin-bottom: 11px
}
#random_btn {
font-size: 300%
}
#component-11 {
align-self: stretch;
}
/* 갤러리 메인 스타일 */
#lora_gallery {
margin: 20px 0;
padding: 10px;
border: 1px solid #ddd;
border-radius: 12px;
background: linear-gradient(to bottom right, #ffffff, #f8f9fa);
width: 100% !important;
height: 800px !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
display: block !important;
}
/* 갤러리 그리드 스타일 */
#gallery {
display: grid !important;
grid-template-columns: repeat(10, 1fr) !important;
gap: 10px !important;
padding: 10px !important;
width: 100% !important;
height: 100% !important;
overflow-y: auto !important;
max-width: 100% !important;
}
/* 갤러리 아이템 스타일 */
.gallery-item {
position: relative !important;
width: 100% !important;
aspect-ratio: 1 !important;
margin: 0 !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease, box-shadow 0.3s ease;
border-radius: 12px;
overflow: hidden;
}
.gallery-item img {
width: 100% !important;
height: 100% !important;
object-fit: cover !important;
border-radius: 12px !important;
}
/* 갤러리 그리드 래퍼 */
.wrap, .svelte-w6dy5e {
display: grid !important;
grid-template-columns: repeat(10, 1fr) !important;
gap: 10px !important;
width: 100% !important;
max-width: 100% !important;
}
/* 컨테이너 공통 스타일 */
.container, .content, .block, .contain {
width: 100% !important;
max-width: 100% !important;
margin: 0 !important;
padding: 0 !important;
}
.row {
width: 100% !important;
margin: 0 !important;
padding: 0 !important;
}
/* 버튼 스타일 */
.button_total {
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
transition: all 0.3s ease;
}
.button_total:hover {
transform: translateY(-2px);
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05);
}
/* 입력 필드 스타일 */
input, textarea {
box-shadow: inset 0 2px 4px 0 rgba(0, 0, 0, 0.06);
transition: all 0.3s ease;
}
input:focus, textarea:focus {
box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.5);
}
/* 컴포넌트 border-radius */
.gradio-container .input,
.gradio-container .button,
.gradio-container .block {
border-radius: 12px;
}
/* 스크롤바 스타일 */
#gallery::-webkit-scrollbar {
width: 8px;
}
#gallery::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 4px;
}
#gallery::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
#gallery::-webkit-scrollbar-thumb:hover {
background: #555;
}
/* Flex 컨테이너 */
.flex {
width: 100% !important;
max-width: 100% !important;
display: flex !important;
}
/* Svelte 특정 클래스 */
.svelte-1p9xokt {
width: 100% !important;
max-width: 100% !important;
}
/* Footer 숨김 */
#footer {
visibility: hidden;
}
/* 결과 이미지 및 컨테이너 스타일 */
#result_column, #result_column > div {
display: flex !important;
flex-direction: column !important;
align-items: flex-start !important; /* center에서 flex-start로 변경 */
width: 100% !important;
margin: 0 !important; /* auto에서 0으로 변경 */
}
.generated-image, .generated-image > div {
display: flex !important;
justify-content: flex-start !important; /* center에서 flex-start로 변경 */
align-items: flex-start !important; /* center에서 flex-start로 변경 */
width: 90% !important;
max-width: 768px !important;
margin: 0 !important; /* auto에서 0으로 변경 */
margin-left: 20px !important; /* 왼쪽 여백 추가 */
}
.generated-image img {
margin: 0 !important; /* auto에서 0으로 변경 */
display: block !important;
max-width: 100% !important;
}
/* 히스토리 갤러리도 좌측 정렬로 변경 */
.history-gallery {
display: flex !important;
justify-content: flex-start !important; /* center에서 flex-start로 변경 */
width: 90% !important;
max-width: 90% !important;
margin: 0 !important; /* auto에서 0으로 변경 */
margin-left: 20px !important; /* 왼쪽 여백 추가 */
/* 새로고침 버튼 스타일 */
#refresh-button {
margin: 10px;
padding: 8px 16px;
background-color: #4a5568;
color: white;
border-radius: 8px;
transition: all 0.3s ease;
}
#refresh-button:hover {
background-color: #2d3748;
transform: scale(1.05);
}
#refresh-button:active {
transform: scale(0.95);
}
}
'''
with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
loras_state = gr.State(loras)
selected_indices = gr.State([])
gr.Markdown(
"""
# MixGen3: 멀티 Lora(이미지 학습) 통합 생성 모델
### 사용 안내:
갤러리에서 원하는 모델을 선택(최대 3개까지) < 프롬프트에 한글 또는 영문으로 원하는 내용을 입력 < Generate 버튼 실행
"""
)
# 새로고침 버튼 추가
with gr.Row():
refresh_button = gr.Button("🔄 모델 새로고침(나만의 맞춤 학습된 Private 모델 불러오기)", variant="secondary")
with gr.Row(elem_id="lora_gallery", equal_height=True):
gallery = gr.Gallery(
value=[(item["image"], item["title"]) for item in loras],
label="LoRA Explorer Gallery",
columns=11,
elem_id="gallery",
height=800,
object_fit="cover",
show_label=True,
allow_preview=False,
show_share_button=False,
container=True,
preview=False
)
with gr.Tab(label="Generate"):
# Prompt and Generate Button
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1):
generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
# LoRA Selection Area
with gr.Row(elem_id="loaded_loras"):
# Randomize Button
with gr.Column(scale=1, min_width=25):
randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
# LoRA 1
with gr.Column(scale=8):
with gr.Row():
with gr.Column(scale=0, min_width=50):
lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
with gr.Column(scale=3, min_width=100):
selected_info_1 = gr.Markdown("Select a LoRA 1")
with gr.Column(scale=5, min_width=50):
lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
with gr.Row():
remove_button_1 = gr.Button("Remove", size="sm")
# LoRA 2
with gr.Column(scale=8):
with gr.Row():
with gr.Column(scale=0, min_width=50):
lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
with gr.Column(scale=3, min_width=100):
selected_info_2 = gr.Markdown("Select a LoRA 2")
with gr.Column(scale=5, min_width=50):
lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
with gr.Row():
remove_button_2 = gr.Button("Remove", size="sm")
# LoRA 3
with gr.Column(scale=8):
with gr.Row():
with gr.Column(scale=0, min_width=50):
lora_image_3 = gr.Image(label="LoRA 3 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
with gr.Column(scale=3, min_width=100):
selected_info_3 = gr.Markdown("Select a LoRA 3")
with gr.Column(scale=5, min_width=50):
lora_scale_3 = gr.Slider(label="LoRA 3 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
with gr.Row():
remove_button_3 = gr.Button("Remove", size="sm")
# Result and Progress Area
with gr.Column(elem_id="result_column"):
progress_bar = gr.Markdown(elem_id="progress", visible=False)
with gr.Column(elem_id="result_box"): # Box를 Column으로 변경
result = gr.Image(
label="Generated Image",
interactive=False,
elem_classes=["generated-image"],
container=True,
elem_id="result_image",
width="100%"
)
with gr.Accordion("History", open=False):
history_gallery = gr.Gallery(
label="History",
columns=6,
object_fit="contain",
interactive=False,
elem_classes=["history-gallery"]
)
# Advanced Settings
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
input_image = gr.Image(label="Input image", type="filepath")
image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
# Custom LoRA Section
with gr.Column():
with gr.Group():
with gr.Row(elem_id="custom_lora_structure"):
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="ginipick/flux-lora-eric-cat", scale=3, min_width=150)
add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
# Event Handlers
gallery.select(
update_selection,
inputs=[selected_indices, loras_state, width, height],
outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3, width, height,
lora_image_1, lora_image_2, lora_image_3]
)
remove_button_1.click(
remove_lora_1,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_info_3, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3]
)
remove_button_2.click(
remove_lora_2,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_info_3, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3]
)
remove_button_3.click(
remove_lora_3,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_info_3, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3]
)
randomize_button.click(
randomize_loras,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_info_3, selected_indices,
lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3, prompt]
)
add_custom_lora_button.click(
add_custom_lora,
inputs=[custom_lora, selected_indices, loras_state],
outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_info_3,
selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3]
)
remove_custom_lora_button.click(
remove_custom_lora,
inputs=[selected_indices, loras_state],
outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_info_3,
selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
lora_image_1, lora_image_2, lora_image_3]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, input_image, image_strength, cfg_scale, steps,
selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
randomize_seed, seed, width, height, loras_state],
outputs=[result, seed, progress_bar]
).then(
fn=lambda x, history: update_history(x, history) if x is not None else history,
inputs=[result, history_gallery],
outputs=history_gallery
)
# 새로고침 버튼 이벤트 핸들러
def refresh_gallery():
updated_loras = refresh_models(huggingface_token)
return (
gr.update(value=[(item["image"], item["title"]) for item in updated_loras]),
updated_loras
)
refresh_button.click(
refresh_gallery,
outputs=[gallery, loras_state]
)
if __name__ == "__main__":
app.queue(max_size=20)
app.launch(debug=True)