import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import warnings
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera

# Configurações e supressão de avisos
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"


# 1. Carregar modelos do Thera ----------------------------------------------------------------
def load_thera_model(repo_id, filename):
    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        params, backbone, size = check['model'], check['backbone'], check['size']
    model = build_thera(3, backbone, size)
    return model, params


print("Carregando Thera EDSR...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")

# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")

# 3. Carregar modelo de profundidade ----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)


# Pipeline principal --------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
    try:
        # 1. Super Resolução com Thera
        image = image.convert("RGB")
        source = np.array(image) / 255.0
        target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))

        # Preparar parâmetros para JAX
        source_jax = jax.device_put(source, JAX_DEVICE)
        t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)

        # Processar com Thera
        upscaled = model_edsr.apply(
            params_edsr,
            source_jax,
            t,
            target_shape,
            do_ensemble=True
        )
        upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))

        # 2. Gerar Bas-Relief
        full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
        bas_relief = pipe(
            prompt=full_prompt,
            image=upscaled_pil,
            strength=0.7,
            num_inference_steps=25,
            guidance_scale=7.5
        ).images[0]

        # 3. Calcular Depth Map
        inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
        with torch.no_grad():
            outputs = depth_model(**inputs)
            depth = outputs.predicted_depth

        depth_map = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=bas_relief.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
        depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))

        return upscaled_pil, bas_relief, depth_pil

    except Exception as e:
        raise gr.Error(f"Erro no processamento: {str(e)}")


# Interface Gradio ----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
    gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Imagem de Entrada")
            prompt = gr.Textbox(
                label="Descrição do Relevo",
                value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
            )
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            btn = gr.Button("Processar")

        with gr.Column():
            img_upscaled = gr.Image(label="Imagem Super Resolvida")
            img_basrelief = gr.Image(label="Resultado Bas-Relief")
            img_depth = gr.Image(label="Mapa de Profundidade")

    btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[img_upscaled, img_basrelief, img_depth]
    )

if __name__ == "__main__":
    app.launch()