|
|
|
import logging |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import jax |
|
import pickle |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from model import build_thera |
|
from super_resolve import process |
|
from diffusers import StableDiffusionXLImg2ImgPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
|
|
|
|
|
|
class CustomLogger: |
|
def __init__(self, name): |
|
self.logger = logging.getLogger(name) |
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
handler = logging.StreamHandler() |
|
handler.setFormatter(formatter) |
|
self.logger.addHandler(handler) |
|
self.logger.setLevel(logging.INFO) |
|
|
|
def divider(self, text=None, length=60): |
|
if text: |
|
available_space = max(length - len(text) - 12, 1) |
|
msg = f"{'=' * 10} {text.upper()} {'=' * available_space}" |
|
else: |
|
msg = "=" * length |
|
self.logger.info(msg) |
|
|
|
def etapa(self, text): |
|
self.logger.info(f"▶ {text}") |
|
|
|
def success(self, text): |
|
self.logger.info(f"✓ {text}") |
|
|
|
def error(self, text): |
|
self.logger.error(f"✗ {text}") |
|
|
|
|
|
logger = CustomLogger(__name__) |
|
|
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
logger.divider("Configuração Forçada") |
|
logger.success(f"Dispositivo: {device.upper()}") |
|
logger.success(f"Precisão: {str(torch_dtype).replace('torch.', '')}") |
|
|
|
|
|
|
|
def carregar_modelo_thera(repo_id): |
|
try: |
|
logger.divider(f"Carregando {repo_id}") |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pkl") |
|
with open(model_path, 'rb') as f: |
|
check = pickle.load(f) |
|
model = build_thera(3, check['backbone'], check['size']) |
|
params = check['model'] |
|
logger.success(f"{repo_id} carregado") |
|
return model, params |
|
except Exception as e: |
|
logger.error(f"Falha no carregamento: {str(e)}") |
|
return None, None |
|
|
|
|
|
try: |
|
modelo_edsr, params_edsr = carregar_modelo_thera("prs-eth/thera-edsr-pro") |
|
modelo_rdn, params_rdn = carregar_modelo_thera("prs-eth/thera-rdn-pro") |
|
except Exception as e: |
|
logger.error("Falha crítica nos modelos Thera") |
|
raise |
|
|
|
|
|
pipe = None |
|
modelo_profundidade = None |
|
|
|
try: |
|
logger.divider("Configurando Componentes Artísticos") |
|
|
|
|
|
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype, |
|
variant="fp32" |
|
).to(device) |
|
|
|
|
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors", |
|
peft_backend="peft" |
|
) |
|
|
|
|
|
processador_profundidade = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
modelo_profundidade = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device).float() |
|
|
|
logger.success("Componentes artísticos em float32") |
|
except Exception as e: |
|
logger.warning(f"Recursos artísticos limitados: {str(e)}") |
|
print(e) |
|
pipe = None |
|
|
|
|
|
|
|
def processar_imagem_completa(imagem, escala, modelo, prompt): |
|
try: |
|
logger.divider("Iniciando Processamento") |
|
|
|
|
|
if not isinstance(imagem, Image.Image): |
|
imagem = Image.fromarray(imagem) |
|
|
|
|
|
logger.etapa("Processando Super-Resolução") |
|
modelo_sr = modelo_edsr if modelo == "EDSR" else modelo_rdn |
|
params_sr = params_edsr if modelo == "EDSR" else params_rdn |
|
|
|
sr_jax = process( |
|
np.array(imagem) / 255., |
|
modelo_sr, |
|
params_sr, |
|
(round(imagem.height * escala), |
|
round(imagem.width * escala)), |
|
True |
|
) |
|
|
|
sr_pil = Image.fromarray(np.array(sr_jax)).convert("RGB") |
|
logger.success(f"SR: {sr_pil.size[0]}x{sr_pil.size[1]}") |
|
|
|
|
|
arte_pil = sr_pil |
|
if pipe: |
|
try: |
|
logger.etapa("Aplicando Estilo") |
|
arte_pil = pipe( |
|
prompt=f"BAS-RELIEF {prompt}, marble texture, 8k", |
|
image=sr_pil, |
|
strength=0.6, |
|
num_inference_steps=25, |
|
guidance_scale=7.0, |
|
generator=torch.Generator(device).manual_seed(42) |
|
).images[0] |
|
logger.success("Estilo aplicado") |
|
except Exception as e: |
|
logger.error(f"Erro no estilo: {str(e)}") |
|
print(e) |
|
|
|
|
|
mapa_pil = arte_pil |
|
if modelo_profundidade: |
|
try: |
|
logger.etapa("Calculando Profundidade") |
|
inputs = processador_profundidade(arte_pil, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
depth = modelo_profundidade(**inputs).predicted_depth |
|
|
|
depth = torch.nn.functional.interpolate( |
|
depth.unsqueeze(1), |
|
size=arte_pil.size[::-1], |
|
mode="bicubic" |
|
).squeeze().cpu().numpy() |
|
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) |
|
mapa_pil = Image.fromarray((depth * 255).astype(np.uint8)) |
|
logger.success("Profundidade calculada") |
|
except Exception as e: |
|
logger.error(f"Erro na profundidade: {str(e)}") |
|
print (e) |
|
|
|
return sr_pil, arte_pil, mapa_pil |
|
|
|
except Exception as e: |
|
logger.error(f"Erro fatal: {str(e)}") |
|
print(e) |
|
return None, None, None |
|
|
|
|
|
|
|
with gr.Blocks(title="TheraSR Universal", theme=gr.themes.Soft()) as app: |
|
gr.Markdown("# 🏛 TheraSR - Processamento Completo em Float32") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Imagem de Entrada", type="pil") |
|
slider_scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala") |
|
radio_model = gr.Radio(["EDSR", "RDN"], value="EDSR", label="Modelo") |
|
text_prompt = gr.Textbox( |
|
label="Prompt de Estilo", |
|
value="ancient marble浮雕, ultra detailed, 8k cinematic" |
|
) |
|
btn_process = gr.Button("Processar", variant="primary") |
|
|
|
with gr.Column(): |
|
output_sr = gr.Image(label="Super-Resolução", interactive=False) |
|
output_art = gr.Image(label="Arte em Relevo", interactive=False) |
|
output_depth = gr.Image(label="Mapa de Profundidade", interactive=False) |
|
|
|
btn_process.click( |
|
processar_imagem_completa, |
|
inputs=[input_img, slider_scale, radio_model, text_prompt], |
|
outputs=[output_sr, output_art, output_depth] |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch(server_name="0.0.0.0", server_port=7860) |