File size: 368 Bytes
b925bf3 |
1 2 3 4 5 6 7 8 9 10 |
from whisper_jax import FlaxWhisperForConditionalGeneration
import jax.numpy as jnp
checkpoint_id = "/media/user01/HDWINDOWS/whisper-medium-portuguese"
# convert PyTorch weights to Flax
model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)
# Save the converted Flax model in the same directory
model.save_pretrained(checkpoint_id)
|