File size: 2,026 Bytes
19a6d73 1eb87a5 d160dc6 1eb87a5 19a6d73 d85fde4 19a6d73 d85fde4 19a6d73 054a11a 19a6d73 d160dc6 19a6d73 d160dc6 19a6d73 054a11a d85fde4 19a6d73 d160dc6 19a6d73 d160dc6 19a6d73 054a11a 19a6d73 054a11a 4a3fe77 d160dc6 4a3fe77 054a11a 19a6d73 d160dc6 054a11a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# utils.py
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
def make_grid(patch_size: int | tuple[int, int]):
"""Gera grid de coordenadas com validação robusta"""
if isinstance(patch_size, int):
h = w = max(16, patch_size) # Novo mínimo seguro
else:
h, w = (max(16, ps) for ps in patch_size) # 16x16 mínimo
# Cálculo preciso das coordenadas
y_coords = np.linspace(-0.5 + 1 / (2 * h), 0.5 - 1 / (2 * h), h)
x_coords = np.linspace(-0.5 + 1 / (2 * w), 0.5 - 1 / (2 * w), w)
# Grid com dimensões (1, H, W, 2)
grid = np.stack(np.meshgrid(y_coords, x_coords, indexing='ij'), axis=-1)
return grid[np.newaxis, ...]
def interpolate_grid(coords, grid, order=0):
"""Interpolação com tratamento completo de dimensões"""
try:
# Converter e garantir 4D
coords = jnp.asarray(coords)
if coords.ndim == 1: # Caso de erro reportado
coords = coords.reshape(1, 1, 1, -1)
while coords.ndim < 4:
coords = coords[jnp.newaxis, ...]
# Validação final
if coords.shape[-1] != 2 or coords.ndim != 4:
raise ValueError(
f"Dimensões inválidas: {coords.shape}. Formato esperado: (B, H, W, 2)"
)
# Transformação de coordenadas
coords = coords.transpose((0, 3, 1, 2))
coords = coords.at[:, 0].set(
coords[:, 0] * (grid.shape[-3] - 1) + (grid.shape[-3] - 1) / 2
)
coords = coords.at[:, 1].set(
coords[:, 1] * (grid.shape[-2] - 1) + (grid.shape[-2] - 1) / 2
)
# Interpolação vetorizada
map_coordinates = partial(jax.scipy.ndimage.map_coordinates,
order=order,
mode='nearest')
return jax.vmap(jax.vmap(map_coordinates, in_axes=(2, None), out_axes=2))(grid, coords)
except Exception as e:
raise RuntimeError(f"Erro de interpolação: {str(e)}") from e |