# app.py import logging import gradio as gr import torch import numpy as np import jax import pickle from PIL import Image from huggingface_hub import hf_hub_download from model import build_thera from super_resolve import process from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTFeatureExtractor, DPTForDepthEstimation # ================== CONFIGURAÇÃO DE LOGGING ================== class CustomLogger: def __init__(self, name): self.logger = logging.getLogger(name) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler = logging.StreamHandler() handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) def divider(self, text=None, length=60): if text: available_space = max(length - len(text) - 12, 1) msg = f"{'=' * 10} {text.upper()} {'=' * available_space}" else: msg = "=" * length self.logger.info(msg) def etapa(self, text): self.logger.info(f"▶ {text}") def success(self, text): self.logger.info(f"✓ {text}") def error(self, text): self.logger.error(f"✗ {text}") def warning(self, text): self.logger.warning(f"⚠ {text}") logger = CustomLogger(__name__) # ================== CONFIGURAÇÃO DE HARDWARE ================== device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float32 # Forçar precisão única para compatibilidade logger.divider("Inicialização do Sistema") logger.success(f"Dispositivo detectado: {device.upper()}") logger.success(f"Modo de precisão: float32") # ================== CARREGAMENTO DE MODELOS ================== def carregar_modelo_thera(repo_id): """Carrega modelos Thera com tratamento de erros robusto""" try: logger.divider(f"Carregando {repo_id}") model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") with open(model_path, 'rb') as f: check = pickle.load(f) model = build_thera(3, check['backbone'], check['size']) params = check['model'] logger.success(f"{repo_id} carregado") return model, params except Exception as e: logger.error(f"Falha ao carregar {repo_id}: {str(e)}") return None, None # Carregar modelos principais modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") # ================== MODELOS DE ARTE (CARREGAMENTO CONDICIONAL) ================== pipe = None modelo_profundidade = None processador_profundidade = None try: logger.divider("Inicializando Componentes Artísticos") # Pipeline de estilo pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype, use_safetensors=True ).to(device) # Adapter LoRA pipe.load_lora_weights( "KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors" ) # Modelo de profundidade processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) logger.success("Componentes artísticos prontos") except Exception as e: logger.warning(f"Recursos artísticos desativados: {str(e)}") pipe = None modelo_profundidade = None # ================== PIPELINE PRINCIPAL ================== def processar_imagem(imagem, escala, modelo, prompt): """Fluxo completo de processamento com fallbacks""" try: logger.divider("Novo Processamento") # Converter entrada para PIL if not isinstance(imagem, Image.Image): imagem = Image.fromarray(imagem) # ========= 1. SUPER-RESOLUÇÃO ========= logger.etapa("Super-Resolução Thera") modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn params_sr = params_edsr if modelo == "EDSR" else params_rdn sr_jax = process( np.array(imagem) / 255.0, modelo_sr, params_sr, (int(imagem.height * escala), int(imagem.width * escala)), True ) sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB") logger.success(f"Resolução: {sr_pil.size[0]}x{sr_pil.size[1]}") # ========= 2. ESTILO BAIXO-RELEVO ========= arte_pil = sr_pil # Fallback padrão if pipe: try: logger.etapa("Aplicando Estilo") arte_pil = pipe( prompt=f"BAS-RELIEF {prompt}, marble texture, cinematic lighting", image=sr_pil, strength=0.6, num_inference_steps=25, guidance_scale=7.0 ).images[0] logger.success("Estilo aplicado") except Exception as e: logger.error(f"Erro no estilo: {str(e)}") # ========= 3. MAPA DE PROFUNDIDADE ========= mapa_pil = arte_pil # Fallback padrão if modelo_profundidade and arte_pil: try: logger.etapa("Calculando Profundidade") inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device) with torch.no_grad(): depth = modelo_profundidade(**inputs).predicted_depth depth = torch.nn.functional.interpolate( depth.unsqueeze(1).float(), size=arte_pil.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() depth_normalized = (depth - depth.min()) / (depth.max() - depth.min()) mapa_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8)) logger.success("Profundidade calculada") except Exception as e: logger.error(f"Erro na profundidade: {str(e)}") return sr_pil, arte_pil, mapa_pil except Exception as e: logger.error(f"Erro fatal: {str(e)}") return None, None, None # ================== INTERFACE GRADIO ================== with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app: gr.Markdown("# 🏛 TheraSR - Super Resolução & Arte") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Imagem de Entrada", type="pil") scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala", step=0.1) model_select = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo") style_prompt = gr.Textbox( label="Descrição do Estilo", value="ancient greek marble浮雕, ultra detailed, 8k" ) btn_process = gr.Button("Processar", variant="primary") with gr.Column(): output_sr = gr.Image(label="Super-Resolução", interactive=False) output_art = gr.Image(label="Arte em Relevo", interactive=False) output_depth = gr.Image(label="Mapa de Profundidade", interactive=False) btn_process.click( processar_imagem, inputs=[input_image, scale, model_select, style_prompt], outputs=[output_sr, output_art, output_depth] ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)