|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import jax |
|
import pickle |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download, file_download |
|
from model import build_thera |
|
from super_resolve import process |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
|
|
|
file_download.cached_download = file_download.hf_hub_download |
|
|
|
|
|
REPO_ID_EDSR = "prs-eth/thera-edsr-pro" |
|
REPO_ID_RDN = "prs-eth/thera-rdn-pro" |
|
|
|
|
|
def load_thera_model(repo_id): |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") |
|
with open(model_path, 'rb') as fh: |
|
check = pickle.load(fh) |
|
return build_thera(3, check['backbone'], check['size']), check['model'] |
|
|
|
|
|
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR) |
|
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN) |
|
|
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype |
|
).to(device) |
|
|
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors", |
|
peft_backend="peft" |
|
) |
|
|
|
|
|
depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) |
|
|
|
|
|
def full_pipeline(image, scale_factor, model_type, style_prompt): |
|
|
|
sr_model = model_edsr if model_type == "EDSR" else model_rdn |
|
sr_params = params_edsr if model_type == "EDSR" else params_rdn |
|
|
|
|
|
sr_jax = process(np.array(image) / 255., sr_model, sr_params, |
|
(round(image.size[1] * scale_factor), |
|
round(image.size[0] * scale_factor)), |
|
True) |
|
|
|
|
|
sr_np = np.asarray(sr_jax) |
|
sr_pil = Image.fromarray(sr_np) |
|
|
|
|
|
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture" |
|
bas_relief = pipe( |
|
prompt=prompt, |
|
image=sr_pil, |
|
strength=0.6, |
|
num_inference_steps=25, |
|
guidance_scale=7.5 |
|
).images[0] |
|
|
|
|
|
inputs = depth_processor(bas_relief, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = depth_model(**inputs) |
|
depth = outputs.predicted_depth |
|
|
|
depth = torch.nn.functional.interpolate( |
|
depth.unsqueeze(1), |
|
mode="bicubic", |
|
size=bas_relief.size[::-1] |
|
).squeeze().cpu().numpy() |
|
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) |
|
depth = (depth * 255).astype(np.uint8) |
|
|
|
return sr_pil, bas_relief, Image.fromarray(depth) |
|
|
|
|
|
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app: |
|
gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil") |
|
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor") |
|
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model") |
|
style_prompt = gr.Textbox( |
|
label="Style Prompt", |
|
value="insanely detailed and complex engraving relief, ultra-high definition" |
|
) |
|
process_btn = gr.Button("Start Pipeline") |
|
|
|
with gr.Column(): |
|
sr_output = gr.Image(label="Super-Resolution Result") |
|
style_output = gr.Image(label="Bas-Relief Result") |
|
depth_output = gr.Image(label="Depth Map") |
|
|
|
process_btn.click( |
|
full_pipeline, |
|
inputs=[input_image, scale, model_type, style_prompt], |
|
outputs=[sr_output, style_output, depth_output] |
|
) |
|
|
|
app.launch(debug=False) |