New logic
Browse files
app.py
CHANGED
@@ -6,11 +6,24 @@ import numpy as np
|
|
6 |
from PIL import Image
|
7 |
import pickle
|
8 |
import warnings
|
|
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from diffusers import StableDiffusionXLImg2ImgPipeline
|
11 |
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
12 |
from model import build_thera
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Configurações e supressão de avisos
|
15 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
16 |
warnings.filterwarnings("ignore", category=UserWarning)
|
@@ -22,56 +35,86 @@ TORCH_DEVICE = "cpu"
|
|
22 |
|
23 |
# 1. Carregar modelos do Thera ----------------------------------------------------------------
|
24 |
def load_thera_model(repo_id, filename):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
-
|
36 |
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
37 |
|
38 |
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# 3. Carregar modelo de profundidade ----------------------------------------------------------
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
# Pipeline principal --------------------------------------------------------------------------
|
53 |
-
def full_pipeline(image, prompt, scale_factor=2.0):
|
54 |
try:
|
|
|
|
|
55 |
# 1. Super Resolução com Thera
|
|
|
56 |
image = image.convert("RGB")
|
|
|
|
|
57 |
source = np.array(image) / 255.0
|
|
|
58 |
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
|
|
59 |
|
|
|
60 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
61 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
62 |
|
63 |
-
|
64 |
upscaled = model_edsr.apply(
|
65 |
-
variables_edsr,
|
66 |
source_jax,
|
67 |
t,
|
68 |
target_shape
|
69 |
)
|
|
|
70 |
|
|
|
71 |
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
|
|
|
72 |
|
73 |
# 2. Gerar Bas-Relief
|
|
|
74 |
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
|
|
|
|
|
|
|
75 |
bas_relief = pipe(
|
76 |
prompt=full_prompt,
|
77 |
image=upscaled_pil,
|
@@ -79,9 +122,13 @@ def full_pipeline(image, prompt, scale_factor=2.0):
|
|
79 |
num_inference_steps=25,
|
80 |
guidance_scale=7.5
|
81 |
).images[0]
|
|
|
82 |
|
83 |
# 3. Calcular Depth Map
|
|
|
|
|
84 |
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
|
|
|
85 |
with torch.no_grad():
|
86 |
outputs = depth_model(**inputs)
|
87 |
depth = outputs.predicted_depth
|
@@ -92,12 +139,18 @@ def full_pipeline(image, prompt, scale_factor=2.0):
|
|
92 |
mode="bicubic"
|
93 |
).squeeze().cpu().numpy()
|
94 |
|
95 |
-
|
|
|
|
|
|
|
96 |
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
|
|
97 |
|
|
|
98 |
return upscaled_pil, bas_relief, depth_pil
|
99 |
|
100 |
except Exception as e:
|
|
|
101 |
raise gr.Error(f"Erro no processamento: {str(e)}")
|
102 |
|
103 |
|
@@ -127,4 +180,5 @@ with gr.Blocks(title="Super Res + Bas-Relief") as app:
|
|
127 |
)
|
128 |
|
129 |
if __name__ == "__main__":
|
|
|
130 |
app.launch(share=False)
|
|
|
6 |
from PIL import Image
|
7 |
import pickle
|
8 |
import warnings
|
9 |
+
import logging
|
10 |
+
from datetime import datetime
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from diffusers import StableDiffusionXLImg2ImgPipeline
|
13 |
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
14 |
from model import build_thera
|
15 |
|
16 |
+
# Configuração de logging
|
17 |
+
logging.basicConfig(
|
18 |
+
level=logging.INFO,
|
19 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
20 |
+
handlers=[
|
21 |
+
logging.FileHandler("processing.log"),
|
22 |
+
logging.StreamHandler()
|
23 |
+
]
|
24 |
+
)
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
# Configurações e supressão de avisos
|
28 |
warnings.filterwarnings("ignore", category=FutureWarning)
|
29 |
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
35 |
|
36 |
# 1. Carregar modelos do Thera ----------------------------------------------------------------
|
37 |
def load_thera_model(repo_id, filename):
|
38 |
+
try:
|
39 |
+
logger.info(f"Iniciando carregamento do modelo Thera de {repo_id}")
|
40 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
41 |
+
with open(model_path, 'rb') as fh:
|
42 |
+
check = pickle.load(fh)
|
43 |
+
variables = check['model']
|
44 |
+
backbone, size = check['backbone'], check['size']
|
45 |
+
model = build_thera(3, backbone, size)
|
46 |
+
logger.info("Modelo Thera carregado com sucesso")
|
47 |
+
return model, variables
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Falha ao carregar modelo Thera: {str(e)}")
|
50 |
+
raise
|
51 |
|
52 |
|
53 |
+
logger.info("Carregando Thera EDSR...")
|
54 |
model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
55 |
|
56 |
# 2. Carregar SDXL + LoRA ---------------------------------------------------------------------
|
57 |
+
try:
|
58 |
+
logger.info("Iniciando carregamento do SDXL + LoRA...")
|
59 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
60 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
61 |
+
torch_dtype=torch.float32
|
62 |
+
).to(TORCH_DEVICE)
|
63 |
+
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
64 |
+
logger.info("SDXL + LoRA carregado com sucesso")
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Falha ao carregar SDXL: {str(e)}")
|
67 |
+
raise
|
68 |
|
69 |
# 3. Carregar modelo de profundidade ----------------------------------------------------------
|
70 |
+
try:
|
71 |
+
logger.info("Iniciando carregamento do DPT Depth...")
|
72 |
+
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
73 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
|
74 |
+
logger.info("Modelo DPT carregado com sucesso")
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Falha ao carregar DPT: {str(e)}")
|
77 |
+
raise
|
78 |
|
79 |
|
80 |
# Pipeline principal --------------------------------------------------------------------------
|
81 |
+
def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
|
82 |
try:
|
83 |
+
progress(0, desc="Iniciando processamento...")
|
84 |
+
|
85 |
# 1. Super Resolução com Thera
|
86 |
+
progress(0.1, desc="Convertendo imagem para RGB...")
|
87 |
image = image.convert("RGB")
|
88 |
+
|
89 |
+
progress(0.2, desc="Preparando entrada para super-resolução...")
|
90 |
source = np.array(image) / 255.0
|
91 |
+
original_size = image.size
|
92 |
target_shape = (int(image.height * scale_factor), int(image.width * scale_factor))
|
93 |
+
logger.info(f"Super-resolução: {original_size} → {target_shape} (scale: {scale_factor}x)")
|
94 |
|
95 |
+
progress(0.3, desc="Processando com Thera...")
|
96 |
source_jax = jax.device_put(source, JAX_DEVICE)
|
97 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
98 |
|
99 |
+
start_time = datetime.now()
|
100 |
upscaled = model_edsr.apply(
|
101 |
+
variables_edsr,
|
102 |
source_jax,
|
103 |
t,
|
104 |
target_shape
|
105 |
)
|
106 |
+
logger.info(f"Super-resolução concluída em {datetime.now() - start_time}")
|
107 |
|
108 |
+
progress(0.5, desc="Convertendo resultado...")
|
109 |
upscaled_pil = Image.fromarray((np.array(upscaled) * 255).astype(np.uint8))
|
110 |
+
logger.info(f"Tamanho após super-resolução: {upscaled_pil.size}")
|
111 |
|
112 |
# 2. Gerar Bas-Relief
|
113 |
+
progress(0.6, desc="Gerando Bas-Relief...")
|
114 |
full_prompt = f"BAS-RELIEF {prompt}, insanely detailed and complex engraving relief, ultra-high definition, rich in detail, 16K resolution"
|
115 |
+
logger.info(f"Prompt final: {full_prompt}")
|
116 |
+
|
117 |
+
start_time = datetime.now()
|
118 |
bas_relief = pipe(
|
119 |
prompt=full_prompt,
|
120 |
image=upscaled_pil,
|
|
|
122 |
num_inference_steps=25,
|
123 |
guidance_scale=7.5
|
124 |
).images[0]
|
125 |
+
logger.info(f"Bas-Relief gerado em {datetime.now() - start_time}")
|
126 |
|
127 |
# 3. Calcular Depth Map
|
128 |
+
progress(0.8, desc="Calculando mapa de profundidade...")
|
129 |
+
start_time = datetime.now()
|
130 |
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
|
131 |
+
|
132 |
with torch.no_grad():
|
133 |
outputs = depth_model(**inputs)
|
134 |
depth = outputs.predicted_depth
|
|
|
139 |
mode="bicubic"
|
140 |
).squeeze().cpu().numpy()
|
141 |
|
142 |
+
progress(0.9, desc="Processando mapa de profundidade...")
|
143 |
+
depth_min = depth_map.min()
|
144 |
+
depth_max = depth_map.max()
|
145 |
+
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
|
146 |
depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
147 |
+
logger.info(f"Profundidade calculada em {datetime.now() - start_time} | Range: {depth_min:.2f}-{depth_max:.2f}")
|
148 |
|
149 |
+
progress(1.0, desc="Finalizado!")
|
150 |
return upscaled_pil, bas_relief, depth_pil
|
151 |
|
152 |
except Exception as e:
|
153 |
+
logger.error(f"Erro no processamento: {str(e)}", exc_info=True)
|
154 |
raise gr.Error(f"Erro no processamento: {str(e)}")
|
155 |
|
156 |
|
|
|
180 |
)
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
+
logger.info("Iniciando aplicação Gradio")
|
184 |
app.launch(share=False)
|