|
|
|
import gradio as gr |
|
import torch |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from PIL import Image |
|
import pickle |
|
import logging |
|
from huggingface_hub import hf_hub_download |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTImageProcessor, DPTForDepthEstimation |
|
from model import build_thera |
|
from utils import make_grid, interpolate_grid |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
JAX_DEVICE = jax.devices("cpu")[0] |
|
TORCH_DEVICE = "cpu" |
|
|
|
|
|
def load_thera_model(repo_id: str, filename: str): |
|
"""Carrega modelo com verificação de segurança""" |
|
try: |
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
with open(model_path, 'rb') as fh: |
|
checkpoint = pickle.load(fh) |
|
return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model'] |
|
except Exception as e: |
|
logger.error(f"Erro ao carregar modelo: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
try: |
|
logger.info("Carregando modelos...") |
|
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl") |
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float32 |
|
).to(TORCH_DEVICE) |
|
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors") |
|
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large") |
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE) |
|
except Exception as e: |
|
logger.error(f"Falha na inicialização: {str(e)}") |
|
raise |
|
|
|
|
|
def adjust_size(original: int, scale: float) -> int: |
|
"""Ajuste de tamanho com limites seguros""" |
|
scaled = int(original * scale) |
|
adjusted = (scaled // 8) * 8 |
|
return max(32, adjusted) |
|
|
|
|
|
def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0): |
|
"""Pipeline completo com tratamento robusto""" |
|
try: |
|
|
|
image = image.convert("RGB") |
|
orig_w, orig_h = image.size |
|
|
|
|
|
new_h = adjust_size(orig_h, scale_factor) |
|
new_w = adjust_size(orig_w, scale_factor) |
|
logger.info(f"Redimensionando: {orig_h}x{orig_w} → {new_h}x{new_w}") |
|
|
|
|
|
coords = make_grid((new_h, new_w)) |
|
logger.debug(f"Dimensões do grid: {coords.shape}") |
|
|
|
|
|
if coords.shape[1:3] != (new_h, new_w): |
|
raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})") |
|
|
|
|
|
source = jnp.array(image).astype(jnp.float32) / 255.0 |
|
source = source[jnp.newaxis, ...] |
|
|
|
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32) |
|
upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w)) |
|
|
|
|
|
upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8)) |
|
|
|
|
|
result = pipe( |
|
prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution", |
|
image=upscaled_img, |
|
strength=0.7, |
|
num_inference_steps=30 |
|
) |
|
bas_relief = result.images[0] |
|
|
|
|
|
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE) |
|
with torch.no_grad(): |
|
depth = depth_model(**inputs).predicted_depth |
|
|
|
depth_map = torch.nn.functional.interpolate( |
|
depth.unsqueeze(1), |
|
size=bas_relief.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
|
|
depth_min = depth_map.min() |
|
depth_max = depth_map.max() |
|
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8) |
|
depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8)) |
|
|
|
return upscaled_img, bas_relief, depth_img |
|
|
|
except Exception as e: |
|
logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True) |
|
raise gr.Error(f"Processamento falhou: {str(e)}") |
|
|
|
|
|
|
|
with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app: |
|
gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(label="Imagem de Entrada", type="pil") |
|
prompt = gr.Textbox( |
|
label="Descrição do Relevo", |
|
value="Ainsanely detailed and complex engraving relief, ultra-high definition", |
|
placeholder="Descreva o estilo desejado..." |
|
) |
|
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") |
|
btn = gr.Button("Processar Imagem", variant="primary") |
|
|
|
with gr.Column(): |
|
gr.Markdown("## Resultados") |
|
with gr.Tabs(): |
|
with gr.TabItem("Super Resolução"): |
|
upscaled_output = gr.Image(label="Resultado Super Resolução") |
|
with gr.TabItem("Bas-Relief"): |
|
basrelief_output = gr.Image(label="Relevo Gerado") |
|
with gr.TabItem("Profundidade"): |
|
depth_output = gr.Image(label="Mapa de Profundidade") |
|
|
|
btn.click( |
|
full_pipeline, |
|
inputs=[img_input, prompt, scale], |
|
outputs=[upscaled_output, basrelief_output, depth_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", server_port=7860) |