ds1david commited on
Commit
1557411
·
1 Parent(s): e0956f1

fixing bugs

Browse files
Files changed (2) hide show
  1. app.py +89 -53
  2. utils.py +19 -14
app.py CHANGED
@@ -27,132 +27,168 @@ TORCH_DEVICE = "cpu"
27
 
28
 
29
  def load_thera_model(repo_id: str, filename: str):
30
- """Carrega modelo com verificação de segurança"""
31
  try:
32
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
33
  with open(model_path, 'rb') as fh:
34
  checkpoint = pickle.load(fh)
 
 
 
 
 
 
 
35
  return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
36
  except Exception as e:
37
  logger.error(f"Erro ao carregar modelo: {str(e)}")
38
  raise
39
 
40
 
41
- # Inicialização dos modelos
42
  try:
43
- logger.info("Carregando modelos...")
44
  model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
 
 
45
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
46
  "stabilityai/stable-diffusion-xl-base-1.0",
47
  torch_dtype=torch.float32
48
  ).to(TORCH_DEVICE)
49
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
 
 
50
  feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
51
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
 
52
  except Exception as e:
53
- logger.error(f"Falha na inicialização: {str(e)}")
54
  raise
55
 
56
 
57
- def adjust_size(original: int, scale: float) -> int:
58
- """Ajuste de tamanho com limites seguros"""
59
- scaled = int(original * scale)
60
- adjusted = (scaled // 8) * 8 # Divisível por 8
61
- return max(32, adjusted) # Mínimo absoluto
 
 
 
 
 
 
62
 
63
 
64
  def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
65
- """Pipeline completo com tratamento robusto"""
66
  try:
67
- # Pré-processamento
 
 
 
 
68
  image = image.convert("RGB")
69
  orig_w, orig_h = image.size
 
70
 
71
- # Cálculo do tamanho alvo
72
- new_h = adjust_size(orig_h, scale_factor)
73
- new_w = adjust_size(orig_w, scale_factor)
74
- logger.info(f"Redimensionando: {orig_h}x{orig_w} → {new_h}x{new_w}")
75
 
76
  # Gerar grid de coordenadas
77
- coords = make_grid((new_h, new_w))
78
- logger.debug(f"Dimensões do grid: {coords.shape}")
79
-
80
- # Verificação crítica
81
- if coords.shape[1:3] != (new_h, new_w):
82
- raise ValueError(f"Grid incorreto: {coords.shape[1:3]} vs ({new_h}, {new_w})")
 
 
 
83
 
84
- # Super-resolução
85
  source = jnp.array(image).astype(jnp.float32) / 255.0
86
- source = source[jnp.newaxis, ...] # Adicionar batch
87
 
 
88
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
 
 
89
  upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
90
 
91
- # Pós-processamento
92
  upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
 
93
 
94
- # Bas-Relief
95
  result = pipe(
96
  prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
97
  image=upscaled_img,
98
  strength=0.7,
99
- num_inference_steps=30
 
100
  )
101
  bas_relief = result.images[0]
 
102
 
103
- # Mapa de profundidade
104
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
105
  with torch.no_grad():
106
  depth = depth_model(**inputs).predicted_depth
107
 
 
108
  depth_map = torch.nn.functional.interpolate(
109
  depth.unsqueeze(1),
110
  size=bas_relief.size[::-1],
111
  mode="bicubic"
112
  ).squeeze().cpu().numpy()
113
 
114
- # Normalização
115
  depth_min = depth_map.min()
116
  depth_max = depth_map.max()
117
  depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
118
  depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
 
119
 
120
  return upscaled_img, bas_relief, depth_img
121
 
122
  except Exception as e:
123
  logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
124
- raise gr.Error(f"Processamento falhou: {str(e)}")
125
 
126
 
127
- # Interface
128
- with gr.Blocks(title="SuperRes+BasRelief", theme=gr.themes.Default()) as app:
129
  gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
130
 
131
  with gr.Row():
132
- with gr.Column():
133
- img_input = gr.Image(label="Imagem de Entrada", type="pil")
134
- prompt = gr.Textbox(
135
- label="Descrição do Relevo",
136
- value="Ainsanely detailed and complex engraving relief, ultra-high definition",
137
- placeholder="Descreva o estilo desejado..."
138
- )
139
- scale = gr.Slider(1.0, 4.0, value=2.0, label="Fator de Escala")
140
- btn = gr.Button("Processar Imagem", variant="primary")
141
-
142
- with gr.Column():
143
- gr.Markdown("## Resultados")
144
- with gr.Tabs():
145
- with gr.TabItem("Super Resolução"):
146
- upscaled_output = gr.Image(label="Resultado Super Resolução")
147
- with gr.TabItem("Bas-Relief"):
148
- basrelief_output = gr.Image(label="Relevo Gerado")
149
- with gr.TabItem("Profundidade"):
150
- depth_output = gr.Image(label="Mapa de Profundidade")
151
-
152
- btn.click(
 
 
153
  full_pipeline,
154
  inputs=[img_input, prompt, scale],
155
- outputs=[upscaled_output, basrelief_output, depth_output]
 
156
  )
157
 
158
  if __name__ == "__main__":
 
27
 
28
 
29
  def load_thera_model(repo_id: str, filename: str):
30
+ """Carrega modelo com múltiplas verificações"""
31
  try:
32
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
33
  with open(model_path, 'rb') as fh:
34
  checkpoint = pickle.load(fh)
35
+
36
+ # Verificar estrutura do checkpoint
37
+ required_keys = {'model', 'backbone', 'size'}
38
+ if not required_keys.issubset(checkpoint.keys()):
39
+ missing = required_keys - checkpoint.keys()
40
+ raise ValueError(f"Checkpoint corrompido. Chaves faltando: {missing}")
41
+
42
  return build_thera(3, checkpoint['backbone'], checkpoint['size']), checkpoint['model']
43
  except Exception as e:
44
  logger.error(f"Erro ao carregar modelo: {str(e)}")
45
  raise
46
 
47
 
48
+ # Inicialização segura
49
  try:
50
+ logger.info("Inicializando modelos...")
51
  model_edsr, params_edsr = load_thera_model("prs-eth/thera-edsr-pro", "model.pkl")
52
+
53
+ # Pipeline SDXL
54
  pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
55
  "stabilityai/stable-diffusion-xl-base-1.0",
56
  torch_dtype=torch.float32
57
  ).to(TORCH_DEVICE)
58
  pipe.load_lora_weights("KappaNeuro/bas-relief", weight_name="BAS-RELIEF.safetensors")
59
+
60
+ # Modelo de profundidade
61
  feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
62
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(TORCH_DEVICE)
63
+
64
  except Exception as e:
65
+ logger.error(f"Falha crítica na inicialização: {str(e)}")
66
  raise
67
 
68
 
69
+ def safe_resize(original: tuple[int, int], scale: float) -> tuple[int, int]:
70
+ """Calcula tamanho garantindo estabilidade numérica"""
71
+ h, w = original
72
+ new_h = int(h * scale)
73
+ new_w = int(w * scale)
74
+
75
+ # Ajustar para múltiplo de 8
76
+ new_h = max(32, new_h - new_h % 8)
77
+ new_w = max(32, new_w - new_w % 8)
78
+
79
+ return (new_h, new_w)
80
 
81
 
82
  def full_pipeline(image: Image.Image, prompt: str, scale_factor: float = 2.0):
83
+ """Pipeline completo com tratamento de erros robusto"""
84
  try:
85
+ # Verificação inicial
86
+ if not image:
87
+ raise ValueError("Nenhuma imagem fornecida")
88
+
89
+ # Conversão segura para RGB
90
  image = image.convert("RGB")
91
  orig_w, orig_h = image.size
92
+ logger.info(f"Processando imagem: {orig_w}x{orig_h}")
93
 
94
+ # Cálculo do novo tamanho
95
+ new_h, new_w = safe_resize((orig_h, orig_w), scale_factor)
96
+ logger.info(f"Novo tamanho calculado: {new_h}x{new_w}")
 
97
 
98
  # Gerar grid de coordenadas
99
+ grid = make_grid((new_h, new_w))
100
+ logger.debug(f"Grid gerado: {grid.shape}")
101
+
102
+ # Verificação crítica do grid
103
+ if grid.shape[1:3] != (new_h, new_w):
104
+ raise RuntimeError(
105
+ f"Incompatibilidade de dimensões: "
106
+ f"Grid {grid.shape[1:3]} vs Alvo {new_h}x{new_w}"
107
+ )
108
 
109
+ # Pré-processamento da imagem
110
  source = jnp.array(image).astype(jnp.float32) / 255.0
111
+ source = source[jnp.newaxis, ...] # Adicionar dimensão de batch
112
 
113
+ # Parâmetro de escala
114
  t = jnp.array([1.0 / (scale_factor ** 2)], dtype=jnp.float32)
115
+
116
+ # Processamento Thera
117
  upscaled = model_edsr.apply(params_edsr, source, t, (new_h, new_w))
118
 
119
+ # Conversão para PIL
120
  upscaled_img = Image.fromarray((np.array(upscaled[0]) * 255).astype(np.uint8))
121
+ logger.info(f"Imagem super-resolvida: {upscaled_img.size}")
122
 
123
+ # Geração do Bas-Relief
124
  result = pipe(
125
  prompt=f"BAS-RELIEF {prompt}, ultra detailed, 8K resolution",
126
  image=upscaled_img,
127
  strength=0.7,
128
+ num_inference_steps=30,
129
+ guidance_scale=7.5
130
  )
131
  bas_relief = result.images[0]
132
+ logger.info(f"Bas-Relief gerado: {bas_relief.size}")
133
 
134
+ # Cálculo da profundidade
135
  inputs = feature_extractor(bas_relief, return_tensors="pt").to(TORCH_DEVICE)
136
  with torch.no_grad():
137
  depth = depth_model(**inputs).predicted_depth
138
 
139
+ # Redimensionamento
140
  depth_map = torch.nn.functional.interpolate(
141
  depth.unsqueeze(1),
142
  size=bas_relief.size[::-1],
143
  mode="bicubic"
144
  ).squeeze().cpu().numpy()
145
 
146
+ # Normalização e conversão
147
  depth_min = depth_map.min()
148
  depth_max = depth_map.max()
149
  depth_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
150
  depth_img = Image.fromarray((depth_normalized * 255).astype(np.uint8))
151
+ logger.info("Mapa de profundidade calculado")
152
 
153
  return upscaled_img, bas_relief, depth_img
154
 
155
  except Exception as e:
156
  logger.error(f"ERRO NO PIPELINE: {str(e)}", exc_info=True)
157
+ raise gr.Error(f"Falha no processamento: {str(e)}")
158
 
159
 
160
+ # Interface Gradio
161
+ with gr.Blocks(title="SuperRes+BasRelief Pro", theme=gr.themes.Soft()) as app:
162
  gr.Markdown("# 🖼️ Super Resolução + 🗿 Bas-Relief + 🗺️ Mapa de Profundidade")
163
 
164
  with gr.Row():
165
+ input_col = gr.Column()
166
+ output_col = gr.Column()
167
+
168
+ with input_col:
169
+ img_input = gr.Image(label="Carregar Imagem", type="pil", height=300)
170
+ prompt = gr.Textbox(
171
+ label="Descrição do Relevo",
172
+ value="A insanely detailed and complex engraving relief, ultra-high definition",
173
+ placeholder="Descreva o estilo desejado..."
174
+ )
175
+ scale = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="Fator de Escala")
176
+ process_btn = gr.Button("Iniciar Processamento", variant="primary")
177
+
178
+ with output_col:
179
+ with gr.Tabs():
180
+ with gr.TabItem("Super Resolução"):
181
+ upscaled_output = gr.Image(label="Resultado", show_label=False)
182
+ with gr.TabItem("Bas-Relief"):
183
+ basrelief_output = gr.Image(label="Relevo", show_label=False)
184
+ with gr.TabItem("Profundidade"):
185
+ depth_output = gr.Image(label="Mapa 3D", show_label=False)
186
+
187
+ process_btn.click(
188
  full_pipeline,
189
  inputs=[img_input, prompt, scale],
190
+ outputs=[upscaled_output, basrelief_output, depth_output],
191
+ api_name="processar"
192
  )
193
 
194
  if __name__ == "__main__":
utils.py CHANGED
@@ -13,36 +13,41 @@ def repeat_vmap(fun, in_axes=None):
13
  return fun
14
 
15
 
16
- def make_grid(patch_size: int | tuple[int, int]):
17
- """Gera grid de coordenadas com validação robusta"""
18
- if isinstance(patch_size, int):
19
- h = w = max(16, patch_size) # Novo mínimo seguro
20
- else:
21
- h, w = (max(16, ps) for ps in patch_size) # 16x16 mínimo
22
 
23
- # Cálculo preciso das coordenadas
 
 
 
 
 
 
24
  y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
25
  x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
26
 
27
- # Grid com dimensões (1, H, W, 2)
28
  grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
29
  return grid[np.newaxis, ...]
30
 
31
 
32
  def interpolate_grid(coords, grid, order=0):
33
- """Interpolação com tratamento completo de dimensões"""
34
  try:
35
- # Converter e garantir 4D
36
  coords = jnp.asarray(coords)
37
- if coords.ndim == 1: # Caso de erro reportado
38
- coords = coords.reshape(1, 1, 1, -1)
 
39
  while coords.ndim < 4:
40
  coords = coords[jnp.newaxis, ...]
41
 
42
  # Validação final
43
  if coords.shape[-1] != 2 or coords.ndim != 4:
44
  raise ValueError(
45
- f"Dimensões inválidas: {coords.shape}. Formato esperado: (B, H, W, 2)"
 
46
  )
47
 
48
  # Transformação de coordenadas
@@ -61,4 +66,4 @@ def interpolate_grid(coords, grid, order=0):
61
  return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
62
 
63
  except Exception as e:
64
- raise RuntimeError(f"Erro de interpolação: {str(e)}") from e
 
13
  return fun
14
 
15
 
16
+ def make_grid(target_shape: tuple[int, int]):
17
+ """Gera grid de coordenadas com validação rigorosa"""
18
+ h, w = target_shape
 
 
 
19
 
20
+ # Garantir tamanho mínimo e divisibilidade
21
+ h = max(32, h)
22
+ w = max(32, w)
23
+ h = h if h % 8 == 0 else h + (8 - h % 8)
24
+ w = w if w % 8 == 0 else w + (8 - w % 8)
25
+
26
+ # Espaçamento preciso
27
  y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
28
  x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
29
 
30
+ # Criar grid 4D (1, H, W, 2)
31
  grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
32
  return grid[np.newaxis, ...]
33
 
34
 
35
  def interpolate_grid(coords, grid, order=0):
36
+ """Interpolação segura com verificações em tempo real"""
37
  try:
38
+ # Converter e garantir formato 4D
39
  coords = jnp.asarray(coords)
40
+ original_shape = coords.shape
41
+
42
+ # Adicionar dimensões faltantes
43
  while coords.ndim < 4:
44
  coords = coords[jnp.newaxis, ...]
45
 
46
  # Validação final
47
  if coords.shape[-1] != 2 or coords.ndim != 4:
48
  raise ValueError(
49
+ f"Formato inválido: {original_shape} → {coords.shape}. "
50
+ f"Esperado (B, H, W, 2)"
51
  )
52
 
53
  # Transformação de coordenadas
 
66
  return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
67
 
68
  except Exception as e:
69
+ raise RuntimeError(f"Erro na interpolação: {str(e)}") from e