sculpt / app.py
ds1david's picture
fixing bugs
85a119c
raw
history blame
4.27 kB
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# Fix de compatibilidade
file_download.cached_download = file_download.hf_hub_download
# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"
def load_thera_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
return build_thera(3, check['backbone'], check['size']), check['model']
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
# ========== Configuração do SDXL + Depth ==========
device = "cpu"
torch_dtype = torch.float32
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft"
)
# ========== Configuração do Modelo de Profundidade ==========
depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Nome padronizado
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
# ========== Fluxo Integrado ==========
def full_pipeline(image, scale_factor, model_type, style_prompt):
# 1. Super-Resolution (JAX)
sr_model = model_edsr if model_type == "EDSR" else model_rdn
sr_params = params_edsr if model_type == "EDSR" else params_rdn
# Processar e converter para numpy array
sr_jax = process(np.array(image) / 255., sr_model, sr_params,
(round(image.size[1] * scale_factor),
round(image.size[0] * scale_factor)),
True)
# Conversão crítica: JAX Array → numpy → PIL
sr_np = np.asarray(sr_jax)
sr_pil = Image.fromarray(sr_np)
# 2. Style Transfer (PyTorch)
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
bas_relief = pipe(
prompt=prompt,
image=sr_pil, # Usar PIL Image diretamente
strength=0.6,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Depth Map
inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
mode="bicubic",
size=bas_relief.size[::-1]
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
depth = (depth * 255).astype(np.uint8)
return sr_pil, bas_relief, Image.fromarray(depth)
# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
style_prompt = gr.Textbox(
label="Style Prompt",
value="insanely detailed and complex engraving relief, ultra-high definition" # <-- Alteração aqui
)
process_btn = gr.Button("Start Pipeline")
with gr.Column():
sr_output = gr.Image(label="Super-Resolution Result")
style_output = gr.Image(label="Bas-Relief Result")
depth_output = gr.Image(label="Depth Map")
process_btn.click(
full_pipeline,
inputs=[input_image, scale, model_type, style_prompt],
outputs=[sr_output, style_output, depth_output]
)
app.launch(debug=False)