sculpt / app.py
ds1david's picture
New logic
1665fe1
raw
history blame
4.2 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from peft import PeftModel
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from torchvision import transforms
# Configurações iniciais
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# --- Carregamento dos Modelos ---
# 1. Thera: Super Resolução
def load_thera_model():
# Modelo hipotético - ajuste conforme implementação real do Thera
model = torch.hub.load('prs-eth/thera', 'thera', trust_repo=True)
return model.to(DEVICE)
# 2. Depth Map com PEFT
def load_depth_model():
base_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
model = PeftModel.from_pretrained(base_model, "danube2024/dpt-peft-lora")
return model.to(DEVICE).eval()
# 3. Bas-Relief com ControlNet
def load_controlnet():
controlnet = ControlNetModel.from_pretrained(
"danube2024/controlnet-bas-relief",
torch_dtype=TORCH_DTYPE
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=TORCH_DTYPE
)
pipe.load_lora_weights("danube2024/bas-relief-lora")
return pipe.to(DEVICE)
# --- Processamento ---
def run_thera(image, model):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
return output_img
def create_depth_map(image, model, feature_extractor):
inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
return prediction.squeeze().cpu().numpy()
def create_bas_relief(prompt, image, depth_map, pipe):
control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
image = image.resize((1024, 1024))
control_image = control_image.resize((1024, 1024))
result = pipe(
prompt=prompt,
image=image,
control_image=control_image,
strength=0.8,
num_inference_steps=30
).images[0]
return result
# --- Interface Gradio ---
with gr.Blocks() as app:
gr.Markdown("# 🖼️ Super Resolução + Depth Map + Bas-Relief")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("high quality bas-relief sculpture, intricate details")
submit_btn = gr.Button("Processar")
with gr.Column():
upscaled_output = gr.Image(label="Imagem Super Resolvida")
depth_output = gr.Image(label="Mapa de Profundidade")
basrelief_output = gr.Image(label="Resultado Bas-Relief")
def process(image, prompt):
# Carregar modelos
thera_model = load_thera_model()
depth_model = load_depth_model()
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
basrelief_pipe = load_controlnet()
# 1. Super Resolução
upscaled = run_thera(image, thera_model)
# 2. Depth Map
depth = create_depth_map(upscaled, depth_model, feature_extractor)
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
# 3. Bas-Relief
basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
return upscaled, depth_normalized, basrelief
submit_btn.click(
process,
inputs=[input_image, prompt],
outputs=[upscaled_output, depth_output, basrelief_output]
)
if __name__ == "__main__":
app.launch()