File size: 3,900 Bytes
98889c8 19a6d73 98889c8 b82dc7d 19a6d73 b82dc7d f41a4a7 a7111d1 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d 46bb495 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 19a6d73 98889c8 f41a4a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# Fix de compatibilidade
file_download.cached_download = file_download.hf_hub_download
# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"
def load_thera_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
return build_thera(3, check['backbone'], check['size']), check['model']
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
# ========== Configuração do SDXL + Depth ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft"
)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
# ========== Fluxo Integrado ==========
def full_pipeline(image, scale_factor, model_type, style_prompt):
# 1. Super-Resolution
sr_model = model_edsr if model_type == "EDSR" else model_rdn
sr_params = params_edsr if model_type == "EDSR" else params_rdn
sr_image = process(np.array(image) / 255., sr_model, sr_params,
(round(image.size[1] * scale_factor),
round(image.size[0] * scale_factor)),
True)
# 2. Bas-Relief Style Transfer
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
bas_relief = pipe(
prompt=prompt,
image=sr_image,
strength=0.6,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Depth Map Estimation
inputs = feature_extractor(bas_relief, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = (depth * 255).astype(np.uint8)
return sr_image, bas_relief, Image.fromarray(depth)
# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
style_prompt = gr.Textbox(label="Style Prompt",
placeholder="marble sculpture, ancient greek style")
process_btn = gr.Button("Start Pipeline")
with gr.Column():
sr_output = gr.Image(label="Super-Resolution Result")
style_output = gr.Image(label="Bas-Relief Result")
depth_output = gr.Image(label="Depth Map")
process_btn.click(
full_pipeline,
inputs=[input_image, scale, model_type, style_prompt],
outputs=[sr_output, style_output, depth_output]
)
app.launch(debug=False) |