sculpt / app.py
ds1david's picture
fixing bugs
1557411
raw
history blame
6.91 kB
# app.py
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import logging
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
from utils import make_grid, interpolate_grid
# Configuração de logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Configurações
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
def load_thera_model(repo_id: str, filename: str):
"""Carrega modelo com múltiplas verificações"""
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
checkpoint = pickle.load(fh)
# Verificar estrutura do checkpoint
required_keys = {'model', 'backbone', 'size'}
if not required_keys.issubset(checkpoint.keys()):
missing = required_keys - checkpoint.keys()
raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}")
return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
except Exception as e:
logger.error(f"Erro ao carregar modelo: {str(e)}")
raise
# Inicialização segura
try:
logger.info("Inicializando modelos...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# Pipeline SDXL
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# Modelo de profundidade
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
except Exception as e:
logger.error(f"Falha crítica na inicialização: {str(e)}")
raise
def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]:
"""Calcula tamanho garantindo estabilidade numérica"""
h, w = original
new_h = int(h * scale)
new_w = int(w * scale)
# Ajustar para múltiplo de 8
new_h = max(32, new_h - new_h % 8)
new_w = max(32, new_w - new_w % 8)
return (new_h, new_w)
def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
"""Pipeline completo com tratamento de erros robusto"""
try:
# Verificação inicial
if not image:
raise ValueError("Nenhuma imagem fornecida")
# Conversão segura para RGB
image = image.convert("RGB")
orig_w, orig_h = image.size
logger.info(f"Processando imagem: {orig_w}x{orig_h}")
# Cálculo do novo tamanho
new_h, new_w = safe_resize((orig_h, orig_w), scale_factor)
logger.info(f"Novo tamanho calculado: {new_h}x{new_w}")
# Gerar grid de coordenadas
grid = make_grid((new_h, new_w))
logger.debug(f"Grid gerado: {grid.shape}")
# Verificação crítica do grid
if grid.shape[1:3] != (new_h, new_w):
raise RuntimeError(
f"Incompatibilidade de dimensões: "
f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}"
)
# Pré-processamento da imagem
source = jnp.array(image).astype(jnp.float32) / 255.0
source = source[jnp.newaxis, ...] # Adicionar dimensão de batch
# Parâmetro de escala
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
# Processamento Thera
upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
# Conversão para PIL
upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
logger.info(f"Imagem super-resolvida: {upscaled_img.size}")
# Geração do Bas-Relief
result = pipe(
prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
image=upscaled_img,
strength=0.7,
num_inference_steps=30,
guidance_scale=7.5
)
bas_relief = result.images[0]
logger.info(f"Bas-Relief gerado: {bas_relief.size}")
# Cálculo da profundidade
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
depth = depth_model(**inputs).predicted_depth
# Redimensionamento
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
# Normalização e conversão
depth_min = depth_map.min()
depth_max = depth_map.max()
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
logger.info("Mapa de profundidade calculado")
return upscaled_img, bas_relief, depth_img
except Exception as e:
logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
raise gr.Error(f"Falha no processamento: {str(e)}")
# Interface Gradio
with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
with gr.Row():
input_col = gr.Column()
output_col = gr.Column()
with input_col:
img_input = gr.Image(label="Carregar Imagem", type="pil", height=300)
prompt = gr.Textbox(
label="Descrição do Relevo",
value="A insanely detailed and complex engraving relief, ultra-high definition",
placeholder="Descreva o estilo desejado..."
)
scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala")
process_btn = gr.Button("Iniciar Processamento", variant="primary")
with output_col:
with gr.Tabs():
with gr.TabItem("Super Resolução"):
upscaled_output = gr.Image(label="Resultado", show_label=False)
with gr.TabItem("Bas-Relief"):
basrelief_output = gr.Image(label="Relevo", show_label=False)
with gr.TabItem("Profundidade"):
depth_output = gr.Image(label="Mapa 3D", show_label=False)
process_btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[upscaled_output, basrelief_output, depth_output],
api_name="processar"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)