sculpt / app.py
ds1david's picture
fixing bugs
b82dc7d
raw
history blame
4.26 kB
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# ========== Configuração do Thera ==========
REPO_ID_EDSR = "prs-eth/thera-edsr-pro"
REPO_ID_RDN = "prs-eth/thera-rdn-pro"
# Carregar modelos Thera
def load_thera_model(repo_id):
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as fh:
check = pickle.load(fh)
params, backbone, size = check['model'], check['backbone'], check['size']
model = build_thera(3, backbone, size)
return model, params
model_edsr, params_edsr = load_thera_model(REPO_ID_EDSR)
model_rdn, params_rdn = load_thera_model(REPO_ID_RDN)
# ========== Configuração do SDXL + Depth ==========
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
# Carregar modelos de geração
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft"
)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
# ========== Funções Principais ==========
def super_resolution(image, scale_factor, model_type):
model = model_edsr if model_type == "EDSR" else model_rdn
params = params_edsr if model_type == "EDSR" else params_rdn
source = np.asarray(image) / 255.
target_shape = (
round(source.shape[0] * scale_factor),
round(source.shape[1] * scale_factor),
)
output = process(source, model, params, target_shape, do_ensemble=True)
return Image.fromarray(np.asarray(output))
def generate_bas_relief(prompt):
full_prompt = f"BAS-RELIEF {prompt}"
image = pipe(
prompt=full_prompt,
num_inference_steps=25,
guidance_scale=7.5,
height=512,
width=512
).images[0]
inputs = feature_extractor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
depth_map = outputs.predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=image.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
return image, Image.fromarray(depth_map)
# ========== Interface Gradio ==========
with gr.Blocks(title="TheraSR + Bas-Relief Generator") as app:
gr.Markdown("# 🔥 TheraSR + Bas-Relief Generator")
gr.Markdown("Combine aliasing-free super-resolution with artistic bas-relief generation")
with gr.Tabs():
with gr.TabItem("🖼 Super-Resolution"):
with gr.Row():
sr_input = gr.Image(label="Input Image", type="pil")
sr_output = gr.Image(label="Super-Resolution Result")
sr_scale = gr.Slider(1.0, 6.0, value=2.0, label="Scale Factor")
sr_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Model Type")
sr_btn = gr.Button("Enhance Resolution")
with gr.TabItem("🎨 Generate Bas-Relief"):
with gr.Row():
text_input = gr.Textbox(label="Art Prompt", placeholder="Roman soldier marble relief...")
with gr.Row():
gen_output = gr.Image(label="Generated Art")
depth_output = gr.Image(label="Depth Map")
gen_btn = gr.Button("Generate Artwork")
# Event Handlers
sr_btn.click(
super_resolution,
inputs=[sr_input, sr_scale, sr_model],
outputs=sr_output
)
gen_btn.click(
generate_bas_relief,
inputs=text_input,
outputs=[gen_output, depth_output]
)
# Configuração do Hugging Face
app.launch(debug=False, share=True)