Extremely slow inference when using jax.jit
2023-07-25 08:41:54.423197: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:
%reduce-window.16 = f32[32,2,128]{2,1,0} reduce-window(f32[1024,64,128]{2,1,0} %constant.964, f32[] %constant.288), window={size=32x32x1 stride=32x32x1}, to_apply=%region_17.1831
This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
I don't know what you meant by creating this discussion.
There is no code and no execution environment here, just a simple error message. Sadly, I'm not omnipotent, so there's nothing I can do with this alone.
We trained on TPU-VM using jax.pmap without problems. Also, I just did a simple test in the WSL2 environment, and jit compilation proceeded without any problems.
If you need help, literally more information please.
def pred_vec( wavPath, vecPath):
feats = readwave(wavPath, normalize=False)
model = FlaxAutoModel.from_pretrained("team-lucid/hubert-large-korean", trust_remote_code=True)
outputs = jax.jit(model,backend='cpu')(feats.cpu().numpy())
vec = outputs.last_hidden_state.squeeze(0)
np.save(vecPath, vec, allow_pickle=False)
using TPU-VM v3-8 and feats.shape = (1,wav_length)
The indent is gone, but your code seems to load the model and try to jit every time you call the function. It is inefficient because the previous time-consuming compilation is erased.
It is recommended to leave it like jit_inference
for already loaded models. Also, there is no reason to use a CPU in an accelerated environment such as TPU or GPU. Of course, all input types must be the same for jit to be effective.
It would be easier to check if more information could be accessed like the Github repository.
You're right, the model runs fast on TPUs using pmap