Redux / app.py
nftnik's picture
Update app.py
62b6248 verified
raw
history blame
6.18 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
# Diagnóstico CUDA
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA disponível:", torch.cuda.is_available())
print("Quantidade de GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
print("GPU atual:", torch.cuda.get_device_name(0))
# Adicionar o caminho da pasta ComfyUI ao sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)
# Importar ComfyUI components
import execution
from nodes import NODE_CLASS_MAPPINGS
import folder_paths
from comfy import model_management
# Configuração de diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
os.makedirs(output_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Inicializar execução
execution.init()
# Helper function
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# Baixar modelos necessários
def download_models():
print("Baixando modelos...")
models = [
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "models/style_models"),
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "models/text_encoders"),
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "models/text_encoders"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/vae"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/diffusion_models"), # Corrigido aqui
("google/siglip-so400m-patch14-384", "model.safetensors", "models/clip_vision"),
("nftnik/NFTNIK-FLUX.1-dev-LoRA", "NFTNIK_V1.safetensors", "models/lora")
]
for repo_id, filename, local_dir in models:
try:
os.makedirs(local_dir, exist_ok=True)
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
except Exception as e:
print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
# Continue mesmo se um download falhar
continue
# Download models antes de inicializar
download_models()
# Inicializar modelos
print("Inicializando modelos...")
with torch.inference_mode():
# Initialize nodes
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_357 = dualcliploader.load_clip(
clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
type="flux",
)
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
stylemodelloader_441 = stylemodelloader.load_style_model(
style_model_name="models/style_models/flux1-redux-dev.safetensors"
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_359 = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")
# Carregar modelos na GPU
model_loaders = [dualcliploader_357, vaeloader_359, stylemodelloader_441]
valid_models = [
getattr(loader[0], 'patcher', loader[0])
for loader in model_loaders
if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
]
model_management.load_models_gpu(valid_models)
@spaces.GPU
def generate_image(prompt, input_image, lora_weight, progress=gr.Progress(track_tqdm=True)):
"""Função principal de geração com monitoramento de progresso"""
try:
with torch.inference_mode():
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(dualcliploader_357, 0)
)
# Carregar LoRA
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
lora_model = loraloadermodelonly.load_lora_model_only(
lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
strength_model=lora_weight,
model=get_value_at_index(stylemodelloader_441, 0)
)
# Processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
# Decodificar
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=get_value_at_index(lora_model, 0),
vae=get_value_at_index(vaeloader_359, 0)
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
Image.fromarray((get_value_at_index(decoded, 0) * 255).astype("uint8")).save(temp_path)
return temp_path
except Exception as e:
print(f"Erro ao gerar imagem: {str(e)}")
return None
# Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# Gerador de Imagens FLUX Redux")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", placeholder="Digite seu prompt aqui...", lines=5)
input_image = gr.Image(label="Imagem de Entrada", type="filepath")
lora_weight = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="Peso LoRA")
generate_btn = gr.Button("Gerar Imagem")
with gr.Column():
output_image = gr.Image(label="Imagem Gerada", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, input_image, lora_weight],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()