File size: 6,910 Bytes
19a6d73
98889c8
19a6d73
1eb87a5
a7111d1
98889c8
1665fe1
19a6d73
 
d160dc6
19a6d73
a7111d1
 
19a6d73
a7111d1
46bb495
 
 
 
19a6d73
46bb495
 
 
d160dc6
3920f5c
 
1eb87a5
 
19a6d73
1557411
46bb495
 
 
19a6d73
1557411
 
 
 
 
 
 
19a6d73
46bb495
19a6d73
46bb495
3920f5c
 
1557411
19a6d73
1557411
19a6d73
1557411
 
19a6d73
 
 
 
 
1557411
 
19a6d73
 
1557411
19a6d73
1557411
19a6d73
1eb87a5
 
1557411
 
 
 
 
 
 
 
 
 
 
d85fde4
 
19a6d73
1557411
3920f5c
1557411
 
 
 
 
a7111d1
19a6d73
1557411
3920f5c
1557411
 
 
d160dc6
19a6d73
1557411
 
 
 
 
 
 
 
 
d85fde4
1557411
19a6d73
1557411
19a6d73
1557411
a7111d1
1557411
 
19a6d73
3920f5c
1557411
19a6d73
1557411
0652978
1557411
19a6d73
 
 
3920f5c
1557411
 
19a6d73
 
1557411
3920f5c
1557411
3920f5c
 
d160dc6
3920f5c
1557411
3920f5c
 
 
 
 
 
1557411
19a6d73
 
 
 
1557411
a7111d1
19a6d73
3920f5c
 
19a6d73
1557411
3920f5c
 
1557411
 
19a6d73
 
1665fe1
1557411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a6d73
 
1557411
 
19a6d73
98889c8
 
19a6d73
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# app.py
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import logging
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
from utils import make_grid, interpolate_grid

# Configuração de logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Configurações
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"


def load_thera_model(repo_id: str, filename: str):
    """Carrega modelo com múltiplas verificações"""
    try:
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        with open(model_path, 'rb') as fh:
            checkpoint = pickle.load(fh)

        # Verificar estrutura do checkpoint
        required_keys = {'model', 'backbone', 'size'}
        if not required_keys.issubset(checkpoint.keys()):
            missing = required_keys - checkpoint.keys()
            raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}")

        return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
    except Exception as e:
        logger.error(f"Erro ao carregar modelo: {str(e)}")
        raise


# Inicialização segura
try:
    logger.info("Inicializando modelos...")
    model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")

    # Pipeline SDXL
    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float32
    ).to(TORCH_DEVICE)
    pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")

    # Modelo de profundidade
    feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
    depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)

except Exception as e:
    logger.error(f"Falha crítica na inicialização: {str(e)}")
    raise


def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]:
    """Calcula tamanho garantindo estabilidade numérica"""
    h, w = original
    new_h = int(h * scale)
    new_w = int(w * scale)

    # Ajustar para múltiplo de 8
    new_h = max(32, new_h - new_h % 8)
    new_w = max(32, new_w - new_w % 8)

    return (new_h, new_w)


def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
    """Pipeline completo com tratamento de erros robusto"""
    try:
        # Verificação inicial
        if not image:
            raise ValueError("Nenhuma imagem fornecida")

        # Conversão segura para RGB
        image = image.convert("RGB")
        orig_w, orig_h = image.size
        logger.info(f"Processando imagem: {orig_w}x{orig_h}")

        # Cálculo do novo tamanho
        new_h, new_w = safe_resize((orig_h, orig_w), scale_factor)
        logger.info(f"Novo tamanho calculado: {new_h}x{new_w}")

        # Gerar grid de coordenadas
        grid = make_grid((new_h, new_w))
        logger.debug(f"Grid gerado: {grid.shape}")

        # Verificação crítica do grid
        if grid.shape[1:3] != (new_h, new_w):
            raise RuntimeError(
                f"Incompatibilidade de dimensões: "
                f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}"
            )

        # Pré-processamento da imagem
        source = jnp.array(image).astype(jnp.float32) / 255.0
        source = source[jnp.newaxis, ...]  # Adicionar dimensão de batch

        # Parâmetro de escala
        t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)

        # Processamento Thera
        upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))

        # Conversão para PIL
        upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
        logger.info(f"Imagem super-resolvida: {upscaled_img.size}")

        # Geração do Bas-Relief
        result = pipe(
            prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
            image=upscaled_img,
            strength=0.7,
            num_inference_steps=30,
            guidance_scale=7.5
        )
        bas_relief = result.images[0]
        logger.info(f"Bas-Relief gerado: {bas_relief.size}")

        # Cálculo da profundidade
        inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
        with torch.no_grad():
            depth = depth_model(**inputs).predicted_depth

        # Redimensionamento
        depth_map = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=bas_relief.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        # Normalização e conversão
        depth_min = depth_map.min()
        depth_max = depth_map.max()
        depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
        depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
        logger.info("Mapa de profundidade calculado")

        return upscaled_img, bas_relief, depth_img

    except Exception as e:
        logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
        raise gr.Error(f"Falha no processamento: {str(e)}")


# Interface Gradio
with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app:
    gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")

    with gr.Row():
        input_col = gr.Column()
        output_col = gr.Column()

    with input_col:
        img_input = gr.Image(label="Carregar Imagem", type="pil", height=300)
        prompt = gr.Textbox(
            label="Descrição do Relevo",
            value="A insanely detailed and complex engraving relief, ultra-high definition",
            placeholder="Descreva o estilo desejado..."
        )
        scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala")
        process_btn = gr.Button("Iniciar Processamento", variant="primary")

    with output_col:
        with gr.Tabs():
            with gr.TabItem("Super Resolução"):
                upscaled_output = gr.Image(label="Resultado", show_label=False)
            with gr.TabItem("Bas-Relief"):
                basrelief_output = gr.Image(label="Relevo", show_label=False)
            with gr.TabItem("Profundidade"):
                depth_output = gr.Image(label="Mapa 3D", show_label=False)

    process_btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[upscaled_output, basrelief_output, depth_output],
        api_name="processar"
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)