File size: 3,900 Bytes
98889c8
19a6d73
98889c8
b82dc7d
19a6d73
b82dc7d
f41a4a7
a7111d1
b82dc7d
f41a4a7
b82dc7d
 
f41a4a7
 
 
b82dc7d
 
 
 
 
 
 
 
 
f41a4a7
b82dc7d
 
 
 
 
 
 
 
 
f41a4a7
b82dc7d
 
 
 
 
 
 
 
46bb495
 
b82dc7d
 
 
 
f41a4a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b82dc7d
f41a4a7
b82dc7d
 
f41a4a7
 
b82dc7d
 
f41a4a7
b82dc7d
f41a4a7
 
 
b82dc7d
 
 
f41a4a7
 
b82dc7d
f41a4a7
b82dc7d
 
 
f41a4a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a6d73
98889c8
f41a4a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation

# Fix de compatibilidade
file_download.cached_download = file_download.hf_hub_download

# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"


def load_thera_model(repo_id):
    model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        return build_thera(3, check['backbone'], check['size']), check['model']


model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)

# ========== Configuração do SDXL + Depth ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch_dtype
).to(device)

pipe.load_lora_weights(
    "KappaNeuro/bas-relief",
    weight_name="BAS-RELIEF.safetensors",
    peft_backend="peft"
)

feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)


# ========== Fluxo Integrado ==========
def full_pipeline(image, scale_factor, model_type, style_prompt):
    # 1. Super-Resolution
    sr_model = model_edsr if model_type == "EDSR" else model_rdn
    sr_params = params_edsr if model_type == "EDSR" else params_rdn
    sr_image = process(np.array(image) / 255., sr_model, sr_params,
                       (round(image.size[1] * scale_factor),
                        round(image.size[0] * scale_factor)),
                       True)

    # 2. Bas-Relief Style Transfer
    prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
    bas_relief = pipe(
        prompt=prompt,
        image=sr_image,
        strength=0.6,
        num_inference_steps=25,
        guidance_scale=7.5
    ).images[0]

    # 3. Depth Map Estimation
    inputs = feature_extractor(bas_relief, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        depth = outputs.predicted_depth

    depth = torch.nn.functional.interpolate(
        depth.unsqueeze(1),
        size=bas_relief.size[::-1],
        mode="bicubic"
    ).squeeze().cpu().numpy()

    depth = (depth - depth.min()) / (depth.max() - depth.min())
    depth = (depth * 255).astype(np.uint8)

    return sr_image, bas_relief, Image.fromarray(depth)


# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
    gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil")
            scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
            model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
            style_prompt = gr.Textbox(label="Style Prompt",
                                      placeholder="marble sculpture, ancient greek style")
            process_btn = gr.Button("Start Pipeline")

        with gr.Column():
            sr_output = gr.Image(label="Super-Resolution Result")
            style_output = gr.Image(label="Bas-Relief Result")
            depth_output = gr.Image(label="Depth Map")

    process_btn.click(
        full_pipeline,
        inputs=[input_image, scale, model_type, style_prompt],
        outputs=[sr_output, style_output, depth_output]
    )

app.launch(debug=False)