Model does not work with device set to `mps`

#2
by akbir - opened

Inference on M1 GPU (device=mps) does not work.

Shapes end up being wrong here, not sure if this is a pytorch code or bad model implementation.

Can you share a snippet of code to reproduce?

This comment has been hidden
This comment has been hidden
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment