|
import gradio as gr |
|
import torch |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from PIL import Image |
|
import pickle |
|
import warnings |
|
from huggingface_hub import hf_hub_download |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTImageProcessor, DPTForDepthEstimation |
|
from model import build_thera |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
JAX_DEVICE = jax.devices("cpu")[0] |
|
TORCH_DEVICE = "cpu" |
|
|
|
|
|
|
|
def load_thera_model(repo_id, filename): |
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
with open(model_path, 'rb') as fh: |
|
check = pickle.load(fh) |
|
|
|
variables = check['model'] |
|
backbone, size = check['backbone'], check['size'] |
|
model = build_thera(3, backbone, size) |
|
return model, variables |
|
|
|
|
|
print("Carregando Thera EDSR...") |
|
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl") |
|
|
|
|
|
print("Carregando SDXL + LoRA...") |
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float32 |
|
).to(TORCH_DEVICE) |
|
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors") |
|
|
|
|
|
print("Carregando DPT Depth...") |
|
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large") |
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE) |
|
|
|
|
|
|
|
def full_pipeline(image, prompt, scale_factor=2.0): |
|
try: |
|
|
|
image = image.convert("RGB") |
|
source = np.array(image) / 255.0 |
|
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor)) |
|
|
|
source_jax = jax.device_put(source, JAX_DEVICE) |
|
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32) |
|
|
|
|
|
upscaled = model_edsr.apply( |
|
variables_edsr, |
|
source_jax, |
|
t, |
|
target_shape |
|
) |
|
|
|
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8)) |
|
|
|
|
|
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution" |
|
bas_relief = pipe( |
|
prompt=full_prompt, |
|
image=upscaled_pil, |
|
strength=0.7, |
|
num_inference_steps=25, |
|
guidance_scale=7.5 |
|
).images[0] |
|
|
|
|
|
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE) |
|
with torch.no_grad(): |
|
outputs = depth_model(**inputs) |
|
depth = outputs.predicted_depth |
|
|
|
depth_map = torch.nn.functional.interpolate( |
|
depth.unsqueeze(1), |
|
size=bas_relief.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) |
|
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8)) |
|
|
|
return upscaled_pil, bas_relief, depth_pil |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Erro no processamento: {str(e)}") |
|
|
|
|
|
|
|
with gr.Blocks(title="Super Res + Bas-Relief") as app: |
|
gr.Markdown("## 🔍 Super Resolução + 🗿 Bas-Relief + 🗺️ Profundidade") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type="pil", label="Imagem de Entrada") |
|
prompt = gr.Textbox( |
|
label="Descrição do Relevo", |
|
value="insanely detailed and complex engraving relief, ultra-high definition, rich in detail, and 16K resolution." |
|
) |
|
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") |
|
btn = gr.Button("Processar") |
|
|
|
with gr.Column(): |
|
img_upscaled = gr.Image(label="Imagem Super Resolvida") |
|
img_basrelief = gr.Image(label="Resultado Bas-Relief") |
|
img_depth = gr.Image(label="Mapa de Profundidade") |
|
|
|
btn.click( |
|
full_pipeline, |
|
inputs=[img_input, prompt, scale], |
|
outputs=[img_upscaled, img_basrelief, img_depth] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch(share=False) |