import logging import gradio as gr import torch import numpy as np import jax import pickle from PIL import Image from huggingface_hub import hf_hub_download, file_download from model import build_thera from super_resolve import process from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTFeatureExtractor, DPTForDepthEstimation # ================== CONFIGURAÇÃO INICIAL ================== # Configurar sistema de logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Fix para compatibilidade do Hugging Face Hub file_download.cached_download = file_download.hf_hub_download # ================== CONFIGURAÇÃO DE HARDWARE ================== device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 logger.info(f"Dispositivo selecionado: {device.upper()}") logger.info(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}") # ================== CARREGAMENTO DE MODELOS ================== def carregar_modelo_thera(repo_id): """Carrega modelos Thera do Hugging Face Hub""" try: logger.info(f"Carregando modelo Thera: {repo_id}") caminho_modelo = hf_hub_download(repo_id=repo_id, filename="model.pkl") with open(caminho_modelo, 'rb') as arquivo: dados = pickle.load(arquivo) modelo = build_thera(3, dados['backbone'], dados['size']) parametros = dados['model'] logger.success(f"Modelo {repo_id} carregado com sucesso") return modelo, parametros except Exception as erro: logger.error(f"Falha ao carregar {repo_id}: {str(erro)}") raise # Carregar modelos Thera try: logger.divider("Carregando Modelos Thera") modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") except Exception as erro: logger.critical("Falha crítica no carregamento dos modelos Thera") raise # ================== PIPELINE DE ARTE ================== # Configurar SDXL + LoRA try: logger.divider("Configurando Pipeline de Arte") pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype, variant="fp16", use_safetensors=True ).to(device) pipe.load_lora_weights( "KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors", adapter_name="bas_relief" ) logger.success("Pipeline SDXL + LoRA configurado") except Exception as erro: logger.error(f"Erro no SDXL: {str(erro)}") pipe = None # Configurar modelo de profundidade try: logger.divider("Configurando Modelo de Profundidade") processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) logger.success("Modelo de profundidade pronto") except Exception as erro: logger.error(f"Erro no modelo de profundidade: {str(erro)}") modelo_profundidade = None # ================== FLUXO DE PROCESSAMENTO ================== def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo): """Executa todo o fluxo de processamento""" try: logger.divider("Iniciando novo processamento") # ========= FASE 1: SUPER-RESOLUÇÃO ========= logger.etapa("Processando Super-Resolução") modelo_sr = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn parametros_sr = params_edsr if modelo_escolhido == "EDSR" else params_rdn # Converter e validar entrada if not isinstance(imagem, Image.Image): logger.warning("Convertendo entrada numpy para PIL Image") imagem = Image.fromarray(imagem) # Processar super-resolução imagem_sr_jax = process( np.array(imagem) / 255., modelo_sr, parametros_sr, (round(imagem.size[1] * fator_escala), round(imagem.size[0] * fator_escala)), True ) # Converter para formato compatível imagem_sr_pil = Image.fromarray(np.array(imagem_sr_jax)).convert("RGB") logger.success(f"Super-Resolução concluída: {imagem_sr_pil.size}") # ========= FASE 2: ESTILO BAIXO-RELEVO ========= if device == "cpu" or not pipe: logger.warning("GPU não disponível - Pulando estilo") return imagem_sr_pil, None, None logger.etapa("Aplicando Estilo Baixo-Relevo") prompt_completo = f"BAS-RELIEF {prompt_estilo}, intricate carving, marble texture, 8k" with torch.autocast(device_type=device.split(':')[0], dtype=torch_dtype): imagem_estilizada = pipe( prompt=prompt_completo, image=imagem_sr_pil, strength=0.7, num_inference_steps=35, guidance_scale=7.5, output_type="pil" ).images[0] logger.success(f"Estilo aplicado: {imagem_estilizada.size}") # ========= FASE 3: MAPA DE PROFUNDIDADE ========= logger.etapa("Gerando Mapa de Profundidade") inputs = processador_profundidade( images=imagem_estilizada, return_tensors="pt" ).to(device, dtype=torch_dtype) with torch.no_grad(), torch.autocast(device_type=device.split(':')[0]): outputs = modelo_profundidade(**inputs) profundidade = outputs.predicted_depth # Processar profundidade profundidade = torch.nn.functional.interpolate( profundidade.unsqueeze(1).float(), # Converter para float32 size=imagem_estilizada.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() # Normalizar e converter profundidade = (profundidade - profundidade.min()) / (profundidade.max() - profundidade.min() + 1e-8) mapa_profundidade = Image.fromarray((profundidade * 255).astype(np.uint8)) logger.success("Processamento completo") return imagem_sr_pil, imagem_estilizada, mapa_profundidade except Exception as erro: logger.error(f"ERRO NO PIPELINE: {str(erro)}", exc_info=True) return imagem_sr_pil if 'imagem_sr_pil' in locals() else None, None, None # ================== INTERFACE GRADIO ================== with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🎨 TheraSR Art Suite **Combine super-resolução aliasing-free com geração artística de baixo-relevo** """) with gr.Row(variant="panel"): with gr.Column(scale=1): entrada_imagem = gr.Image(label="🖼 Imagem de Entrada", type="pil") seletor_modelo = gr.Radio( ["EDSR", "RDN"], value="EDSR", label="🔧 Modelo de Super-Resolução" ) controle_escala = gr.Slider( 1.0, 6.0, value=2.0, step=0.1, label="🔍 Fator de Escala" ) entrada_prompt = gr.Textbox( label="📝 Prompt de Estilo", value="insanely detailed and complex engraving relief, ultra HD 8k", placeholder="Descreva o estilo desejado..." ) botao_processar = gr.Button("🚀 Processar Imagem", variant="primary") with gr.Column(scale=2): saida_sr = gr.Image(label="✨ Super-Resolução", interactive=False) saida_arte = gr.Image(label="🖌 Arte em Baixo-Relevo", interactive=False) saida_profundidade = gr.Image(label="🗺 Mapa de Profundidade", interactive=False) # Configurar eventos botao_processar.click( fn=pipeline_completo, inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt], outputs=[saida_sr, saida_arte, saida_profundidade] ) # ================== INICIALIZAÇÃO ================== if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, show_error=True, share=False, debug=False )