File size: 4,950 Bytes
98889c8
 
1eb87a5
a7111d1
98889c8
1665fe1
3920f5c
a7111d1
3920f5c
1eb87a5
a7111d1
 
 
 
 
 
1eb87a5
3920f5c
 
 
1eb87a5
 
a7111d1
3920f5c
 
 
 
a02c6d7
 
42a2e7b
3920f5c
a02c6d7
3920f5c
 
 
a02c6d7
3920f5c
a7111d1
3920f5c
1eb87a5
 
 
3920f5c
1eb87a5
 
a7111d1
3920f5c
a7111d1
3920f5c
1eb87a5
 
a7111d1
1eb87a5
3920f5c
 
a7111d1
 
3920f5c
 
 
a7111d1
3920f5c
a02c6d7
3920f5c
a02c6d7
3920f5c
a7111d1
0652978
3920f5c
0652978
3920f5c
 
 
a7111d1
3920f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7111d1
 
 
 
3920f5c
 
 
 
 
a7111d1
3920f5c
 
1665fe1
 
 
1eb87a5
a7111d1
 
 
 
1eb87a5
 
1665fe1
 
a7111d1
 
1eb87a5
 
 
 
 
 
1665fe1
98889c8
 
eb02bc3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import warnings
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera

# Configurações e supressão de avisos
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"


# 1. Carregar modelos do Thera ----------------------------------------------------------------
def load_thera_model(repo_id, filename):
    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        # Carregar estrutura completa de variáveis
        variables = check['model']  # Deve conter {'params': ...}
        backbone, size = check['backbone'], check['size']
    model = build_thera(3, backbone, size)
    return model, variables


print("Carregando Thera EDSR...")
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")

# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")

# 3. Carregar modelo de profundidade ----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)


# Pipeline principal --------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
    try:
        # 1. Super Resolução com Thera
        image = image.convert("RGB")
        source = np.array(image) / 255.0
        target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))

        source_jax = jax.device_put(source, JAX_DEVICE)
        t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)

        # Chamada corrigida com estrutura de variáveis correta
        upscaled = model_edsr.apply(
            variables_edsr,  # Estrutura completa {'params': ...}
            source_jax,
            t,
            target_shape
        )

        upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))

        # 2. Gerar Bas-Relief
        full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
        bas_relief = pipe(
            prompt=full_prompt,
            image=upscaled_pil,
            strength=0.7,
            num_inference_steps=25,
            guidance_scale=7.5
        ).images[0]

        # 3. Calcular Depth Map
        inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
        with torch.no_grad():
            outputs = depth_model(**inputs)
            depth = outputs.predicted_depth

        depth_map = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=bas_relief.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
        depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))

        return upscaled_pil, bas_relief, depth_pil

    except Exception as e:
        raise gr.Error(f"Erro no processamento: {str(e)}")


# Interface Gradio ----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
    gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Imagem de Entrada")
            prompt = gr.Textbox(
                label="Descrição do Relevo",
                value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
            )
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            btn = gr.Button("Processar")

        with gr.Column():
            img_upscaled = gr.Image(label="Imagem Super Resolvida")
            img_basrelief = gr.Image(label="Resultado Bas-Relief")
            img_depth = gr.Image(label="Mapa de Profundidade")

    btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[img_upscaled, img_basrelief, img_depth]
    )

if __name__ == "__main__":
    app.launch(share=False)