|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import jax |
|
import pickle |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from model import build_thera |
|
from super_resolve import process |
|
from diffusers import StableDiffusionXLPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
|
|
|
REPO_ID_EDSR = "prs-eth/thera-edsr-pro" |
|
REPO_ID_RDN = "prs-eth/thera-rdn-pro" |
|
|
|
|
|
|
|
def load_thera_model(repo_id): |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") |
|
with open(model_path, 'rb') as fh: |
|
check = pickle.load(fh) |
|
params, backbone, size = check['model'], check['backbone'], check['size'] |
|
model = build_thera(3, backbone, size) |
|
return model, params |
|
|
|
|
|
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR) |
|
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype |
|
).to(device) |
|
|
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors", |
|
peft_backend="peft" |
|
) |
|
|
|
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) |
|
|
|
|
|
|
|
def super_resolution(image, scale_factor, model_type): |
|
model = model_edsr if model_type == "EDSR" else model_rdn |
|
params = params_edsr if model_type == "EDSR" else params_rdn |
|
|
|
source = np.asarray(image) / 255. |
|
target_shape = ( |
|
round(source.shape[0] * scale_factor), |
|
round(source.shape[1] * scale_factor), |
|
) |
|
|
|
output = process(source, model, params, target_shape, do_ensemble=True) |
|
return Image.fromarray(np.asarray(output)) |
|
|
|
|
|
def generate_bas_relief(prompt): |
|
full_prompt = f"BAS-RELIEF {prompt}" |
|
image = pipe( |
|
prompt=full_prompt, |
|
num_inference_steps=25, |
|
guidance_scale=7.5, |
|
height=512, |
|
width=512 |
|
).images[0] |
|
|
|
inputs = feature_extractor(image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = depth_model(**inputs) |
|
depth_map = outputs.predicted_depth |
|
|
|
depth_map = torch.nn.functional.interpolate( |
|
depth_map.unsqueeze(1), |
|
size=image.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) |
|
depth_map = (depth_map * 255).astype(np.uint8) |
|
|
|
return image, Image.fromarray(depth_map) |
|
|
|
|
|
|
|
with gr.Blocks(title="TheraSR + Bas-Relief Generator") as app: |
|
gr.Markdown("# 🔥 TheraSR + Bas-Relief Generator") |
|
gr.Markdown("Combine aliasing-free super-resolution with artistic bas-relief generation") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("🖼 Super-Resolution"): |
|
with gr.Row(): |
|
sr_input = gr.Image(label="Input Image", type="pil") |
|
sr_output = gr.Image(label="Super-Resolution Result") |
|
sr_scale = gr.Slider(1.0, 6.0, value=2.0, label="Scale Factor") |
|
sr_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Model Type") |
|
sr_btn = gr.Button("Enhance Resolution") |
|
|
|
with gr.TabItem("🎨 Generate Bas-Relief"): |
|
with gr.Row(): |
|
text_input = gr.Textbox(label="Art Prompt", placeholder="Roman soldier marble relief...") |
|
with gr.Row(): |
|
gen_output = gr.Image(label="Generated Art") |
|
depth_output = gr.Image(label="Depth Map") |
|
gen_btn = gr.Button("Generate Artwork") |
|
|
|
|
|
sr_btn.click( |
|
super_resolution, |
|
inputs=[sr_input, sr_scale, sr_model], |
|
outputs=sr_output |
|
) |
|
|
|
gen_btn.click( |
|
generate_bas_relief, |
|
inputs=text_input, |
|
outputs=[gen_output, depth_output] |
|
) |
|
|
|
|
|
app.launch(debug=False, share=True) |