ds1david commited on
Commit
e0956f1
·
1 Parent(s): 19a6d73

fixing bugs

Browse files
Files changed (1) hide show
  1. utils.py +8 -0
utils.py CHANGED
@@ -5,6 +5,14 @@ import numpy as np
5
  from functools import partial
6
 
7
 
 
 
 
 
 
 
 
 
8
  def make_grid(patch_size: int | tuple[int, int]):
9
  """Gera grid de coordenadas com validação robusta"""
10
  if isinstance(patch_size, int):
 
5
  from functools import partial
6
 
7
 
8
+ def repeat_vmap(fun, in_axes=None):
9
+ if in_axes is None:
10
+ in_axes = [0]
11
+ for axes in in_axes:
12
+ fun = jax.vmap(fun, in_axes=axes)
13
+ return fun
14
+
15
+
16
  def make_grid(patch_size: int | tuple[int, int]):
17
  """Gera grid de coordenadas com validação robusta"""
18
  if isinstance(patch_size, int):