# app.py import logging import gradio as gr import torch import numpy as np import jax import pickle from PIL import Image from huggingface_hub import hf_hub_download from model import build_thera from super_resolve import process from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTFeatureExtractor, DPTForDepthEstimation # ================== CONFIGURAÇÃO DE LOGGING ================== class CustomLogger: def __init__(self, name): self.logger = logging.getLogger(name) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler = logging.StreamHandler() handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) def divider(self, text=None, length=60): if text: available_space = max(length - len(text) - 12, 1) msg = f"{'=' * 10} {text.upper()} {'=' * available_space}" else: msg = "=" * length self.logger.info(msg) def etapa(self, text): self.logger.info(f"▶ {text}") def success(self, text): self.logger.info(f"✓ {text}") def error(self, text): self.logger.error(f"✗ {text}") logger = CustomLogger(__name__) # ================== CONFIGURAÇÃO FORÇADA ================== device = "cpu" torch_dtype = torch.float32 logger.divider("Configuração Forçada") logger.success(f"Dispositivo: {device.upper()}") logger.success(f"Precisão: {str(torch_dtype).replace('torch.', '')}") # ================== CARREGAMENTO DE MODELOS ================== def carregar_modelo_thera(repo_id): try: logger.divider(f"Carregando {repo_id}") model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") with open(model_path, 'rb') as f: check = pickle.load(f) model = build_thera(3, check['backbone'], check['size']) params = check['model'] logger.success(f"{repo_id} carregado") return model, params except Exception as e: logger.error(f"Falha no carregamento: {str(e)}") return None, None try: modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") except Exception as e: logger.error("Falha crítica nos modelos Thera") raise # ================== PIPELINE ARTÍSTICO ================== pipe = None modelo_profundidade = None try: logger.divider("Configurando Componentes Artísticos") # Pipeline principal pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype, variant="fp32" ).to(device) # LoRA pipe.load_lora_weights( "KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors", peft_backend="peft" # This is crucial ) # Modelo de profundidade processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).float() logger.success("Componentes artísticos em float32") except Exception as e: logger.warning(f"Recursos artísticos limitados: {str(e)}") print(e) pipe = None # ================== PROCESSAMENTO PRINCIPAL ================== def processar_imagem_completa(imagem, escala, modelo, prompt): try: logger.divider("Iniciando Processamento") # Converter entrada if not isinstance(imagem, Image.Image): imagem = Image.fromarray(imagem) # ========= 1. SUPER-RESOLUÇÃO ========= logger.etapa("Processando Super-Resolução") modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn params_sr = params_edsr if modelo == "EDSR" else params_rdn sr_jax = process( np.array(imagem) / 255., modelo_sr, params_sr, (round(imagem.height * escala), round(imagem.width * escala)), True ) sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB") logger.success(f"SR: {sr_pil.size[0]}x{sr_pil.size[1]}") # ========= 2. ESTILO BAIXO-RELEVO ========= arte_pil = sr_pil # Fallback if pipe: try: logger.etapa("Aplicando Estilo") arte_pil = pipe( prompt=f"BAS-RELIEF {prompt}, marble texture, 8k", image=sr_pil, strength=0.6, num_inference_steps=25, guidance_scale=7.0, generator=torch.Generator(device).manual_seed(42) ).images[0] logger.success("Estilo aplicado") except Exception as e: logger.error(f"Erro no estilo: {str(e)}") print(e) # ========= 3. MAPA DE PROFUNDIDADE ========= mapa_pil = arte_pil # Fallback if modelo_profundidade: try: logger.etapa("Calculando Profundidade") inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device) with torch.no_grad(): depth = modelo_profundidade(**inputs).predicted_depth depth = torch.nn.functional.interpolate( depth.unsqueeze(1), size=arte_pil.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() depth = (depth - depth.min()) / (depth.max() - depth.min()) mapa_pil = Image.fromarray((depth * 255).astype(np.uint8)) logger.success("Profundidade calculada") except Exception as e: logger.error(f"Erro na profundidade: {str(e)}") print (e) return sr_pil, arte_pil, mapa_pil except Exception as e: logger.error(f"Erro fatal: {str(e)}") print(e) return None, None, None # ================== INTERFACE GRADIO ================== with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app: gr.Markdown("# 🏛 TheraSR - Processamento Completo em Float32") with gr.Row(): with gr.Column(): input_img = gr.Image(label="Imagem de Entrada", type="pil") slider_scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") radio_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo") text_prompt = gr.Textbox( label="Prompt de Estilo", value="ancient marble浮雕, ultra detailed, 8k cinematic" ) btn_process = gr.Button("Processar", variant="primary") with gr.Column(): output_sr = gr.Image(label="Super-Resolução", interactive=False) output_art = gr.Image(label="Arte em Relevo", interactive=False) output_depth = gr.Image(label="Mapa de Profundidade", interactive=False) btn_process.click( processar_imagem_completa, inputs=[input_img, slider_scale, radio_model, text_prompt], outputs=[output_sr, output_art, output_depth] ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)