File size: 4,311 Bytes
98889c8
 
1eb87a5
98889c8
1665fe1
3920f5c
 
1eb87a5
98889c8
3920f5c
1eb87a5
3920f5c
 
 
1eb87a5
 
3920f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb87a5
 
 
3920f5c
1eb87a5
 
3920f5c
 
1eb87a5
3920f5c
1eb87a5
 
3920f5c
1eb87a5
3920f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1665fe1
 
 
1eb87a5
3920f5c
1eb87a5
 
1665fe1
 
3920f5c
 
1eb87a5
 
 
 
 
 
1665fe1
98889c8
 
1665fe1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
import pickle
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from model import build_thera  # Importar do código original do Thera

# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"


# 1. Carregar modelos do Thera ------------------------------------------------------------------
def load_thera_model(repo_id, filename):
    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        params, backbone, size = check['model'], check['backbone'], check['size']
    model = build_thera(3, backbone, size)
    return model, params


print("Carregando Thera EDSR...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")

# 2. Carregar SDXL + LoRA ----------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")

# 3. Carregar modelo de profundidade -----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)


# Pipeline principal ---------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
    try:
        # 1. Super Resolução com Thera
        source = np.array(image.convert("RGB")) / 255.0
        target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))

        # Converter para JAX array
        source_jax = jax.device_put(source, JAX_DEVICE)

        # Processar com Thera
        upscaled = model_edsr.apply(
            params_edsr,
            source_jax,
            target_shape,
            do_ensemble=True
        )
        upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))

        # 2. Gerar Bas-Relief
        full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
        bas_relief = pipe(
            prompt=full_prompt,
            image=upscaled_pil,
            strength=0.7,
            num_inference_steps=25,
            guidance_scale=7.5
        ).images[0]

        # 3. Calcular Depth Map
        inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
        with torch.no_grad():
            outputs = depth_model(**inputs)
            depth = outputs.predicted_depth

        depth_map = torch.nn.functional.interpolate(
            depth.unsqueeze(1),
            size=bas_relief.size[::-1],
            mode="bicubic"
        ).squeeze().cpu().numpy()

        return upscaled_pil, bas_relief, (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())

    except Exception as e:
        raise gr.Error(f"Erro no processamento: {str(e)}")


# Interface Gradio -----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
    gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")

    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Imagem de Entrada")
            prompt = gr.Textbox("insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution.", label="Descrição")
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
            btn = gr.Button("Processar")

        with gr.Column():
            img_upscaled = gr.Image(label="Super Resolvida")
            img_basrelief = gr.Image(label="Bas-Relief")
            img_depth = gr.Image(label="Mapa de Profundidade")

    btn.click(
        full_pipeline,
        inputs=[img_input, prompt, scale],
        outputs=[img_upscaled, img_basrelief, img_depth]
    )

if __name__ == "__main__":
    app.launch()