|
import logging |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import jax |
|
import pickle |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download, file_download |
|
from model import build_thera |
|
from super_resolve import process |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
file_download.cached_download = file_download.hf_hub_download |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
logger.info(f"Dispositivo selecionado: {device.upper()}") |
|
logger.info(f"Precisão numérica: {str(torch_dtype).replace('torch.', '')}") |
|
|
|
|
|
|
|
def carregar_modelo_thera(repo_id): |
|
"""Carrega modelos Thera do Hugging Face Hub""" |
|
try: |
|
logger.info(f"Carregando modelo Thera: {repo_id}") |
|
caminho_modelo = hf_hub_download(repo_id=repo_id, filename="model.pkl") |
|
with open(caminho_modelo, 'rb') as arquivo: |
|
dados = pickle.load(arquivo) |
|
modelo = build_thera(3, dados['backbone'], dados['size']) |
|
parametros = dados['model'] |
|
logger.success(f"Modelo {repo_id} carregado com sucesso") |
|
return modelo, parametros |
|
except Exception as erro: |
|
logger.error(f"Falha ao carregar {repo_id}: {str(erro)}") |
|
raise |
|
|
|
|
|
|
|
try: |
|
logger.divider("Carregando Modelos Thera") |
|
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") |
|
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") |
|
except Exception as erro: |
|
logger.critical("Falha crítica no carregamento dos modelos Thera") |
|
raise |
|
|
|
|
|
|
|
try: |
|
logger.divider("Configurando Pipeline de Arte") |
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype, |
|
variant="fp16", |
|
use_safetensors=True |
|
).to(device) |
|
|
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors", |
|
adapter_name="bas_relief" |
|
) |
|
logger.success("Pipeline SDXL + LoRA configurado") |
|
except Exception as erro: |
|
logger.error(f"Erro no SDXL: {str(erro)}") |
|
pipe = None |
|
|
|
|
|
try: |
|
logger.divider("Configurando Modelo de Profundidade") |
|
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) |
|
logger.success("Modelo de profundidade pronto") |
|
except Exception as erro: |
|
logger.error(f"Erro no modelo de profundidade: {str(erro)}") |
|
modelo_profundidade = None |
|
|
|
|
|
|
|
def pipeline_completo(imagem, fator_escala, modelo_escolhido, prompt_estilo): |
|
"""Executa todo o fluxo de processamento""" |
|
try: |
|
logger.divider("Iniciando novo processamento") |
|
|
|
|
|
logger.etapa("Processando Super-Resolução") |
|
modelo_sr = modelo_edsr if modelo_escolhido == "EDSR" else modelo_rdn |
|
parametros_sr = params_edsr if modelo_escolhido == "EDSR" else params_rdn |
|
|
|
|
|
if not isinstance(imagem, Image.Image): |
|
logger.warning("Convertendo entrada numpy para PIL Image") |
|
imagem = Image.fromarray(imagem) |
|
|
|
|
|
imagem_sr_jax = process( |
|
np.array(imagem) / 255., |
|
modelo_sr, |
|
parametros_sr, |
|
(round(imagem.size[1] * fator_escala), |
|
round(imagem.size[0] * fator_escala)), |
|
True |
|
) |
|
|
|
|
|
imagem_sr_pil = Image.fromarray(np.array(imagem_sr_jax)).convert("RGB") |
|
logger.success(f"Super-Resolução concluída: {imagem_sr_pil.size}") |
|
|
|
|
|
if device == "cpu" or not pipe: |
|
logger.warning("GPU não disponível - Pulando estilo") |
|
return imagem_sr_pil, None, None |
|
|
|
logger.etapa("Aplicando Estilo Baixo-Relevo") |
|
prompt_completo = f"BAS-RELIEF {prompt_estilo}, intricate carving, marble texture, 8k" |
|
|
|
with torch.autocast(device_type=device.split(':')[0], dtype=torch_dtype): |
|
imagem_estilizada = pipe( |
|
prompt=prompt_completo, |
|
image=imagem_sr_pil, |
|
strength=0.7, |
|
num_inference_steps=35, |
|
guidance_scale=7.5, |
|
output_type="pil" |
|
).images[0] |
|
|
|
logger.success(f"Estilo aplicado: {imagem_estilizada.size}") |
|
|
|
|
|
logger.etapa("Gerando Mapa de Profundidade") |
|
inputs = processador_profundidade( |
|
images=imagem_estilizada, |
|
return_tensors="pt" |
|
).to(device, dtype=torch_dtype) |
|
|
|
with torch.no_grad(), torch.autocast(device_type=device.split(':')[0]): |
|
outputs = modelo_profundidade(**inputs) |
|
profundidade = outputs.predicted_depth |
|
|
|
|
|
profundidade = torch.nn.functional.interpolate( |
|
profundidade.unsqueeze(1).float(), |
|
size=imagem_estilizada.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
|
|
profundidade = (profundidade - profundidade.min()) / (profundidade.max() - profundidade.min() + 1e-8) |
|
mapa_profundidade = Image.fromarray((profundidade * 255).astype(np.uint8)) |
|
|
|
logger.success("Processamento completo") |
|
return imagem_sr_pil, imagem_estilizada, mapa_profundidade |
|
|
|
except Exception as erro: |
|
logger.error(f"ERRO NO PIPELINE: {str(erro)}", exc_info=True) |
|
return imagem_sr_pil if 'imagem_sr_pil' in locals() else None, None, None |
|
|
|
|
|
|
|
with gr.Blocks(title="TheraSR Art Suite", theme=gr.themes.Soft()) as app: |
|
gr.Markdown(""" |
|
# 🎨 TheraSR Art Suite |
|
**Combine super-resolução aliasing-free com geração artística de baixo-relevo** |
|
""") |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1): |
|
entrada_imagem = gr.Image(label="🖼 Imagem de Entrada", type="pil") |
|
seletor_modelo = gr.Radio( |
|
["EDSR", "RDN"], |
|
value="EDSR", |
|
label="🔧 Modelo de Super-Resolução" |
|
) |
|
controle_escala = gr.Slider( |
|
1.0, 6.0, |
|
value=2.0, |
|
step=0.1, |
|
label="🔍 Fator de Escala" |
|
) |
|
entrada_prompt = gr.Textbox( |
|
label="📝 Prompt de Estilo", |
|
value="insanely detailed and complex engraving relief, ultra HD 8k", |
|
placeholder="Descreva o estilo desejado..." |
|
) |
|
botao_processar = gr.Button("🚀 Processar Imagem", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
saida_sr = gr.Image(label="✨ Super-Resolução", interactive=False) |
|
saida_arte = gr.Image(label="🖌 Arte em Baixo-Relevo", interactive=False) |
|
saida_profundidade = gr.Image(label="🗺 Mapa de Profundidade", interactive=False) |
|
|
|
|
|
botao_processar.click( |
|
fn=pipeline_completo, |
|
inputs=[entrada_imagem, controle_escala, seletor_modelo, entrada_prompt], |
|
outputs=[saida_sr, saida_arte, saida_profundidade] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
share=False, |
|
debug=False |
|
) |