fixing bugs
Browse files- requirements.txt +1 -1
- utils.py +34 -11
requirements.txt
CHANGED
@@ -25,7 +25,7 @@ opt-einsum==3.3.0
|
|
25 |
optax==0.2.0
|
26 |
orbax-checkpoint==0.2.4
|
27 |
peft
|
28 |
-
|
29 |
scipy==1.10.1
|
30 |
timm==0.9.6
|
31 |
torch
|
|
|
25 |
optax==0.2.0
|
26 |
orbax-checkpoint==0.2.4
|
27 |
peft
|
28 |
+
Pillow==10.0.0
|
29 |
scipy==1.10.1
|
30 |
timm==0.9.6
|
31 |
torch
|
utils.py
CHANGED
@@ -4,7 +4,9 @@ import jax
|
|
4 |
import numpy as np
|
5 |
|
6 |
|
7 |
-
def repeat_vmap(fun, in_axes=
|
|
|
|
|
8 |
for axes in in_axes:
|
9 |
fun = jax.vmap(fun, in_axes=axes)
|
10 |
return fun
|
@@ -21,16 +23,37 @@ def make_grid(patch_size: int | tuple[int, int]):
|
|
21 |
|
22 |
def interpolate_grid(coords, grid, order=0):
|
23 |
"""
|
24 |
-
|
25 |
-
coords: Tensor
|
26 |
-
grid: Tensor
|
27 |
-
returns:
|
28 |
-
Tensor of shape (B, H, W, C) with interpolated values
|
29 |
"""
|
30 |
-
#
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
|
34 |
coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
|
7 |
+
def repeat_vmap(fun, in_axes=None):
|
8 |
+
if in_axes is None:
|
9 |
+
in_axes = [0]
|
10 |
for axes in in_axes:
|
11 |
fun = jax.vmap(fun, in_axes=axes)
|
12 |
return fun
|
|
|
23 |
|
24 |
def interpolate_grid(coords, grid, order=0):
|
25 |
"""
|
26 |
+
Args:
|
27 |
+
coords: Tensor de shape (B, H, W, 2) ou (H, W, 2)
|
28 |
+
grid: Tensor de shape (B, H', W', C)
|
|
|
|
|
29 |
"""
|
30 |
+
# Adicionar dimensão de batch se necessário
|
31 |
+
if coords.ndim == 3:
|
32 |
+
coords = coords[np.newaxis, ...]
|
33 |
+
|
34 |
+
# Verificar dimensões
|
35 |
+
assert coords.ndim == 4, f"Dimensões inválidas para coords: {coords.shape}"
|
36 |
+
assert grid.ndim == 4, f"Dimensões inválidas para grid: {grid.shape}"
|
37 |
+
|
38 |
+
# Ajustar transposição de forma segura
|
39 |
+
try:
|
40 |
+
coords = coords.transpose((0, 3, 1, 2))
|
41 |
+
except ValueError as e:
|
42 |
+
raise ValueError(f"Falha na transposição: {coords.shape} → (0,3,1,2)") from e
|
43 |
+
|
44 |
+
# Conversão de coordenadas
|
45 |
coords = coords.at[:, 0].set(coords[:, 0] * grid.shape[-3] + (grid.shape[-3] - 1) / 2)
|
46 |
coords = coords.at[:, 1].set(coords[:, 1] * grid.shape[-2] + (grid.shape[-2] - 1) / 2)
|
47 |
+
|
48 |
+
# Interpolação com JAX
|
49 |
+
map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
|
50 |
+
order=order,
|
51 |
+
mode='nearest')
|
52 |
+
|
53 |
+
return jax.vmap( # Sobre batches
|
54 |
+
jax.vmap( # Sobre canais
|
55 |
+
map_coordinates,
|
56 |
+
in_axes=(2, None), # (C, H', W'), (B, 2, H, W)
|
57 |
+
out_axes=2
|
58 |
+
)
|
59 |
+
)(grid, coords)
|