File size: 4,258 Bytes
98889c8
19a6d73
98889c8
b82dc7d
19a6d73
b82dc7d
d160dc6
a7111d1
b82dc7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bb495
 
b82dc7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a6d73
98889c8
b82dc7d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation

# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"


# Carregar modelos Thera
def load_thera_model(repo_id):
    model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
    with open(model_path, 'rb') as fh:
        check = pickle.load(fh)
        params, backbone, size = check['model'], check['backbone'], check['size']
        model = build_thera(3, backbone, size)
    return model, params


model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)

# ========== Configuração do SDXL + Depth ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32

# Carregar modelos de geração
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch_dtype
).to(device)

pipe.load_lora_weights(
    "KappaNeuro/bas-relief",
    weight_name="BAS-RELIEF.safetensors",
    peft_backend="peft"
)

feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)


# ========== Funções Principais ==========
def super_resolution(image, scale_factor, model_type):
    model = model_edsr if model_type == "EDSR" else model_rdn
    params = params_edsr if model_type == "EDSR" else params_rdn

    source = np.asarray(image) / 255.
    target_shape = (
        round(source.shape[0] * scale_factor),
        round(source.shape[1] * scale_factor),
    )

    output = process(source, model, params, target_shape, do_ensemble=True)
    return Image.fromarray(np.asarray(output))


def generate_bas_relief(prompt):
    full_prompt = f"BAS-RELIEF {prompt}"
    image = pipe(
        prompt=full_prompt,
        num_inference_steps=25,
        guidance_scale=7.5,
        height=512,
        width=512
    ).images[0]

    inputs = feature_extractor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        depth_map = outputs.predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic"
    ).squeeze().cpu().numpy()

    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    depth_map = (depth_map * 255).astype(np.uint8)

    return image, Image.fromarray(depth_map)


# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Generator") as app:
    gr.Markdown("# 🔥 TheraSR + Bas-Relief Generator")
    gr.Markdown("Combine aliasing-free super-resolution with artistic bas-relief generation")

    with gr.Tabs():
        with gr.TabItem("🖼 Super-Resolution"):
            with gr.Row():
                sr_input = gr.Image(label="Input Image", type="pil")
                sr_output = gr.Image(label="Super-Resolution Result")
            sr_scale = gr.Slider(1.0, 6.0, value=2.0, label="Scale Factor")
            sr_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Model Type")
            sr_btn = gr.Button("Enhance Resolution")

        with gr.TabItem("🎨 Generate Bas-Relief"):
            with gr.Row():
                text_input = gr.Textbox(label="Art Prompt", placeholder="Roman soldier marble relief...")
            with gr.Row():
                gen_output = gr.Image(label="Generated Art")
                depth_output = gr.Image(label="Depth Map")
            gen_btn = gr.Button("Generate Artwork")

    # Event Handlers
    sr_btn.click(
        super_resolution,
        inputs=[sr_input, sr_scale, sr_model],
        outputs=sr_output
    )

    gen_btn.click(
        generate_bas_relief,
        inputs=text_input,
        outputs=[gen_output, depth_output]
    )

# Configuração do Hugging Face
app.launch(debug=False, share=True)