# app.py import gradio as gr import torch import jax import jax.numpy as jnp import numpy as np from PIL import Image import pickle import logging from huggingface_hub import hf_hub_download from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTImageProcessor, DPTForDepthEstimation from model import build_thera from utils import make_grid, interpolate_grid # Configuração de logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Configurações JAX_DEVICE = jax.devices("cpu")[0] TORCH_DEVICE = "cpu" def load_thera_model(repo_id: str, filename: str): """Carrega modelo com múltiplas verificações""" try: model_path = hf_hub_download(repo_id=repo_id, filename=filename) with open(model_path, 'rb') as fh: checkpoint = pickle.load(fh) # Verificar estrutura do checkpoint required_keys = {'model', 'backbone', 'size'} if not required_keys.issubset(checkpoint.keys()): missing = required_keys - checkpoint.keys() raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}") return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model'] except Exception as e: logger.error(f"Erro ao carregar modelo: {str(e)}") raise # Inicialização segura try: logger.info("Inicializando modelos...") model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl") # Pipeline SDXL pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32 ).to(TORCH_DEVICE) pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors") # Modelo de profundidade feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE) except Exception as e: logger.error(f"Falha crítica na inicialização: {str(e)}") raise def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]: """Calcula tamanho garantindo estabilidade numérica""" h, w = original new_h = int(h * scale) new_w = int(w * scale) # Ajustar para múltiplo de 8 new_h = max(32, new_h - new_h % 8) new_w = max(32, new_w - new_w % 8) return (new_h, new_w) def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0): """Pipeline completo com tratamento de erros robusto""" try: # Verificação inicial if not image: raise ValueError("Nenhuma imagem fornecida") # Conversão segura para RGB image = image.convert("RGB") orig_w, orig_h = image.size logger.info(f"Processando imagem: {orig_w}x{orig_h}") # Cálculo do novo tamanho new_h, new_w = safe_resize((orig_h, orig_w), scale_factor) logger.info(f"Novo tamanho calculado: {new_h}x{new_w}") # Gerar grid de coordenadas grid = make_grid((new_h, new_w)) logger.debug(f"Grid gerado: {grid.shape}") # Verificação crítica do grid if grid.shape[1:3] != (new_h, new_w): raise RuntimeError( f"Incompatibilidade de dimensões: " f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}" ) # Pré-processamento da imagem source = jnp.array(image).astype(jnp.float32) / 255.0 source = source[jnp.newaxis, ...] # Adicionar dimensão de batch # Parâmetro de escala t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32) # Processamento Thera upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w)) # Conversão para PIL upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8)) logger.info(f"Imagem super-resolvida: {upscaled_img.size}") # Geração do Bas-Relief result = pipe( prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution", image=upscaled_img, strength=0.7, num_inference_steps=30, guidance_scale=7.5 ) bas_relief = result.images[0] logger.info(f"Bas-Relief gerado: {bas_relief.size}") # Cálculo da profundidade inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE) with torch.no_grad(): depth = depth_model(**inputs).predicted_depth # Redimensionamento depth_map = torch.nn.functional.interpolate( depth.unsqueeze(1), size=bas_relief.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() # Normalização e conversão depth_min = depth_map.min() depth_max = depth_map.max() depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8) depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8)) logger.info("Mapa de profundidade calculado") return upscaled_img, bas_relief, depth_img except Exception as e: logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True) raise gr.Error(f"Falha no processamento: {str(e)}") # Interface Gradio with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app: gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade") with gr.Row(): input_col = gr.Column() output_col = gr.Column() with input_col: img_input = gr.Image(label="Carregar Imagem", type="pil", height=300) prompt = gr.Textbox( label="Descrição do Relevo", value="A insanely detailed and complex engraving relief, ultra-high definition", placeholder="Descreva o estilo desejado..." ) scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala") process_btn = gr.Button("Iniciar Processamento", variant="primary") with output_col: with gr.Tabs(): with gr.TabItem("Super Resolução"): upscaled_output = gr.Image(label="Resultado", show_label=False) with gr.TabItem("Bas-Relief"): basrelief_output = gr.Image(label="Relevo", show_label=False) with gr.TabItem("Profundidade"): depth_output = gr.Image(label="Mapa 3D", show_label=False) process_btn.click( full_pipeline, inputs=[img_input, prompt, scale], outputs=[upscaled_output, basrelief_output, depth_output], api_name="processar" ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)