File size: 4,950 Bytes
98889c8 1eb87a5 a7111d1 98889c8 1665fe1 3920f5c a7111d1 3920f5c 1eb87a5 a7111d1 1eb87a5 3920f5c 1eb87a5 a7111d1 3920f5c a02c6d7 42a2e7b 3920f5c a02c6d7 3920f5c a02c6d7 3920f5c a7111d1 3920f5c 1eb87a5 3920f5c 1eb87a5 a7111d1 3920f5c a7111d1 3920f5c 1eb87a5 a7111d1 1eb87a5 3920f5c a7111d1 3920f5c a7111d1 3920f5c a02c6d7 3920f5c a02c6d7 3920f5c a7111d1 0652978 3920f5c 0652978 3920f5c a7111d1 3920f5c a7111d1 3920f5c a7111d1 3920f5c 1665fe1 1eb87a5 a7111d1 1eb87a5 1665fe1 a7111d1 1eb87a5 1665fe1 98889c8 eb02bc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import torch
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image
import pickle
import warnings
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTImageProcessor, DPTForDepthEstimation
from model import build_thera
# Configurações e supressão de avisos
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
# 1. Carregar modelos do Thera ----------------------------------------------------------------
def load_thera_model(repo_id, filename):
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
# Carregar estrutura completa de variáveis
variables = check['model'] # Deve conter {'params': ...}
backbone, size = check['backbone'], check['size']
model = build_thera(3, backbone, size)
return model, variables
print("Carregando Thera EDSR...")
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade ----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
# Pipeline principal --------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
try:
# 1. Super Resolução com Thera
image = image.convert("RGB")
source = np.array(image) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
source_jax = jax.device_put(source, JAX_DEVICE)
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
# Chamada corrigida com estrutura de variáveis correta
upscaled = model_edsr.apply(
variables_edsr, # Estrutura completa {'params': ...}
source_jax,
t,
target_shape
)
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
# 2. Gerar Bas-Relief
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
return upscaled_pil, bas_relief, depth_pil
except Exception as e:
raise gr.Error(f"Erro no processamento: {str(e)}")
# Interface Gradio ----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox(
label="Descrição do Relevo",
value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution."
)
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Imagem Super Resolvida")
img_basrelief = gr.Image(label="Resultado Bas-Relief")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch(share=False) |