fixing bugs
Browse files
app.py
CHANGED
@@ -1,106 +1,106 @@
|
|
1 |
-
|
2 |
-
import pickle
|
3 |
-
import warnings
|
4 |
-
|
5 |
import gradio as gr
|
|
|
6 |
import jax
|
7 |
import jax.numpy as jnp
|
8 |
import numpy as np
|
9 |
-
import torch
|
10 |
from PIL import Image
|
11 |
-
|
|
|
12 |
from huggingface_hub import hf_hub_download
|
|
|
13 |
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
14 |
-
|
15 |
from model import build_thera
|
16 |
-
from utils import make_grid
|
17 |
|
18 |
# Configuração de logging
|
19 |
logging.basicConfig(
|
20 |
level=logging.INFO,
|
21 |
format='%(asctime)s - %(levelname)s - %(message)s',
|
22 |
-
handlers=[
|
23 |
-
logging.FileHandler("processing.log"),
|
24 |
-
logging.StreamHandler()
|
25 |
-
]
|
26 |
)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
29 |
# Configurações
|
30 |
-
warnings.filterwarnings("ignore")
|
31 |
JAX_DEVICE = jax.devices("cpu")[0]
|
32 |
TORCH_DEVICE = "cpu"
|
33 |
|
34 |
|
35 |
-
def load_thera_model(repo_id, filename):
|
|
|
36 |
try:
|
37 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
38 |
with open(model_path, 'rb') as fh:
|
39 |
-
|
40 |
-
|
41 |
-
backbone, size = check['backbone'], check['size']
|
42 |
-
return build_thera(3, backbone, size), variables
|
43 |
except Exception as e:
|
44 |
-
logger.error(f"Erro ao carregar
|
45 |
raise
|
46 |
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
"
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
-
def adjust_size(
|
60 |
-
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
-
def full_pipeline(image, prompt, scale_factor=2.0
|
|
|
64 |
try:
|
65 |
-
|
66 |
image = image.convert("RGB")
|
67 |
-
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
)
|
74 |
-
logger.info(f"Transformação: {image.size} → {target_shape}")
|
75 |
|
76 |
-
# Gerar grid
|
77 |
-
coords = make_grid(
|
78 |
-
logger.debug(f"
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# Super-resolução
|
81 |
-
|
82 |
-
|
|
|
83 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
source_jax,
|
88 |
-
t,
|
89 |
-
target_shape
|
90 |
-
)
|
91 |
-
upscaled_pil = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
|
92 |
|
93 |
# Bas-Relief
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
image=upscaled_pil,
|
98 |
strength=0.7,
|
99 |
-
num_inference_steps=
|
100 |
-
)
|
|
|
101 |
|
102 |
-
#
|
103 |
-
progress(0.8, desc="Calculando profundidade...")
|
104 |
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
|
105 |
with torch.no_grad():
|
106 |
depth = depth_model(**inputs).predicted_depth
|
@@ -111,30 +111,49 @@ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
|
|
111 |
mode="bicubic"
|
112 |
).squeeze().cpu().numpy()
|
113 |
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
-
return
|
118 |
|
119 |
except Exception as e:
|
120 |
-
logger.error(f"ERRO: {str(e)}", exc_info=True)
|
121 |
-
raise gr.Error(f"
|
122 |
|
123 |
|
124 |
# Interface
|
125 |
-
with gr.Blocks(title="SuperRes
|
126 |
-
gr.Markdown("
|
|
|
127 |
with gr.Row():
|
128 |
with gr.Column():
|
129 |
-
img_input = gr.Image(
|
130 |
-
prompt = gr.Textbox(
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
133 |
with gr.Column():
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
if __name__ == "__main__":
|
140 |
-
app.launch()
|
|
|
1 |
+
# app.py
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
+
import torch
|
4 |
import jax
|
5 |
import jax.numpy as jnp
|
6 |
import numpy as np
|
|
|
7 |
from PIL import Image
|
8 |
+
import pickle
|
9 |
+
import logging
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
12 |
from transformers import DPTImageProcessor, DPTForDepthEstimation
|
|
|
13 |
from model import build_thera
|
14 |
+
from utils import make_grid, interpolate_grid
|
15 |
|
16 |
# Configuração de logging
|
17 |
logging.basicConfig(
|
18 |
level=logging.INFO,
|
19 |
format='%(asctime)s - %(levelname)s - %(message)s',
|
20 |
+
handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
|
|
|
|
|
|
|
21 |
)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
# Configurações
|
|
|
25 |
JAX_DEVICE = jax.devices("cpu")[0]
|
26 |
TORCH_DEVICE = "cpu"
|
27 |
|
28 |
|
29 |
+
def load_thera_model(repo_id: str, filename: str):
|
30 |
+
"""Carrega modelo com verificação de segurança"""
|
31 |
try:
|
32 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
33 |
with open(model_path, 'rb') as fh:
|
34 |
+
checkpoint = pickle.load(fh)
|
35 |
+
return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
|
|
|
|
|
36 |
except Exception as e:
|
37 |
+
logger.error(f"Erro ao carregar modelo: {str(e)}")
|
38 |
raise
|
39 |
|
40 |
|
41 |
+
# Inicialização dos modelos
|
42 |
+
try:
|
43 |
+
logger.info("Carregando modelos...")
|
44 |
+
model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
|
45 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
46 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
47 |
+
torch_dtype=torch.float32
|
48 |
+
).to(TORCH_DEVICE)
|
49 |
+
pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
|
50 |
+
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
51 |
+
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Falha na inicialização: {str(e)}")
|
54 |
+
raise
|
55 |
|
56 |
|
57 |
+
def adjust_size(original: int, scale: float) -> int:
|
58 |
+
"""Ajuste de tamanho com limites seguros"""
|
59 |
+
scaled = int(original * scale)
|
60 |
+
adjusted = (scaled // 8) * 8 # Divisível por 8
|
61 |
+
return max(32, adjusted) # Mínimo absoluto
|
62 |
|
63 |
|
64 |
+
def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
|
65 |
+
"""Pipeline completo com tratamento robusto"""
|
66 |
try:
|
67 |
+
# Pré-processamento
|
68 |
image = image.convert("RGB")
|
69 |
+
orig_w, orig_h = image.size
|
70 |
|
71 |
+
# Cálculo do tamanho alvo
|
72 |
+
new_h = adjust_size(orig_h, scale_factor)
|
73 |
+
new_w = adjust_size(orig_w, scale_factor)
|
74 |
+
logger.info(f"Redimensionando: {orig_h}x{orig_w} → {new_h}x{new_w}")
|
|
|
|
|
75 |
|
76 |
+
# Gerar grid de coordenadas
|
77 |
+
coords = make_grid((new_h, new_w))
|
78 |
+
logger.debug(f"Dimensões do grid: {coords.shape}")
|
79 |
+
|
80 |
+
# Verificação crítica
|
81 |
+
if coords.shape[1:3] != (new_h, new_w):
|
82 |
+
raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})")
|
83 |
|
84 |
# Super-resolução
|
85 |
+
source = jnp.array(image).astype(jnp.float32) / 255.0
|
86 |
+
source = source[jnp.newaxis, ...] # Adicionar batch
|
87 |
+
|
88 |
t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
|
89 |
+
upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
|
90 |
|
91 |
+
# Pós-processamento
|
92 |
+
upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# Bas-Relief
|
95 |
+
result = pipe(
|
96 |
+
prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
|
97 |
+
image=upscaled_img,
|
|
|
98 |
strength=0.7,
|
99 |
+
num_inference_steps=30
|
100 |
+
)
|
101 |
+
bas_relief = result.images[0]
|
102 |
|
103 |
+
# Mapa de profundidade
|
|
|
104 |
inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
|
105 |
with torch.no_grad():
|
106 |
depth = depth_model(**inputs).predicted_depth
|
|
|
111 |
mode="bicubic"
|
112 |
).squeeze().cpu().numpy()
|
113 |
|
114 |
+
# Normalização
|
115 |
+
depth_min = depth_map.min()
|
116 |
+
depth_max = depth_map.max()
|
117 |
+
depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
|
118 |
+
depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
|
119 |
|
120 |
+
return upscaled_img, bas_relief, depth_img
|
121 |
|
122 |
except Exception as e:
|
123 |
+
logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
|
124 |
+
raise gr.Error(f"Processamento falhou: {str(e)}")
|
125 |
|
126 |
|
127 |
# Interface
|
128 |
+
with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app:
|
129 |
+
gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
|
130 |
+
|
131 |
with gr.Row():
|
132 |
with gr.Column():
|
133 |
+
img_input = gr.Image(label="Imagem de Entrada", type="pil")
|
134 |
+
prompt = gr.Textbox(
|
135 |
+
label="Descrição do Relevo",
|
136 |
+
value="Ainsanely detailed and complex engraving relief, ultra-high definition",
|
137 |
+
placeholder="Descreva o estilo desejado..."
|
138 |
+
)
|
139 |
+
scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
|
140 |
+
btn = gr.Button("Processar Imagem", variant="primary")
|
141 |
+
|
142 |
with gr.Column():
|
143 |
+
gr.Markdown("## Resultados")
|
144 |
+
with gr.Tabs():
|
145 |
+
with gr.TabItem("Super Resolução"):
|
146 |
+
upscaled_output = gr.Image(label="Resultado Super Resolução")
|
147 |
+
with gr.TabItem("Bas-Relief"):
|
148 |
+
basrelief_output = gr.Image(label="Relevo Gerado")
|
149 |
+
with gr.TabItem("Profundidade"):
|
150 |
+
depth_output = gr.Image(label="Mapa de Profundidade")
|
151 |
+
|
152 |
+
btn.click(
|
153 |
+
full_pipeline,
|
154 |
+
inputs=[img_input, prompt, scale],
|
155 |
+
outputs=[upscaled_output, basrelief_output, depth_output]
|
156 |
+
)
|
157 |
|
158 |
if __name__ == "__main__":
|
159 |
+
app.launch(server_name="0.0.0.0", server_port=7860)
|
utils.py
CHANGED
@@ -1,40 +1,40 @@
|
|
1 |
-
|
2 |
import jax
|
3 |
import jax.numpy as jnp
|
4 |
import numpy as np
|
5 |
-
|
6 |
-
|
7 |
-
def repeat_vmap(fun, in_axes=[0]):
|
8 |
-
for axes in in_axes:
|
9 |
-
fun = jax.vmap(fun, in_axes=axes)
|
10 |
-
return fun
|
11 |
|
12 |
|
13 |
def make_grid(patch_size: int | tuple[int, int]):
|
14 |
-
"""Gera grid de coordenadas com
|
15 |
-
# Garantir tamanho mínimo de 8x8
|
16 |
if isinstance(patch_size, int):
|
17 |
-
h = w = max(
|
18 |
else:
|
19 |
-
h, w = (max(
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
|
25 |
-
#
|
26 |
-
grid = np.stack(np.meshgrid(
|
27 |
return grid[np.newaxis, ...]
|
28 |
|
29 |
|
30 |
def interpolate_grid(coords, grid, order=0):
|
31 |
-
"""Interpolação
|
32 |
try:
|
33 |
-
# Converter
|
34 |
coords = jnp.asarray(coords)
|
35 |
-
if coords.ndim
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
raise ValueError(
|
37 |
-
f"Dimensões inválidas: {coords.shape}.
|
38 |
)
|
39 |
|
40 |
# Transformação de coordenadas
|
@@ -47,12 +47,10 @@ def interpolate_grid(coords, grid, order=0):
|
|
47 |
)
|
48 |
|
49 |
# Interpolação vetorizada
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
))
|
55 |
-
return map_fn(grid, coords)
|
56 |
|
57 |
except Exception as e:
|
58 |
raise RuntimeError(f"Erro de interpolação: {str(e)}") from e
|
|
|
1 |
+
# utils.py
|
2 |
import jax
|
3 |
import jax.numpy as jnp
|
4 |
import numpy as np
|
5 |
+
from functools import partial
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def make_grid(patch_size: int | tuple[int, int]):
|
9 |
+
"""Gera grid de coordenadas com validação robusta"""
|
|
|
10 |
if isinstance(patch_size, int):
|
11 |
+
h = w = max(16, patch_size) # Novo mínimo seguro
|
12 |
else:
|
13 |
+
h, w = (max(16, ps) for ps in patch_size) # 16x16 mínimo
|
14 |
|
15 |
+
# Cálculo preciso das coordenadas
|
16 |
+
y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
|
17 |
+
x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
|
18 |
|
19 |
+
# Grid com dimensões (1, H, W, 2)
|
20 |
+
grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
|
21 |
return grid[np.newaxis, ...]
|
22 |
|
23 |
|
24 |
def interpolate_grid(coords, grid, order=0):
|
25 |
+
"""Interpolação com tratamento completo de dimensões"""
|
26 |
try:
|
27 |
+
# Converter e garantir 4D
|
28 |
coords = jnp.asarray(coords)
|
29 |
+
if coords.ndim == 1: # Caso de erro reportado
|
30 |
+
coords = coords.reshape(1, 1, 1, -1)
|
31 |
+
while coords.ndim < 4:
|
32 |
+
coords = coords[jnp.newaxis, ...]
|
33 |
+
|
34 |
+
# Validação final
|
35 |
+
if coords.shape[-1] != 2 or coords.ndim != 4:
|
36 |
raise ValueError(
|
37 |
+
f"Dimensões inválidas: {coords.shape}. Formato esperado: (B, H, W, 2)"
|
38 |
)
|
39 |
|
40 |
# Transformação de coordenadas
|
|
|
47 |
)
|
48 |
|
49 |
# Interpolação vetorizada
|
50 |
+
map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
|
51 |
+
order=order,
|
52 |
+
mode='nearest')
|
53 |
+
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
|
|
|
|
|
54 |
|
55 |
except Exception as e:
|
56 |
raise RuntimeError(f"Erro de interpolação: {str(e)}") from e
|