File size: 6,910 Bytes
19a6d73 98889c8 19a6d73 1eb87a5 a7111d1 98889c8 1665fe1 19a6d73 d160dc6 19a6d73 a7111d1 19a6d73 a7111d1 46bb495 19a6d73 46bb495 d160dc6 3920f5c 1eb87a5 19a6d73 1557411 46bb495 19a6d73 1557411 19a6d73 46bb495 19a6d73 46bb495 3920f5c 1557411 19a6d73 1557411 19a6d73 1557411 19a6d73 1557411 19a6d73 1557411 19a6d73 1557411 19a6d73 1eb87a5 1557411 d85fde4 19a6d73 1557411 3920f5c 1557411 a7111d1 19a6d73 1557411 3920f5c 1557411 d160dc6 19a6d73 1557411 d85fde4 1557411 19a6d73 1557411 19a6d73 1557411 a7111d1 1557411 19a6d73 3920f5c 1557411 19a6d73 1557411 0652978 1557411 19a6d73 3920f5c 1557411 19a6d73 1557411 3920f5c 1557411 3920f5c d160dc6 3920f5c 1557411 3920f5c 1557411 19a6d73 1557411 a7111d1 19a6d73 3920f5c 19a6d73 1557411 3920f5c 1557411 19a6d73 1665fe1 1557411 19a6d73 1557411 19a6d73 98889c8 19a6d73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# app.py
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import logging
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
from utils import make_grid, interpolate_grid
# Configuração de logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Configurações
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
def load_thera_model(repo_id: str, filename: str):
"""Carrega modelo com múltiplas verificações"""
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
checkpoint = pickle.load(fh)
# Verificar estrutura do checkpoint
required_keys = {'model', 'backbone', 'size'}
if not required_keys.issubset(checkpoint.keys()):
missing = required_keys - checkpoint.keys()
raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}")
return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
except Exception as e:
logger.error(f"Erro ao carregar modelo: {str(e)}")
raise
# Inicialização segura
try:
logger.info("Inicializando modelos...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# Pipeline SDXL
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# Modelo de profundidade
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
except Exception as e:
logger.error(f"Falha crítica na inicialização: {str(e)}")
raise
def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]:
"""Calcula tamanho garantindo estabilidade numérica"""
h, w = original
new_h = int(h * scale)
new_w = int(w * scale)
# Ajustar para múltiplo de 8
new_h = max(32, new_h - new_h % 8)
new_w = max(32, new_w - new_w % 8)
return (new_h, new_w)
def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
"""Pipeline completo com tratamento de erros robusto"""
try:
# Verificação inicial
if not image:
raise ValueError("Nenhuma imagem fornecida")
# Conversão segura para RGB
image = image.convert("RGB")
orig_w, orig_h = image.size
logger.info(f"Processando imagem: {orig_w}x{orig_h}")
# Cálculo do novo tamanho
new_h, new_w = safe_resize((orig_h, orig_w), scale_factor)
logger.info(f"Novo tamanho calculado: {new_h}x{new_w}")
# Gerar grid de coordenadas
grid = make_grid((new_h, new_w))
logger.debug(f"Grid gerado: {grid.shape}")
# Verificação crítica do grid
if grid.shape[1:3] != (new_h, new_w):
raise RuntimeError(
f"Incompatibilidade de dimensões: "
f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}"
)
# Pré-processamento da imagem
source = jnp.array(image).astype(jnp.float32) / 255.0
source = source[jnp.newaxis, ...] # Adicionar dimensão de batch
# Parâmetro de escala
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
# Processamento Thera
upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
# Conversão para PIL
upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
logger.info(f"Imagem super-resolvida: {upscaled_img.size}")
# Geração do Bas-Relief
result = pipe(
prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
image=upscaled_img,
strength=0.7,
num_inference_steps=30,
guidance_scale=7.5
)
bas_relief = result.images[0]
logger.info(f"Bas-Relief gerado: {bas_relief.size}")
# Cálculo da profundidade
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
depth = depth_model(**inputs).predicted_depth
# Redimensionamento
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
# Normalização e conversão
depth_min = depth_map.min()
depth_max = depth_map.max()
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
logger.info("Mapa de profundidade calculado")
return upscaled_img, bas_relief, depth_img
except Exception as e:
logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
raise gr.Error(f"Falha no processamento: {str(e)}")
# Interface Gradio
with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
with gr.Row():
input_col = gr.Column()
output_col = gr.Column()
with input_col:
img_input = gr.Image(label="Carregar Imagem", type="pil", height=300)
prompt = gr.Textbox(
label="Descrição do Relevo",
value="A insanely detailed and complex engraving relief, ultra-high definition",
placeholder="Descreva o estilo desejado..."
)
scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala")
process_btn = gr.Button("Iniciar Processamento", variant="primary")
with output_col:
with gr.Tabs():
with gr.TabItem("Super Resolução"):
upscaled_output = gr.Image(label="Resultado", show_label=False)
with gr.TabItem("Bas-Relief"):
basrelief_output = gr.Image(label="Relevo", show_label=False)
with gr.TabItem("Profundidade"):
depth_output = gr.Image(label="Mapa 3D", show_label=False)
process_btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[upscaled_output, basrelief_output, depth_output],
api_name="processar"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860) |