ds1david commited on
Commit
4a3fe77
·
1 Parent(s): d85fde4

fixing bugs

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. utils.py +34 -11
requirements.txt CHANGED
@@ -25,7 +25,7 @@ opt-einsum==3.3.0
25
  optax==0.2.0
26
  orbax-checkpoint==0.2.4
27
  peft
28
- pillow
29
  scipy==1.10.1
30
  timm==0.9.6
31
  torch
 
25
  optax==0.2.0
26
  orbax-checkpoint==0.2.4
27
  peft
28
+ Pillow==10.0.0
29
  scipy==1.10.1
30
  timm==0.9.6
31
  torch
utils.py CHANGED
@@ -4,7 +4,9 @@ import jax
4
  import numpy as np
5
 
6
 
7
- def repeat_vmap(fun, in_axes=[0]):
 
 
8
  for axes in in_axes:
9
  fun = jax.vmap(fun, in_axes=axes)
10
  return fun
@@ -21,16 +23,37 @@ def make_grid(patch_size: int | tuple[int, int]):
21
 
22
  def interpolate_grid(coords, grid, order=0):
23
  """
24
- args:
25
- coords: Tensor of shape (B, H, W, 2) with coordinates in [-0.5, 0.5]
26
- grid: Tensor of shape (B, H', W', C)
27
- returns:
28
- Tensor of shape (B, H, W, C) with interpolated values
29
  """
30
- # convert [-0.5, 0.5] -> [0, size], where pixel centers are expected at
31
- # [-0.5 + 1 / (2*size), ..., 0.5 - 1 / (2*size)]
32
- coords = coords.transpose((0, 3, 1, 2))
 
 
 
 
 
 
 
 
 
 
 
 
33
  coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
34
  coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
35
- map_coordinates = partial(jax.scipy.ndimage.map_coordinates, order=order, mode='nearest')
36
- return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
 
6
 
7
+ def repeat_vmap(fun, in_axes=None):
8
+ if in_axes is None:
9
+ in_axes = [0]
10
  for axes in in_axes:
11
  fun = jax.vmap(fun, in_axes=axes)
12
  return fun
 
23
 
24
  def interpolate_grid(coords, grid, order=0):
25
  """
26
+ Args:
27
+ coords: Tensor de shape (B, H, W, 2) ou (H, W, 2)
28
+ grid: Tensor de shape (B, H', W', C)
 
 
29
  """
30
+ # Adicionar dimensão de batch se necessário
31
+ if coords.ndim == 3:
32
+ coords = coords[np.newaxis, ...]
33
+
34
+ # Verificar dimensões
35
+ assert coords.ndim == 4, f"Dimensões inválidas para coords: {coords.shape}"
36
+ assert grid.ndim == 4, f"Dimensões inválidas para grid: {grid.shape}"
37
+
38
+ # Ajustar transposição de forma segura
39
+ try:
40
+ coords = coords.transpose((0, 3, 1, 2))
41
+ except ValueError as e:
42
+ raise ValueError(f"Falha na transposição: {coords.shape} → (0,3,1,2)") from e
43
+
44
+ # Conversão de coordenadas
45
  coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
46
  coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
47
+
48
+ # Interpolação com JAX
49
+ map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
50
+ order=order,
51
+ mode='nearest')
52
+
53
+ return jax.vmap( # Sobre batches
54
+ jax.vmap( # Sobre canais
55
+ map_coordinates,
56
+ in_axes=(2, None), # (C, H', W'), (B, 2, H, W)
57
+ out_axes=2
58
+ )
59
+ )(grid, coords)