import gradio as gr import torch import jax import jax.numpy as jnp import numpy as np from PIL import Image import pickle import warnings import logging from huggingface_hub import hf_hub_download from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTImageProcessor, DPTForDepthEstimation from model import build_thera # Configuração de logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("processing.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # Configurações e supressão de avisos warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # Configurar dispositivos JAX_DEVICE = jax.devices("cpu")[0] TORCH_DEVICE = "cpu" # 1. Carregar modelos do Thera ---------------------------------------------------------------- def load_thera_model(repo_id, filename): try: logger.info(f"Carregando modelo Thera de {repo_id}") model_path = hf_hub_download(repo_id=repo_id, filename=filename) with open(model_path, 'rb') as fh: check = pickle.load(fh) variables = check['model'] backbone, size = check['backbone'], check['size'] model = build_thera(3, backbone, size) return model, variables except Exception as e: logger.error(f"Erro ao carregar modelo: {str(e)}") raise logger.info("Carregando Thera EDSR...") model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl") # 2. Carregar SDXL + LoRA --------------------------------------------------------------------- try: logger.info("Carregando SDXL + LoRA...") pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float32 ).to(TORCH_DEVICE) pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors") except Exception as e: logger.error(f"Erro ao carregar SDXL: {str(e)}") raise # 3. Carregar modelo de profundidade ---------------------------------------------------------- try: logger.info("Carregando DPT Depth...") feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE) except Exception as e: logger.error(f"Erro ao carregar DPT: {str(e)}") raise def adjust_size(size): """Garante que o tamanho seja divisível por 8""" return (size // 8) * 8 def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()): try: progress(0.1, desc="Pré-processamento...") # Converter e verificar imagem image = image.convert("RGB") source = np.array(image) / 255.0 # Adicionar dimensão de batch se necessário if source.ndim == 3: source = source[np.newaxis, ...] # Ajustar tamanho alvo target_shape = ( adjust_size(int(image.height * scale_factor)), adjust_size(int(image.width * scale_factor)) ) progress(0.3, desc="Super-resolução...") source_jax = jax.device_put(source, JAX_DEVICE) t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32) # Processar com Thera upscaled = model_edsr.apply( variables_edsr, source_jax, t, target_shape ) # Remover dimensão de batch se necessário if upscaled.ndim == 4: upscaled = upscaled[0] upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8)) progress(0.6, desc="Gerando Bas-Relief...") full_prompt = f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution" bas_relief = pipe( prompt=full_prompt, image=upscaled_pil, strength=0.7, num_inference_steps=25 ).images[0] progress(0.8, desc="Calculando profundidade...") inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE) with torch.no_grad(): outputs = depth_model(**inputs) depth = outputs.predicted_depth depth_map = torch.nn.functional.interpolate( depth.unsqueeze(1), size=bas_relief.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8)) return upscaled_pil, bas_relief, depth_pil except Exception as e: logger.error(f"Erro: {str(e)}", exc_info=True) raise gr.Error(f"Erro: {str(e)}") # Interface Gradio ---------------------------------------------------------------------------- with gr.Blocks(title="SuperRes + BasRelief") as app: gr.Markdown("## 🖼️ Super Resolução + Bas-Relief + Mapa de Profundidade") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Imagem de Entrada") prompt = gr.Textbox( label="Descrição", value="insanely detailed and complex engraving relief, ultra-high definition" ) scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") btn = gr.Button("Processar") with gr.Column(): img_upscaled = gr.Image(label="Super Resolvida") img_basrelief = gr.Image(label="Bas-Relief") img_depth = gr.Image(label="Profundidade") btn.click( full_pipeline, inputs=[img_input, prompt, scale], outputs=[img_upscaled, img_basrelief, img_depth] ) if __name__ == "__main__": app.launch() # Sem compartilhamento público