File size: 5,819 Bytes
19a6d73 98889c8 19a6d73 1eb87a5 a7111d1 98889c8 1665fe1 19a6d73 d160dc6 19a6d73 a7111d1 19a6d73 a7111d1 46bb495 19a6d73 46bb495 d160dc6 3920f5c 1eb87a5 19a6d73 46bb495 19a6d73 46bb495 19a6d73 46bb495 3920f5c 19a6d73 1eb87a5 19a6d73 d85fde4 19a6d73 3920f5c 19a6d73 a7111d1 19a6d73 3920f5c 19a6d73 d160dc6 19a6d73 d85fde4 d160dc6 19a6d73 a7111d1 19a6d73 3920f5c 19a6d73 0652978 d160dc6 19a6d73 3920f5c 19a6d73 3920f5c 19a6d73 3920f5c d160dc6 3920f5c 19a6d73 a7111d1 19a6d73 3920f5c 19a6d73 3920f5c d160dc6 19a6d73 1665fe1 19a6d73 1665fe1 19a6d73 98889c8 19a6d73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# app.py
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import logging
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
from utils import make_grid, interpolate_grid
# Configuração de logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Configurações
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
def load_thera_model(repo_id: str, filename: str):
"""Carrega modelo com verificação de segurança"""
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
checkpoint = pickle.load(fh)
return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
except Exception as e:
logger.error(f"Erro ao carregar modelo: {str(e)}")
raise
# Inicialização dos modelos
try:
logger.info("Carregando modelos...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
except Exception as e:
logger.error(f"Falha na inicialização: {str(e)}")
raise
def adjust_size(original: int, scale: float) -> int:
"""Ajuste de tamanho com limites seguros"""
scaled = int(original * scale)
adjusted = (scaled // 8) * 8 # Divisível por 8
return max(32, adjusted) # Mínimo absoluto
def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
"""Pipeline completo com tratamento robusto"""
try:
# Pré-processamento
image = image.convert("RGB")
orig_w, orig_h = image.size
# Cálculo do tamanho alvo
new_h = adjust_size(orig_h, scale_factor)
new_w = adjust_size(orig_w, scale_factor)
logger.info(f"Redimensionando: {orig_h}x{orig_w} → {new_h}x{new_w}")
# Gerar grid de coordenadas
coords = make_grid((new_h, new_w))
logger.debug(f"Dimensões do grid: {coords.shape}")
# Verificação crítica
if coords.shape[1:3] != (new_h, new_w):
raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})")
# Super-resolução
source = jnp.array(image).astype(jnp.float32) / 255.0
source = source[jnp.newaxis, ...] # Adicionar batch
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
# Pós-processamento
upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
# Bas-Relief
result = pipe(
prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
image=upscaled_img,
strength=0.7,
num_inference_steps=30
)
bas_relief = result.images[0]
# Mapa de profundidade
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
depth = depth_model(**inputs).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
# Normalização
depth_min = depth_map.min()
depth_max = depth_map.max()
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
return upscaled_img, bas_relief, depth_img
except Exception as e:
logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
raise gr.Error(f"Processamento falhou: {str(e)}")
# Interface
with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app:
gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="Imagem de Entrada", type="pil")
prompt = gr.Textbox(
label="Descrição do Relevo",
value="Ainsanely detailed and complex engraving relief, ultra-high definition",
placeholder="Descreva o estilo desejado..."
)
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar Imagem", variant="primary")
with gr.Column():
gr.Markdown("## Resultados")
with gr.Tabs():
with gr.TabItem("Super Resolução"):
upscaled_output = gr.Image(label="Resultado Super Resolução")
with gr.TabItem("Bas-Relief"):
basrelief_output = gr.Image(label="Relevo Gerado")
with gr.TabItem("Profundidade"):
depth_output = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[upscaled_output, basrelief_output, depth_output]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860) |