File size: 8,264 Bytes
65579be
98889c8
19a6d73
98889c8
b82dc7d
19a6d73
b82dc7d
f41a4a7
a7111d1
b82dc7d
f41a4a7
b82dc7d
 
65579be
 
 
 
 
 
f41a4a7
 
65579be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c7829e
65579be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a6d73
98889c8
65579be
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import logging
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation

# ================== CONFIGURAÇÃO INICIAL ==================
# Configurar sistema de logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Fix para compatibilidade do Hugging Face Hub
file_download.cached_download = file_download.hf_hub_download

# ================== CONFIGURAÇÃO DE HARDWARE ==================
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
logger.info(f"Dispositivo selecionado: {device.upper()}")
logger.info(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}")


# ================== CARREGAMENTO DE MODELOS ==================
def carregar_modelo_thera(repo_id):
    """Carrega modelos Thera do Hugging Face Hub"""
    try:
        logger.info(f"Carregando modelo Thera: {repo_id}")
        caminho_modelo = hf_hub_download(repo_id=repo_id, filename="model.pkl")
        with open(caminho_modelo, 'rb') as arquivo:
            dados = pickle.load(arquivo)
            modelo = build_thera(3, dados['backbone'], dados['size'])
            parametros = dados['model']
        logger.success(f"Modelo {repo_id} carregado com sucesso")
        return modelo, parametros
    except Exception as erro:
        logger.error(f"Falha ao carregar {repo_id}: {str(erro)}")
        raise


# Carregar modelos Thera
try:
    logger.divider("Carregando Modelos Thera")
    modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
    modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
except Exception as erro:
    logger.critical("Falha crítica no carregamento dos modelos Thera")
    raise

# ================== PIPELINE DE ARTE ==================
# Configurar SDXL + LoRA
try:
    logger.divider("Configurando Pipeline de Arte")
    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch_dtype,
        variant="fp16",
        use_safetensors=True
    ).to(device)

    pipe.load_lora_weights(
        "KappaNeuro/bas-relief",
        weight_name="BAS-RELIEF.safetensors",
        adapter_name="bas_relief"
    )
    logger.success("Pipeline SDXL + LoRA configurado")
except Exception as erro:
    logger.error(f"Erro no SDXL: {str(erro)}")
    pipe = None

# Configurar modelo de profundidade
try:
    logger.divider("Configurando Modelo de Profundidade")
    processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
    modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
    logger.success("Modelo de profundidade pronto")
except Exception as erro:
    logger.error(f"Erro no modelo de profundidade: {str(erro)}")
    modelo_profundidade = None


# ================== FLUXO DE PROCESSAMENTO ==================
def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo):
    """Executa todo o fluxo de processamento"""
    try:
        logger.divider("Iniciando novo processamento")

        # ========= FASE 1: SUPER-RESOLUÇÃO =========
        logger.etapa("Processando Super-Resolução")
        modelo_sr = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn
        parametros_sr = params_edsr if modelo_escolhido == "EDSR" else params_rdn

        # Converter e validar entrada
        if not isinstance(imagem, Image.Image):
            logger.warning("Convertendo entrada numpy para PIL Image")
            imagem = Image.fromarray(imagem)

        # Processar super-resolução
        imagem_sr_jax = process(
            np.array(imagem) / 255.,
            modelo_sr,
            parametros_sr,
            (round(imagem.size[1] * fator_escala),
             round(imagem.size[0] * fator_escala)),
            True
        )

        # Converter para formato compatível
        imagem_sr_pil = Image.fromarray(np.array(imagem_sr_jax)).convert("RGB")
        logger.success(f"Super-Resolução concluída: {imagem_sr_pil.size}")

        # ========= FASE 2: ESTILO BAIXO-RELEVO =========
        if device == "cpu" or not pipe:
            logger.warning("GPU não disponível - Pulando estilo")
            return imagem_sr_pil, None, None

        logger.etapa("Aplicando Estilo Baixo-Relevo")
        prompt_completo = f"BAS-RELIEF {prompt_estilo}, intricate carving, marble texture, 8k"

        with torch.autocast(device_type=device.split(':')[0], dtype=torch_dtype):
            imagem_estilizada = pipe(
                prompt=prompt_completo,
                image=imagem_sr_pil,
                strength=0.7,
                num_inference_steps=35,
                guidance_scale=7.5,
                output_type="pil"
            ).images[0]

        logger.success(f"Estilo aplicado: {imagem_estilizada.size}")

        # ========= FASE 3: MAPA DE PROFUNDIDADE =========
        logger.etapa("Gerando Mapa de Profundidade")
        inputs = processador_profundidade(
            images=imagem_estilizada,
            return_tensors="pt"
        ).to(device, dtype=torch_dtype)

        with torch.no_grad(), torch.autocast(device_type=device.split(':')[0]):
            outputs = modelo_profundidade(**inputs)
            profundidade = outputs.predicted_depth

        # Processar profundidade
        profundidade = torch.nn.functional.interpolate(
            profundidade.unsqueeze(1).float(),  # Converter para float32
            size=imagem_estilizada.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        # Normalizar e converter
        profundidade = (profundidade - profundidade.min()) / (profundidade.max() - profundidade.min() + 1e-8)
        mapa_profundidade = Image.fromarray((profundidade * 255).astype(np.uint8))

        logger.success("Processamento completo")
        return imagem_sr_pil, imagem_estilizada, mapa_profundidade

    except Exception as erro:
        logger.error(f"ERRO NO PIPELINE: {str(erro)}", exc_info=True)
        return imagem_sr_pil if 'imagem_sr_pil' in locals() else None, None, None


# ================== INTERFACE GRADIO ==================
with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app:
    gr.Markdown("""
    # 🎨 TheraSR Art Suite
    **Combine super-resolução aliasing-free com geração artística de baixo-relevo**
    """)

    with gr.Row(variant="panel"):
        with gr.Column(scale=1):
            entrada_imagem = gr.Image(label="🖼 Imagem de Entrada", type="pil")
            seletor_modelo = gr.Radio(
                ["EDSR", "RDN"],
                value="EDSR",
                label="🔧 Modelo de Super-Resolução"
            )
            controle_escala = gr.Slider(
                1.0, 6.0,
                value=2.0,
                step=0.1,
                label="🔍 Fator de Escala"
            )
            entrada_prompt = gr.Textbox(
                label="📝 Prompt de Estilo",
                value="insanely detailed and complex engraving relief, ultra HD 8k",
                placeholder="Descreva o estilo desejado..."
            )
            botao_processar = gr.Button("🚀 Processar Imagem", variant="primary")

        with gr.Column(scale=2):
            saida_sr = gr.Image(label="✨ Super-Resolução", interactive=False)
            saida_arte = gr.Image(label="🖌 Arte em Baixo-Relevo", interactive=False)
            saida_profundidade = gr.Image(label="🗺 Mapa de Profundidade", interactive=False)

    # Configurar eventos
    botao_processar.click(
        fn=pipeline_completo,
        inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt],
        outputs=[saida_sr, saida_arte, saida_profundidade]
    )

# ================== INICIALIZAÇÃO ==================
if __name__ == "__main__":
    app.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        share=False,
        debug=False
    )