File size: 4,273 Bytes
98889c8 19a6d73 98889c8 b82dc7d 19a6d73 b82dc7d f41a4a7 a7111d1 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d f41a4a7 b82dc7d 1f384c6 b82dc7d f41a4a7 b82dc7d 46bb495 8c7829e b82dc7d f41a4a7 8c7829e f41a4a7 8c7829e f41a4a7 8c7829e f41a4a7 b82dc7d f41a4a7 b82dc7d 8c7829e b82dc7d f41a4a7 b82dc7d f41a4a7 8c7829e b82dc7d 8c7829e f41a4a7 b82dc7d 8c7829e b82dc7d f41a4a7 8c7829e f41a4a7 19a6d73 98889c8 f41a4a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download, file_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# Fix de compatibilidade
file_download.cached_download = file_download.hf_hub_download
# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"
def load_thera_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
return build_thera(3, check['backbone'], check['size']), check['model']
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
# ========== Configuração do SDXL + Depth ==========
device = "cpu"
torch_dtype = torch.float32
# device = "cuda" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft"
)
# ========== Configuração do Modelo de Profundidade ==========
depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Nome padronizado
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
# ========== Fluxo Integrado ==========
def full_pipeline(image, scale_factor, model_type, style_prompt):
# 1. Super-Resolution (JAX)
sr_model = model_edsr if model_type == "EDSR" else model_rdn
sr_params = params_edsr if model_type == "EDSR" else params_rdn
# Processar e converter para numpy array
sr_jax = process(np.array(image) / 255., sr_model, sr_params,
(round(image.size[1] * scale_factor),
round(image.size[0] * scale_factor)),
True)
# Conversão crítica: JAX Array → numpy → PIL
sr_np = np.asarray(sr_jax)
sr_pil = Image.fromarray(sr_np)
# 2. Style Transfer (PyTorch)
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
bas_relief = pipe(
prompt=prompt,
image=sr_pil, # Usar PIL Image diretamente
strength=0.6,
num_inference_steps=25,
guidance_scale=7.5
).images[0]
# 3. Depth Map
inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
depth = outputs.predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
mode="bicubic",
size=bas_relief.size[::-1]
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
depth = (depth * 255).astype(np.uint8)
return sr_pil, bas_relief, Image.fromarray(depth)
# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
style_prompt = gr.Textbox(
label="Style Prompt",
value="insanely detailed and complex engraving relief, ultra-high definition" # <-- Alteração aqui
)
process_btn = gr.Button("Start Pipeline")
with gr.Column():
sr_output = gr.Image(label="Super-Resolution Result")
style_output = gr.Image(label="Bas-Relief Result")
depth_output = gr.Image(label="Depth Map")
process_btn.click(
full_pipeline,
inputs=[input_image, scale, model_type, style_prompt],
outputs=[sr_output, style_output, depth_output]
)
app.launch(debug=False) |