File size: 4,273 Bytes
98889c8
19a6d73
98889c8
b82dc7d
19a6d73
b82dc7d
f41a4a7
a7111d1
b82dc7d
f41a4a7
b82dc7d
 
f41a4a7
 
 
b82dc7d
 
 
 
 
 
 
 
 
f41a4a7
b82dc7d
 
 
 
 
 
1f384c6
 
 
 
 
b82dc7d
f41a4a7
b82dc7d
 
 
 
 
 
 
 
46bb495
 
8c7829e
 
b82dc7d
 
f41a4a7
 
8c7829e
f41a4a7
 
 
8c7829e
 
 
 
 
 
 
 
 
 
 
f41a4a7
 
 
8c7829e
f41a4a7
b82dc7d
f41a4a7
b82dc7d
 
8c7829e
 
b82dc7d
 
f41a4a7
b82dc7d
f41a4a7
 
8c7829e
 
b82dc7d
 
8c7829e
f41a4a7
b82dc7d
8c7829e
b82dc7d
 
f41a4a7
 
 
 
 
 
 
 
8c7829e
 
 
 
f41a4a7
 
 
 
 
 
 
 
 
 
 
19a6d73
98889c8
f41a4a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation

# Fix de compatibilidade
file_download.cached_download = file_download.hf_hub_download

# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"


def load_thera_model(repo_id):
    model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        return build_thera(3, check['backbone'], check['size']), check['model']


model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)

# ========== Configuração do SDXL + Depth ==========
device = "cpu"
torch_dtype = torch.float32
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if device == "cuda" else torch.float32


pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch_dtype
).to(device)

pipe.load_lora_weights(
    "KappaNeuro/bas-relief",
    weight_name="BAS-RELIEF.safetensors",
    peft_backend="peft"
)

# ========== Configuração do Modelo de Profundidade ==========
depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")  # Nome padronizado
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)

# ========== Fluxo Integrado ==========
def full_pipeline(image, scale_factor, model_type, style_prompt):
    # 1. Super-Resolution (JAX)
    sr_model = model_edsr if model_type == "EDSR" else model_rdn
    sr_params = params_edsr if model_type == "EDSR" else params_rdn

    # Processar e converter para numpy array
    sr_jax = process(np.array(image) / 255., sr_model, sr_params,
                     (round(image.size[1] * scale_factor),
                      round(image.size[0] * scale_factor)),
                     True)

    # Conversão crítica: JAX Array → numpy → PIL
    sr_np = np.asarray(sr_jax)
    sr_pil = Image.fromarray(sr_np)

    # 2. Style Transfer (PyTorch)
    prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
    bas_relief = pipe(
        prompt=prompt,
        image=sr_pil,  # Usar PIL Image diretamente
        strength=0.6,
        num_inference_steps=25,
        guidance_scale=7.5
    ).images[0]

    # 3. Depth Map
    inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        depth = outputs.predicted_depth

    depth = torch.nn.functional.interpolate(
        depth.unsqueeze(1),
        mode="bicubic",
        size=bas_relief.size[::-1]
    ).squeeze().cpu().numpy()

    depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
    depth = (depth * 255).astype(np.uint8)

    return sr_pil, bas_relief, Image.fromarray(depth)

# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
    gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil")
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
            model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
            style_prompt = gr.Textbox(
                label="Style Prompt",
                value="insanely detailed and complex engraving relief, ultra-high definition"  # <-- Alteração aqui
            )
            process_btn = gr.Button("Start Pipeline")

        with gr.Column():
            sr_output = gr.Image(label="Super-Resolution Result")
            style_output = gr.Image(label="Bas-Relief Result")
            depth_output = gr.Image(label="Depth Map")

    process_btn.click(
        full_pipeline,
        inputs=[input_image, scale, model_type, style_prompt],
        outputs=[sr_output, style_output, depth_output]
    )

app.launch(debug=False)