ds1david commited on
Commit
054a11a
·
1 Parent(s): d160dc6

fixing bugs

Browse files
Files changed (1) hide show
  1. utils.py +27 -26
utils.py CHANGED
@@ -4,48 +4,49 @@ import jax.numpy as jnp
4
  import numpy as np
5
 
6
 
7
- def repeat_vmap(fun, in_axes=None):
8
- if in_axes is None:
9
- in_axes = [0]
10
  for axes in in_axes:
11
  fun = jax.vmap(fun, in_axes=axes)
12
  return fun
13
 
14
 
15
  def make_grid(patch_size: int | tuple[int, int]):
 
 
16
  if isinstance(patch_size, int):
17
- patch_size = (max(1, patch_size), max(1, patch_size))
 
 
18
 
19
- offset_h, offset_w = 1 / (2 * np.array(patch_size))
20
- space_h = np.linspace(-0.5 + offset_h, 0.5 - offset_h, patch_size[0])
21
- space_w = np.linspace(-0.5 + offset_w, 0.5 - offset_w, patch_size[1])
22
 
23
- grid = np.stack(np.meshgrid(space_h, space_w, indexing='ij'), axis=-1)
24
- return grid[np.newaxis, ...] # Adiciona dimensão de batch
 
25
 
26
 
27
  def interpolate_grid(coords, grid, order=0):
28
- """Args:
29
- coords: Tensor de shape (B, H, W, 2) ou (H, W, 2)
30
- grid: Tensor de shape (B, H', W', C)
31
- order: default 0
32
- """
33
  try:
34
- # Converter para array JAX e ajustar dimensões
35
  coords = jnp.asarray(coords)
36
- while coords.ndim < 4:
37
- coords = coords[jnp.newaxis, ...]
38
-
39
- # Verificação final de dimensões
40
- if coords.shape[-1] != 2 or coords.ndim != 4:
41
- raise ValueError(f"Formato inválido: {coords.shape}. Esperado (B, H, W, 2)")
42
 
43
  # Transformação de coordenadas
44
  coords = coords.transpose((0, 3, 1, 2))
45
- coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
46
- coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
47
-
48
- # Função de interpolação vetorizada
 
 
 
 
49
  map_fn = jax.vmap(jax.vmap(
50
  partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest'),
51
  in_axes=(2, None),
@@ -54,4 +55,4 @@ def interpolate_grid(coords, grid, order=0):
54
  return map_fn(grid, coords)
55
 
56
  except Exception as e:
57
- raise RuntimeError(f"Falha na interpolação: {str(e)}") from e
 
4
  import numpy as np
5
 
6
 
7
+ def repeat_vmap(fun, in_axes=[0]):
 
 
8
  for axes in in_axes:
9
  fun = jax.vmap(fun, in_axes=axes)
10
  return fun
11
 
12
 
13
  def make_grid(patch_size: int | tuple[int, int]):
14
+ """Gera grid de coordenadas com segurança numérica"""
15
+ # Garantir tamanho mínimo de 8x8
16
  if isinstance(patch_size, int):
17
+ h = w = max(8, patch_size)
18
+ else:
19
+ h, w = (max(8, ps) for ps in patch_size)
20
 
21
+ # Espaçamento preciso entre pontos
22
+ y_space = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
23
+ x_space = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
24
 
25
+ # Criar grid com dimensões (1, H, W, 2)
26
+ grid = np.stack(np.meshgrid(y_space, x_space, indexing='ij'), axis=-1)
27
+ return grid[np.newaxis, ...]
28
 
29
 
30
  def interpolate_grid(coords, grid, order=0):
31
+ """Interpolação segura com verificação de dimensões"""
 
 
 
 
32
  try:
33
+ # Converter para JAX array e validar formato
34
  coords = jnp.asarray(coords)
35
+ if coords.ndim != 4 or coords.shape[-1] != 2:
36
+ raise ValueError(
37
+ f"Dimensões inválidas: {coords.shape}. Esperado (B, H, W, 2)"
38
+ )
 
 
39
 
40
  # Transformação de coordenadas
41
  coords = coords.transpose((0, 3, 1, 2))
42
+ coords = coords.at[:, 0].set(
43
+ coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
44
+ )
45
+ coords = coords.at[:, 1].set(
46
+ coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
47
+ )
48
+
49
+ # Interpolação vetorizada
50
  map_fn = jax.vmap(jax.vmap(
51
  partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest'),
52
  in_axes=(2, None),
 
55
  return map_fn(grid, coords)
56
 
57
  except Exception as e:
58
+ raise RuntimeError(f"Erro de interpolação: {str(e)}") from e