File size: 4,311 Bytes
98889c8 1eb87a5 98889c8 1665fe1 3920f5c 1eb87a5 98889c8 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1eb87a5 3920f5c 1665fe1 1eb87a5 3920f5c 1eb87a5 1665fe1 3920f5c 1eb87a5 1665fe1 98889c8 1665fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import torch
import jax
import numpy as np
from PIL import Image
import pickle
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from model import build_thera # Importar do código original do Thera
# Configurar dispositivos
JAX_DEVICE = jax.devices("cpu")[0]
TORCH_DEVICE = "cpu"
# 1. Carregar modelos do Thera ------------------------------------------------------------------
def load_thera_model(repo_id, filename):
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
params, backbone, size = check['model'], check['backbone'], check['size']
model = build_thera(3, backbone, size)
return model, params
print("Carregando Thera EDSR...")
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
# 2. Carregar SDXL + LoRA ----------------------------------------------------------------------
print("Carregando SDXL + LoRA...")
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float32
).to(TORCH_DEVICE)
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
# 3. Carregar modelo de profundidade -----------------------------------------------------------
print("Carregando DPT Depth...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
# Pipeline principal ---------------------------------------------------------------------------
def full_pipeline(image, prompt, scale_factor=2.0):
try:
# 1. Super Resolução com Thera
source = np.array(image.convert("RGB")) / 255.0
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
# Converter para JAX array
source_jax = jax.device_put(source, JAX_DEVICE)
# Processar com Thera
upscaled = model_edsr.apply(
params_edsr,
source_jax,
target_shape,
do_ensemble=True
)
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
# 2. Gerar Bas-Relief
full_prompt = f"BAS-RELIEF {prompt}, intricate carving, marble relief"
bas_relief = pipe(
prompt=full_prompt,
image=upscaled_pil,
strength=0.7,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Calcular Depth Map
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=bas_relief.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
return upscaled_pil, bas_relief, (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
except Exception as e:
raise gr.Error(f"Erro no processamento: {str(e)}")
# Interface Gradio -----------------------------------------------------------------------------
with gr.Blocks(title="Super Res + Bas-Relief") as app:
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution.", label="Descrição")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
btn = gr.Button("Processar")
with gr.Column():
img_upscaled = gr.Image(label="Super Resolvida")
img_basrelief = gr.Image(label="Bas-Relief")
img_depth = gr.Image(label="Mapa de Profundidade")
btn.click(
full_pipeline,
inputs=[img_input, prompt, scale],
outputs=[img_upscaled, img_basrelief, img_depth]
)
if __name__ == "__main__":
app.launch() |