sculpt / app.py
ds1david's picture
New logic
3920f5c
raw
history blame
4.31 kB
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
import pickle
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from model import build_thera # Importar do código original do Thera
# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
# 1. Carregar modelos do Thera ------------------------------------------------------------------
def load_thera_model(repo_id, filename):
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
params, backbone, size = check['model'], check['backbone'], check['size']
model = build_thera(3, backbone, size)
return model, params
print("Carregando Thera EDSR...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# 2. Carregar SDXL + LoRA ----------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade -----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
# Pipeline principal ---------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
try:
# 1. Super Resolução com Thera
source = np.array(image.convert("RGB")) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
# Converter para JAX array
source_jax = jax.device_put(source, JAX_DEVICE)
# Processar com Thera
upscaled = model_edsr.apply(
params_edsr,
source_jax,
target_shape,
do_ensemble=True
)
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
# 2. Gerar Bas-Relief
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
return upscaled_pil, bas_relief, (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
except Exception as e:
raise gr.Error(f"Erro no processamento: {str(e)}")
# Interface Gradio -----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution.", label="Descrição")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Super Resolvida")
img_basrelief = gr.Image(label="Bas-Relief")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch()