sculpt / app.py
ds1david's picture
fixing bugs
a09dd26
# app.py
import logging
import gradio as gr
import torch
import numpy as np
import jax
import pickle
from PIL import Image
from huggingface_hub import hf_hub_download
from model import build_thera
from super_resolve import process
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
# ================== CONFIGURAÇÃO DE LOGGING ==================
class CustomLogger:
def __init__(self, name):
self.logger = logging.getLogger(name)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
def divider(self, text=None, length=60):
if text:
available_space = max(length - len(text) - 12, 1)
msg = f"{'=' * 10} {text.upper()} {'=' * available_space}"
else:
msg = "=" * length
self.logger.info(msg)
def etapa(self, text):
self.logger.info(f"▶ {text}")
def success(self, text):
self.logger.info(f"✓ {text}")
def error(self, text):
self.logger.error(f"✗ {text}")
logger = CustomLogger(__name__)
# ================== CONFIGURAÇÃO FORÇADA ==================
device = "cpu"
torch_dtype = torch.float32
logger.divider("Configuração Forçada")
logger.success(f"Dispositivo: {device.upper()}")
logger.success(f"Precisão: {str(torch_dtype).replace('torch.', '')}")
# ================== CARREGAMENTO DE MODELOS ==================
def carregar_modelo_thera(repo_id):
try:
logger.divider(f"Carregando {repo_id}")
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl")
with open(model_path, 'rb') as f:
check = pickle.load(f)
model = build_thera(3, check['backbone'], check['size'])
params = check['model']
logger.success(f"{repo_id} carregado")
return model, params
except Exception as e:
logger.error(f"Falha no carregamento: {str(e)}")
return None, None
try:
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro")
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro")
except Exception as e:
logger.error("Falha crítica nos modelos Thera")
raise
# ================== PIPELINE ARTÍSTICO ==================
pipe = None
modelo_profundidade = None
try:
logger.divider("Configurando Componentes Artísticos")
# Pipeline principal
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype,
variant="fp32"
).to(device)
# LoRA
pipe.load_lora_weights(
"KappaNeuro/bas-relief",
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft" # This is crucial
)
# Modelo de profundidade
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).float()
logger.success("Componentes artísticos em float32")
except Exception as e:
logger.warning(f"Recursos artísticos limitados: {str(e)}")
print(e)
pipe = None
# ================== PROCESSAMENTO PRINCIPAL ==================
def processar_imagem_completa(imagem, escala, modelo, prompt):
try:
logger.divider("Iniciando Processamento")
# Converter entrada
if not isinstance(imagem, Image.Image):
imagem = Image.fromarray(imagem)
# ========= 1. SUPER-RESOLUÇÃO =========
logger.etapa("Processando Super-Resolução")
modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn
params_sr = params_edsr if modelo == "EDSR" else params_rdn
sr_jax = process(
np.array(imagem) / 255.,
modelo_sr,
params_sr,
(round(imagem.height * escala),
round(imagem.width * escala)),
True
)
sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB")
logger.success(f"SR: {sr_pil.size[0]}x{sr_pil.size[1]}")
# ========= 2. ESTILO BAIXO-RELEVO =========
arte_pil = sr_pil # Fallback
if pipe:
try:
logger.etapa("Aplicando Estilo")
arte_pil = pipe(
prompt=f"BAS-RELIEF {prompt}, marble texture, 8k",
image=sr_pil,
strength=0.6,
num_inference_steps=25,
guidance_scale=7.0,
generator=torch.Generator(device).manual_seed(42)
).images[0]
logger.success("Estilo aplicado")
except Exception as e:
logger.error(f"Erro no estilo: {str(e)}")
print(e)
# ========= 3. MAPA DE PROFUNDIDADE =========
mapa_pil = arte_pil # Fallback
if modelo_profundidade:
try:
logger.etapa("Calculando Profundidade")
inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device)
with torch.no_grad():
depth = modelo_profundidade(**inputs).predicted_depth
depth = torch.nn.functional.interpolate(
depth.unsqueeze(1),
size=arte_pil.size[::-1],
mode="bicubic"
).squeeze().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min())
mapa_pil = Image.fromarray((depth * 255).astype(np.uint8))
logger.success("Profundidade calculada")
except Exception as e:
logger.error(f"Erro na profundidade: {str(e)}")
print (e)
return sr_pil, arte_pil, mapa_pil
except Exception as e:
logger.error(f"Erro fatal: {str(e)}")
print(e)
return None, None, None
# ================== INTERFACE GRADIO ==================
with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🏛 TheraSR - Processamento Completo em Float32")
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Imagem de Entrada", type="pil")
slider_scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
radio_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo")
text_prompt = gr.Textbox(
label="Prompt de Estilo",
value="ancient marble浮雕, ultra detailed, 8k cinematic"
)
btn_process = gr.Button("Processar", variant="primary")
with gr.Column():
output_sr = gr.Image(label="Super-Resolução", interactive=False)
output_art = gr.Image(label="Arte em Relevo", interactive=False)
output_depth = gr.Image(label="Mapa de Profundidade", interactive=False)
btn_process.click(
processar_imagem_completa,
inputs=[input_img, slider_scale, radio_model, text_prompt],
outputs=[output_sr, output_art, output_depth]
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)