ds1david commited on
Commit
19a6d73
·
1 Parent(s): 054a11a

fixing bugs

Browse files
Files changed (2) hide show
  1. app.py +94 -75
  2. utils.py +24 -26
app.py CHANGED
@@ -1,106 +1,106 @@
1
- import logging
2
- import pickle
3
- import warnings
4
-
5
  import gradio as gr
 
6
  import jax
7
  import jax.numpy as jnp
8
  import numpy as np
9
- import torch
10
  from PIL import Image
11
- from diffusers import StableDiffusionXLImg2ImgPipeline
 
12
  from huggingface_hub import hf_hub_download
 
13
  from transformers import DPTImageProcessor, DPTForDepthEstimation
14
-
15
  from model import build_thera
16
- from utils import make_grid
17
 
18
  # Configuração de logging
19
  logging.basicConfig(
20
  level=logging.INFO,
21
  format='%(asctime)s - %(levelname)s - %(message)s',
22
- handlers=[
23
- logging.FileHandler("processing.log"),
24
- logging.StreamHandler()
25
- ]
26
  )
27
  logger = logging.getLogger(__name__)
28
 
29
  # Configurações
30
- warnings.filterwarnings("ignore")
31
  JAX_DEVICE = jax.devices("cpu")[0]
32
  TORCH_DEVICE = "cpu"
33
 
34
 
35
- def load_thera_model(repo_id, filename):
 
36
  try:
37
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
38
  with open(model_path, 'rb') as fh:
39
- check = pickle.load(fh)
40
- variables = check['model']
41
- backbone, size = check['backbone'], check['size']
42
- return build_thera(3, backbone, size), variables
43
  except Exception as e:
44
- logger.error(f"Erro ao carregar Thera: {str(e)}")
45
  raise
46
 
47
 
48
- logger.info("Carregando modelos...")
49
- model_edsr, variables_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
50
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
51
- "stabilityai/stable-diffusion-xl-base-1.0",
52
- torch_dtype=torch.float32
53
- ).to(TORCH_DEVICE)
54
- pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
55
- feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
56
- depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
 
 
 
 
57
 
58
 
59
- def adjust_size(size):
60
- return max(8, (size // 8) * 8)
 
 
 
61
 
62
 
63
- def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
 
64
  try:
65
- progress(0.1, desc="Iniciando...")
66
  image = image.convert("RGB")
67
- source = np.array(image) / 255.0
68
 
69
- # Ajuste de dimensões
70
- target_shape = (
71
- adjust_size(int(image.height * scale_factor)),
72
- adjust_size(int(image.width * scale_factor))
73
- )
74
- logger.info(f"Transformação: {image.size} → {target_shape}")
75
 
76
- # Gerar grid
77
- coords = make_grid(target_shape)
78
- logger.debug(f"Coords shape: {coords.shape}")
 
 
 
 
79
 
80
  # Super-resolução
81
- progress(0.3, desc="Processando super-resolução...")
82
- source_jax = jax.device_put(source[np.newaxis, ...], JAX_DEVICE)
 
83
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
 
84
 
85
- upscaled = model_edsr.apply(
86
- variables_edsr,
87
- source_jax,
88
- t,
89
- target_shape
90
- )
91
- upscaled_pil = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
92
 
93
  # Bas-Relief
94
- progress(0.6, desc="Gerando relevo...")
95
- bas_relief = pipe(
96
- prompt=f"BAS-RELIEF {prompt}, ultra detailed engraving, 16K resolution",
97
- image=upscaled_pil,
98
  strength=0.7,
99
- num_inference_steps=25
100
- ).images[0]
 
101
 
102
- # Depth Map
103
- progress(0.8, desc="Calculando profundidade...")
104
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
105
  with torch.no_grad():
106
  depth = depth_model(**inputs).predicted_depth
@@ -111,30 +111,49 @@ def full_pipeline(image, prompt, scale_factor=2.0, progress=gr.Progress()):
111
  mode="bicubic"
112
  ).squeeze().cpu().numpy()
113
 
114
- depth_normalized = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
115
- depth_pil = Image.fromarray((depth_normalized * 255).astype(np.uint8))
 
 
 
116
 
117
- return upscaled_pil, bas_relief, depth_pil
118
 
119
  except Exception as e:
120
- logger.error(f"ERRO: {str(e)}", exc_info=True)
121
- raise gr.Error(f"Erro no processamento: {str(e)}")
122
 
123
 
124
  # Interface
125
- with gr.Blocks(title="SuperRes + BasRelief") as app:
126
- gr.Markdown("## 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
 
127
  with gr.Row():
128
  with gr.Column():
129
- img_input = gr.Image(type="pil", label="Entrada")
130
- prompt = gr.Textbox("Escultura detalhada em mármore, alto relevo", label="Descrição")
131
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Escala")
132
- btn = gr.Button("Processar ▶️")
 
 
 
 
 
133
  with gr.Column():
134
- img_upscaled = gr.Image(label="Super Resolução")
135
- img_basrelief = gr.Image(label="Bas-Relief")
136
- img_depth = gr.Image(label="Profundidade")
137
- btn.click(full_pipeline, [img_input, prompt, scale], [img_upscaled, img_basrelief, img_depth])
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
- app.launch()
 
1
+ # app.py
 
 
 
2
  import gradio as gr
3
+ import torch
4
  import jax
5
  import jax.numpy as jnp
6
  import numpy as np
 
7
  from PIL import Image
8
+ import pickle
9
+ import logging
10
  from huggingface_hub import hf_hub_download
11
+ from diffusers import StableDiffusionXLImg2ImgPipeline
12
  from transformers import DPTImageProcessor, DPTForDepthEstimation
 
13
  from model import build_thera
14
+ from utils import make_grid, interpolate_grid
15
 
16
  # Configuração de logging
17
  logging.basicConfig(
18
  level=logging.INFO,
19
  format='%(asctime)s - %(levelname)s - %(message)s',
20
+ handlers=[logging.FileHandler("processing.log"), logging.StreamHandler()]
 
 
 
21
  )
22
  logger = logging.getLogger(__name__)
23
 
24
  # Configurações
 
25
  JAX_DEVICE = jax.devices("cpu")[0]
26
  TORCH_DEVICE = "cpu"
27
 
28
 
29
+ def load_thera_model(repo_id: str, filename: str):
30
+ """Carrega modelo com verificação de segurança"""
31
  try:
32
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
33
  with open(model_path, 'rb') as fh:
34
+ checkpoint = pickle.load(fh)
35
+ return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
 
 
36
  except Exception as e:
37
+ logger.error(f"Erro ao carregar modelo: {str(e)}")
38
  raise
39
 
40
 
41
+ # Inicialização dos modelos
42
+ try:
43
+ logger.info("Carregando modelos...")
44
+ model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
45
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
46
+ "stabilityai/stable-diffusion-xl-base-1.0",
47
+ torch_dtype=torch.float32
48
+ ).to(TORCH_DEVICE)
49
+ pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
50
+ feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
51
+ depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
52
+ except Exception as e:
53
+ logger.error(f"Falha na inicialização: {str(e)}")
54
+ raise
55
 
56
 
57
+ def adjust_size(original: int, scale: float) -> int:
58
+ """Ajuste de tamanho com limites seguros"""
59
+ scaled = int(original * scale)
60
+ adjusted = (scaled // 8) * 8 # Divisível por 8
61
+ return max(32, adjusted) # Mínimo absoluto
62
 
63
 
64
+ def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
65
+ """Pipeline completo com tratamento robusto"""
66
  try:
67
+ # Pré-processamento
68
  image = image.convert("RGB")
69
+ orig_w, orig_h = image.size
70
 
71
+ # Cálculo do tamanho alvo
72
+ new_h = adjust_size(orig_h, scale_factor)
73
+ new_w = adjust_size(orig_w, scale_factor)
74
+ logger.info(f"Redimensionando: {orig_h}x{orig_w} → {new_h}x{new_w}")
 
 
75
 
76
+ # Gerar grid de coordenadas
77
+ coords = make_grid((new_h, new_w))
78
+ logger.debug(f"Dimensões do grid: {coords.shape}")
79
+
80
+ # Verificação crítica
81
+ if coords.shape[1:3] != (new_h, new_w):
82
+ raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})")
83
 
84
  # Super-resolução
85
+ source = jnp.array(image).astype(jnp.float32) / 255.0
86
+ source = source[jnp.newaxis, ...] # Adicionar batch
87
+
88
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
89
+ upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
90
 
91
+ # Pós-processamento
92
+ upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
 
 
 
 
 
93
 
94
  # Bas-Relief
95
+ result = pipe(
96
+ prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
97
+ image=upscaled_img,
 
98
  strength=0.7,
99
+ num_inference_steps=30
100
+ )
101
+ bas_relief = result.images[0]
102
 
103
+ # Mapa de profundidade
 
104
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
105
  with torch.no_grad():
106
  depth = depth_model(**inputs).predicted_depth
 
111
  mode="bicubic"
112
  ).squeeze().cpu().numpy()
113
 
114
+ # Normalização
115
+ depth_min = depth_map.min()
116
+ depth_max = depth_map.max()
117
+ depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
118
+ depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
119
 
120
+ return upscaled_img, bas_relief, depth_img
121
 
122
  except Exception as e:
123
+ logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
124
+ raise gr.Error(f"Processamento falhou: {str(e)}")
125
 
126
 
127
  # Interface
128
+ with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app:
129
+ gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
130
+
131
  with gr.Row():
132
  with gr.Column():
133
+ img_input = gr.Image(label="Imagem de Entrada", type="pil")
134
+ prompt = gr.Textbox(
135
+ label="Descrição do Relevo",
136
+ value="Ainsanely detailed and complex engraving relief, ultra-high definition",
137
+ placeholder="Descreva o estilo desejado..."
138
+ )
139
+ scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
140
+ btn = gr.Button("Processar Imagem", variant="primary")
141
+
142
  with gr.Column():
143
+ gr.Markdown("## Resultados")
144
+ with gr.Tabs():
145
+ with gr.TabItem("Super Resolução"):
146
+ upscaled_output = gr.Image(label="Resultado Super Resolução")
147
+ with gr.TabItem("Bas-Relief"):
148
+ basrelief_output = gr.Image(label="Relevo Gerado")
149
+ with gr.TabItem("Profundidade"):
150
+ depth_output = gr.Image(label="Mapa de Profundidade")
151
+
152
+ btn.click(
153
+ full_pipeline,
154
+ inputs=[img_input, prompt, scale],
155
+ outputs=[upscaled_output, basrelief_output, depth_output]
156
+ )
157
 
158
  if __name__ == "__main__":
159
+ app.launch(server_name="0.0.0.0", server_port=7860)
utils.py CHANGED
@@ -1,40 +1,40 @@
1
- from functools import partial
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
-
6
-
7
- def repeat_vmap(fun, in_axes=[0]):
8
- for axes in in_axes:
9
- fun = jax.vmap(fun, in_axes=axes)
10
- return fun
11
 
12
 
13
  def make_grid(patch_size: int | tuple[int, int]):
14
- """Gera grid de coordenadas com segurança numérica"""
15
- # Garantir tamanho mínimo de 8x8
16
  if isinstance(patch_size, int):
17
- h = w = max(8, patch_size)
18
  else:
19
- h, w = (max(8, ps) for ps in patch_size)
20
 
21
- # Espaçamento preciso entre pontos
22
- y_space = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
23
- x_space = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
24
 
25
- # Criar grid com dimensões (1, H, W, 2)
26
- grid = np.stack(np.meshgrid(y_space, x_space, indexing='ij'), axis=-1)
27
  return grid[np.newaxis, ...]
28
 
29
 
30
  def interpolate_grid(coords, grid, order=0):
31
- """Interpolação segura com verificação de dimensões"""
32
  try:
33
- # Converter para JAX array e validar formato
34
  coords = jnp.asarray(coords)
35
- if coords.ndim != 4 or coords.shape[-1] != 2:
 
 
 
 
 
 
36
  raise ValueError(
37
- f"Dimensões inválidas: {coords.shape}. Esperado (B, H, W, 2)"
38
  )
39
 
40
  # Transformação de coordenadas
@@ -47,12 +47,10 @@ def interpolate_grid(coords, grid, order=0):
47
  )
48
 
49
  # Interpolação vetorizada
50
- map_fn = jax.vmap(jax.vmap(
51
- partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest'),
52
- in_axes=(2, None),
53
- out_axes=2
54
- ))
55
- return map_fn(grid, coords)
56
 
57
  except Exception as e:
58
  raise RuntimeError(f"Erro de interpolação: {str(e)}") from e
 
1
+ # utils.py
2
  import jax
3
  import jax.numpy as jnp
4
  import numpy as np
5
+ from functools import partial
 
 
 
 
 
6
 
7
 
8
  def make_grid(patch_size: int | tuple[int, int]):
9
+ """Gera grid de coordenadas com validação robusta"""
 
10
  if isinstance(patch_size, int):
11
+ h = w = max(16, patch_size) # Novo mínimo seguro
12
  else:
13
+ h, w = (max(16, ps) for ps in patch_size) # 16x16 mínimo
14
 
15
+ # Cálculo preciso das coordenadas
16
+ y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
17
+ x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
18
 
19
+ # Grid com dimensões (1, H, W, 2)
20
+ grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
21
  return grid[np.newaxis, ...]
22
 
23
 
24
  def interpolate_grid(coords, grid, order=0):
25
+ """Interpolação com tratamento completo de dimensões"""
26
  try:
27
+ # Converter e garantir 4D
28
  coords = jnp.asarray(coords)
29
+ if coords.ndim == 1: # Caso de erro reportado
30
+ coords = coords.reshape(1, 1, 1, -1)
31
+ while coords.ndim < 4:
32
+ coords = coords[jnp.newaxis, ...]
33
+
34
+ # Validação final
35
+ if coords.shape[-1] != 2 or coords.ndim != 4:
36
  raise ValueError(
37
+ f"Dimensões inválidas: {coords.shape}. Formato esperado: (B, H, W, 2)"
38
  )
39
 
40
  # Transformação de coordenadas
 
47
  )
48
 
49
  # Interpolação vetorizada
50
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
51
+ order=order,
52
+ mode='nearest')
53
+ return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
 
 
54
 
55
  except Exception as e:
56
  raise RuntimeError(f"Erro de interpolação: {str(e)}") from e