sculpt / app.py
ds1david's picture
New logic
a02c6d7
raw
history blame
4.95 kB
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import warnings
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
# Configurações e supressão de avisos
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
# 1. Carregar modelos do Thera ----------------------------------------------------------------
def load_thera_model(repo_id, filename):
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
# Carregar estrutura completa de variáveis
variables = check['model'] # Deve conter {'params': ...}
backbone, size = check['backbone'], check['size']
model = build_thera(3, backbone, size)
return model, variables
print("Carregando Thera EDSR...")
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade ----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
# Pipeline principal --------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
try:
# 1. Super Resolução com Thera
image = image.convert("RGB")
source = np.array(image) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
source_jax = jax.device_put(source, JAX_DEVICE)
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
# Chamada corrigida com estrutura de variáveis correta
upscaled = model_edsr.apply(
variables_edsr, # Estrutura completa {'params': ...}
source_jax,
t,
target_shape
)
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
# 2. Gerar Bas-Relief
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
return upscaled_pil, bas_relief, depth_pil
except Exception as e:
raise gr.Error(f"Erro no processamento: {str(e)}")
# Interface Gradio ----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox(
label="Descrição do Relevo",
value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
)
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Imagem Super Resolvida")
img_basrelief = gr.Image(label="Resultado Bas-Relief")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch(share=False)