import gradio as gr import torch import numpy as np import jax import pickle from PIL import Image from huggingface_hub import hf_hub_download, file_download from model import build_thera from super_resolve import process from diffusers import StableDiffusionXLImg2ImgPipeline from transformers import DPTFeatureExtractor, DPTForDepthEstimation # Fix de compatibilidade file_download.cached_download = file_download.hf_hub_download # ========== Configuração do Thera ========== REPO_ID_EDSR = "prs-eth/thera-edsr-pro" REPO_ID_RDN = "prs-eth/thera-rdn-pro" def load_thera_model(repo_id): model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") with open(model_path, 'rb') as fh: check = pickle.load(fh) return build_thera(3, check['backbone'], check['size']), check['model'] model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR) model_rdn, params_rdn = load_thera_model(REPO_ID_RDN) # ========== Configuração do SDXL + Depth ========== device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch_dtype ).to(device) pipe.load_lora_weights( "KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors", peft_backend="peft" ) feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) # ========== Fluxo Integrado ========== def full_pipeline(image, scale_factor, model_type, style_prompt): # 1. Super-Resolution sr_model = model_edsr if model_type == "EDSR" else model_rdn sr_params = params_edsr if model_type == "EDSR" else params_rdn sr_image = process(np.array(image) / 255., sr_model, sr_params, (round(image.size[1] * scale_factor), round(image.size[0] * scale_factor)), True) # 2. Bas-Relief Style Transfer prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture" bas_relief = pipe( prompt=prompt, image=sr_image, strength=0.6, num_inference_steps=25, guidance_scale=7.5 ).images[0] # 3. Depth Map Estimation inputs = feature_extractor(bas_relief, return_tensors="pt").to(device) with torch.no_grad(): outputs = depth_model(**inputs) depth = outputs.predicted_depth depth = torch.nn.functional.interpolate( depth.unsqueeze(1), size=bas_relief.size[::-1], mode="bicubic" ).squeeze().cpu().numpy() depth = (depth - depth.min()) / (depth.max() - depth.min()) depth = (depth * 255).astype(np.uint8) return sr_image, bas_relief, Image.fromarray(depth) # ========== Interface Gradio ========== with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app: gr.Markdown("## 🪄 Super-Resolution → Bas-Relief → Depth Map") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor") model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model") style_prompt = gr.Textbox(label="Style Prompt", placeholder="marble sculpture, ancient greek style") process_btn = gr.Button("Start Pipeline") with gr.Column(): sr_output = gr.Image(label="Super-Resolution Result") style_output = gr.Image(label="Bas-Relief Result") depth_output = gr.Image(label="Depth Map") process_btn.click( full_pipeline, inputs=[input_image, scale, model_type, style_prompt], outputs=[sr_output, style_output, depth_output] ) app.launch(debug=False)