File size: 7,340 Bytes
357df1b
65579be
98889c8
19a6d73
98889c8
b82dc7d
19a6d73
b82dc7d
357df1b
a7111d1
b82dc7d
f41a4a7
b82dc7d
 
65579be
357df1b
 
 
 
 
 
 
 
 
 
 
 
eb719b4
 
357df1b
eb719b4
357df1b
 
 
 
 
 
 
 
 
 
 
 
 
f41a4a7
5779b8d
0fa6c9c
 
5779b8d
 
 
65579be
 
 
 
 
eb719b4
357df1b
 
 
 
 
eb719b4
357df1b
 
5779b8d
357df1b
65579be
 
5779b8d
 
 
 
 
 
65579be
5779b8d
357df1b
 
65579be
888435a
5779b8d
eb719b4
5779b8d
888435a
 
 
a09dd26
888435a
 
5779b8d
888435a
 
160e39b
 
888435a
 
eb719b4
888435a
5779b8d
357df1b
5779b8d
888435a
5779b8d
8094627
888435a
65579be
 
5779b8d
 
65579be
5779b8d
65579be
5779b8d
65579be
 
 
eb719b4
5779b8d
eb719b4
 
 
 
5779b8d
eb719b4
 
5779b8d
 
65579be
 
 
eb719b4
5779b8d
357df1b
eb719b4
5779b8d
eb719b4
357df1b
eb719b4
 
5779b8d
357df1b
eb719b4
 
5779b8d
 
eb719b4
 
357df1b
eb719b4
8094627
357df1b
eb719b4
5779b8d
 
357df1b
 
eb719b4
357df1b
eb719b4
357df1b
 
5779b8d
357df1b
 
 
 
5779b8d
 
eb719b4
357df1b
eb719b4
8094627
357df1b
eb719b4
357df1b
 
eb719b4
8094627
357df1b
65579be
 
 
eb719b4
5779b8d
357df1b
 
 
5779b8d
 
 
 
 
 
65579be
eb719b4
65579be
357df1b
eb719b4
 
 
 
 
5779b8d
 
eb719b4
19a6d73
98889c8
65579be
357df1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# app.py
import logging
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation


# ================== CONFIGURAÇÃO DE LOGGING ==================
class CustomLogger:
    def __init__(self, name):
        self.logger = logging.getLogger(name)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)

    def divider(self, text=None, length=60):
        if text:
            available_space = max(length - len(text) - 12, 1)
            msg = f"{'=' * 10} {text.upper()} {'=' * available_space}"
        else:
            msg = "=" * length
        self.logger.info(msg)

    def etapa(self, text):
        self.logger.info(f"▶ {text}")

    def success(self, text):
        self.logger.info(f"✓ {text}")

    def error(self, text):
        self.logger.error(f"✗ {text}")


logger = CustomLogger(__name__)

# ================== CONFIGURAÇÃO FORÇADA ==================
device = "cpu"
torch_dtype = torch.float32
logger.divider("Configuração Forçada")
logger.success(f"Dispositivo: {device.upper()}")
logger.success(f"Precisão: {str(torch_dtype).replace('torch.', '')}")


# ================== CARREGAMENTO DE MODELOS ==================
def carregar_modelo_thera(repo_id):
    try:
        logger.divider(f"Carregando {repo_id}")
        model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
        with open(model_path, 'rb') as f:
            check = pickle.load(f)
            model = build_thera(3, check['backbone'], check['size'])
            params = check['model']
        logger.success(f"{repo_id} carregado")
        return model, params
    except Exception as e:
        logger.error(f"Falha no carregamento: {str(e)}")
        return None, None


try:
    modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
    modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
except Exception as e:
    logger.error("Falha crítica nos modelos Thera")
    raise

# ================== PIPELINE ARTÍSTICO ==================
pipe = None
modelo_profundidade = None

try:
    logger.divider("Configurando Componentes Artísticos")

    # Pipeline principal
    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch_dtype,
        variant="fp32"
    ).to(device)

    # LoRA
    pipe.load_lora_weights(
        "KappaNeuro/bas-relief",
        weight_name="BAS-RELIEF.safetensors",
        peft_backend="peft"          # This is crucial
    )

    # Modelo de profundidade
    processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
    modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).float()

    logger.success("Componentes artísticos em float32")
except Exception as e:
    logger.warning(f"Recursos artísticos limitados: {str(e)}")
    print(e)
    pipe = None


# ================== PROCESSAMENTO PRINCIPAL ==================
def processar_imagem_completa(imagem, escala, modelo, prompt):
    try:
        logger.divider("Iniciando Processamento")

        # Converter entrada
        if not isinstance(imagem, Image.Image):
            imagem = Image.fromarray(imagem)

        # ========= 1. SUPER-RESOLUÇÃO =========
        logger.etapa("Processando Super-Resolução")
        modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn
        params_sr = params_edsr if modelo == "EDSR" else params_rdn

        sr_jax = process(
            np.array(imagem) / 255.,
            modelo_sr,
            params_sr,
            (round(imagem.height * escala),
             round(imagem.width * escala)),
            True
        )

        sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB")
        logger.success(f"SR: {sr_pil.size[0]}x{sr_pil.size[1]}")

        # ========= 2. ESTILO BAIXO-RELEVO =========
        arte_pil = sr_pil  # Fallback
        if pipe:
            try:
                logger.etapa("Aplicando Estilo")
                arte_pil = pipe(
                    prompt=f"BAS-RELIEF {prompt}, marble texture, 8k",
                    image=sr_pil,
                    strength=0.6,
                    num_inference_steps=25,
                    guidance_scale=7.0,
                    generator=torch.Generator(device).manual_seed(42)
                ).images[0]
                logger.success("Estilo aplicado")
            except Exception as e:
                logger.error(f"Erro no estilo: {str(e)}")
                print(e)

        # ========= 3. MAPA DE PROFUNDIDADE =========
        mapa_pil = arte_pil  # Fallback
        if modelo_profundidade:
            try:
                logger.etapa("Calculando Profundidade")
                inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device)
                with torch.no_grad():
                    depth = modelo_profundidade(**inputs).predicted_depth

                depth = torch.nn.functional.interpolate(
                    depth.unsqueeze(1),
                    size=arte_pil.size[::-1],
                    mode="bicubic"
                ).squeeze().cpu().numpy()

                depth = (depth - depth.min()) / (depth.max() - depth.min())
                mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
                logger.success("Profundidade calculada")
            except Exception as e:
                logger.error(f"Erro na profundidade: {str(e)}")
                print (e)

        return sr_pil, arte_pil, mapa_pil

    except Exception as e:
        logger.error(f"Erro fatal: {str(e)}")
        print(e)
        return None, None, None


# ================== INTERFACE GRADIO ==================
with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app:
    gr.Markdown("# 🏛 TheraSR - Processamento Completo em Float32")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="Imagem de Entrada", type="pil")
            slider_scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            radio_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo")
            text_prompt = gr.Textbox(
                label="Prompt de Estilo",
                value="ancient marble浮雕, ultra detailed, 8k cinematic"
            )
            btn_process = gr.Button("Processar", variant="primary")

        with gr.Column():
            output_sr = gr.Image(label="Super-Resolução", interactive=False)
            output_art = gr.Image(label="Arte em Relevo", interactive=False)
            output_depth = gr.Image(label="Mapa de Profundidade", interactive=False)

    btn_process.click(
        processar_imagem_completa,
        inputs=[input_img, slider_scale, radio_model, text_prompt],
        outputs=[output_sr, output_art, output_depth]
    )

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)