File size: 5,819 Bytes
19a6d73
98889c8
19a6d73
1eb87a5
a7111d1
98889c8
1665fe1
19a6d73
 
d160dc6
19a6d73
a7111d1
 
19a6d73
a7111d1
46bb495
 
 
 
19a6d73
46bb495
 
 
d160dc6
3920f5c
 
1eb87a5
 
19a6d73
 
46bb495
 
 
19a6d73
 
46bb495
19a6d73
46bb495
3920f5c
 
19a6d73
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb87a5
 
19a6d73
 
 
 
 
d85fde4
 
19a6d73
 
3920f5c
19a6d73
a7111d1
19a6d73
3920f5c
19a6d73
 
 
 
d160dc6
19a6d73
 
 
 
 
 
 
d85fde4
d160dc6
19a6d73
 
 
a7111d1
19a6d73
3920f5c
19a6d73
 
0652978
d160dc6
19a6d73
 
 
3920f5c
19a6d73
 
 
3920f5c
19a6d73
3920f5c
 
d160dc6
3920f5c
 
 
 
 
 
 
19a6d73
 
 
 
 
a7111d1
19a6d73
3920f5c
 
19a6d73
 
3920f5c
 
d160dc6
19a6d73
 
 
1665fe1
 
19a6d73
 
 
 
 
 
 
 
 
1665fe1
19a6d73
 
 
 
 
 
 
 
 
 
 
 
 
 
98889c8
 
19a6d73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# app.py
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import logging
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
from utils import make_grid, interpolate_grid

# Configuração de logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Configurações
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"


def load_thera_model(repo_id: str, filename: str):
    """Carrega modelo com verificação de segurança"""
    try:
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        with open(model_path, 'rb') as fh:
            checkpoint = pickle.load(fh)
        return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
    except Exception as e:
        logger.error(f"Erro ao carregar modelo: {str(e)}")
        raise


# Inicialização dos modelos
try:
    logger.info("Carregando modelos...")
    model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float32
    ).to(TORCH_DEVICE)
    pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
    feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
    depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
except Exception as e:
    logger.error(f"Falha na inicialização: {str(e)}")
    raise


def adjust_size(original: int, scale: float) -> int:
    """Ajuste de tamanho com limites seguros"""
    scaled = int(original * scale)
    adjusted = (scaled // 8) * 8  # Divisível por 8
    return max(32, adjusted)  # Mínimo absoluto


def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
    """Pipeline completo com tratamento robusto"""
    try:
        # Pré-processamento
        image = image.convert("RGB")
        orig_w, orig_h = image.size

        # Cálculo do tamanho alvo
        new_h = adjust_size(orig_h, scale_factor)
        new_w = adjust_size(orig_w, scale_factor)
        logger.info(f"Redimensionando: {orig_h}x{orig_w}{new_h}x{new_w}")

        # Gerar grid de coordenadas
        coords = make_grid((new_h, new_w))
        logger.debug(f"Dimensões do grid: {coords.shape}")

        # Verificação crítica
        if coords.shape[1:3] != (new_h, new_w):
            raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})")

        # Super-resolução
        source = jnp.array(image).astype(jnp.float32) / 255.0
        source = source[jnp.newaxis, ...]  # Adicionar batch

        t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
        upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))

        # Pós-processamento
        upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))

        # Bas-Relief
        result = pipe(
            prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
            image=upscaled_img,
            strength=0.7,
            num_inference_steps=30
        )
        bas_relief = result.images[0]

        # Mapa de profundidade
        inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
        with torch.no_grad():
            depth = depth_model(**inputs).predicted_depth

        depth_map = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=bas_relief.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        # Normalização
        depth_min = depth_map.min()
        depth_max = depth_map.max()
        depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
        depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))

        return upscaled_img, bas_relief, depth_img

    except Exception as e:
        logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
        raise gr.Error(f"Processamento falhou: {str(e)}")


# Interface
with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app:
    gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(label="Imagem de Entrada", type="pil")
            prompt = gr.Textbox(
                label="Descrição do Relevo",
                value="Ainsanely detailed and complex engraving relief, ultra-high definition",
                placeholder="Descreva o estilo desejado..."
            )
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            btn = gr.Button("Processar Imagem", variant="primary")

        with gr.Column():
            gr.Markdown("## Resultados")
            with gr.Tabs():
                with gr.TabItem("Super Resolução"):
                    upscaled_output = gr.Image(label="Resultado Super Resolução")
                with gr.TabItem("Bas-Relief"):
                    basrelief_output = gr.Image(label="Relevo Gerado")
                with gr.TabItem("Profundidade"):
                    depth_output = gr.Image(label="Mapa de Profundidade")

    btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[upscaled_output, basrelief_output, depth_output]
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)