File size: 4,198 Bytes
98889c8 1665fe1 98889c8 1665fe1 98889c8 1665fe1 98889c8 1665fe1 98889c8 1665fe1 eb1615e 1665fe1 98889c8 8ddec5a 1665fe1 8ddec5a 1665fe1 98889c8 1665fe1 3dd7500 1665fe1 3dd7500 1665fe1 75ebedf 98889c8 3dd7500 75ebedf 3dd7500 75ebedf 1665fe1 98889c8 3dd7500 1665fe1 3dd7500 1665fe1 98889c8 1665fe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from peft import PeftModel
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from torchvision import transforms
# Configurações iniciais
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# --- Carregamento dos Modelos ---
# 1. Thera: Super Resolução
def load_thera_model():
# Modelo hipotético - ajuste conforme implementação real do Thera
model = torch.hub.load('prs-eth/thera', 'thera', trust_repo=True)
return model.to(DEVICE)
# 2. Depth Map com PEFT
def load_depth_model():
base_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
model = PeftModel.from_pretrained(base_model, "danube2024/dpt-peft-lora")
return model.to(DEVICE).eval()
# 3. Bas-Relief com ControlNet
def load_controlnet():
controlnet = ControlNetModel.from_pretrained(
"danube2024/controlnet-bas-relief",
torch_dtype=TORCH_DTYPE
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
torch_dtype=TORCH_DTYPE
)
pipe.load_lora_weights("danube2024/bas-relief-lora")
return pipe.to(DEVICE)
# --- Processamento ---
def run_thera(image, model):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
input_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5)
return output_img
def create_depth_map(image, model, feature_extractor):
inputs = feature_extractor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
return prediction.squeeze().cpu().numpy()
def create_bas_relief(prompt, image, depth_map, pipe):
control_image = Image.fromarray((depth_map * 255).astype(np.uint8))
image = image.resize((1024, 1024))
control_image = control_image.resize((1024, 1024))
result = pipe(
prompt=prompt,
image=image,
control_image=control_image,
strength=0.8,
num_inference_steps=30
).images[0]
return result
# --- Interface Gradio ---
with gr.Blocks() as app:
gr.Markdown("# 🖼️ Super Resolução + Depth Map + Bas-Relief")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Imagem de Entrada")
prompt = gr.Textbox("high quality bas-relief sculpture, intricate details")
submit_btn = gr.Button("Processar")
with gr.Column():
upscaled_output = gr.Image(label="Imagem Super Resolvida")
depth_output = gr.Image(label="Mapa de Profundidade")
basrelief_output = gr.Image(label="Resultado Bas-Relief")
def process(image, prompt):
# Carregar modelos
thera_model = load_thera_model()
depth_model = load_depth_model()
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
basrelief_pipe = load_controlnet()
# 1. Super Resolução
upscaled = run_thera(image, thera_model)
# 2. Depth Map
depth = create_depth_map(upscaled, depth_model, feature_extractor)
depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
# 3. Bas-Relief
basrelief = create_bas_relief(prompt, upscaled, depth_normalized, basrelief_pipe)
return upscaled, depth_normalized, basrelief
submit_btn.click(
process,
inputs=[input_image, prompt],
outputs=[upscaled_output, depth_output, basrelief_output]
)
if __name__ == "__main__":
app.launch() |