File size: 368 Bytes
b925bf3
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from whisper_jax import FlaxWhisperForConditionalGeneration
import jax.numpy as jnp

checkpoint_id = "/media/user01/HDWINDOWS/whisper-medium-portuguese"
# convert PyTorch weights to Flax
model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)

# Save the converted Flax model in the same directory
model.save_pretrained(checkpoint_id)