Unable to use this model in RIVA
Hi @naomi-rowden ,
unfortunately, I never used riva before. I only checked inference in colab and in python script - it was working ok.
Nemo version - I do not remember exactly, but 1.20 and 1.21 should work.
You say you have a working example of "good" pair fastpitch+hifigan that work in riva for another language.
I would try to load and compare checkpoints, their layers, shapes and types of "good" models versus my models.
You can do something like
tar xvf tts_ru_ipa_fastpitch_ruslan.nemo # get a number of files packed with the model, among them you see "model_weights.ckpt"
then in python
state_dict = torch.load("model_weights.ckpt")
# then iterate over state_dict.keys() and check types and shapes
def compare_state_dicts(dict1, dict2):
"""
Compare two PyTorch state dictionaries.
Args:
- dict1 (dict): State dictionary 1.
- dict2 (dict): State dictionary 2.
Returns:
- bool: True if the state dictionaries are the same, False otherwise.
"""
# Check if they have the same keys
keys1 = set(dict1.keys())
keys2 = set(dict2.keys())
if keys1 != keys2:
print("Keys mismatch:")
print("Keys in dict1 not in dict2:", keys1 - keys2)
print("Keys in dict2 not in dict1:", keys2 - keys1)
return False
# Check if the values (tensors) are the same
for key in dict1:
if dict1[key].shape != dict2[key].shape:
print(f"Mismatch in shape for key {key}: {dict1[key].shape} vs {dict2[key].shape}")
return False
if dict1[key].dtype != dict2[key].dtype:
print(f"Mismatch in dtype for key {key}: {dict1[key].dtype} vs {dict2[key].dtype}")
return False
Please, let me know if you discover any difference.
Thank you for your suggestion - there is a difference, the output was Mismatch in shape for key fastpitch.encoder.word_emb.weight: torch.Size([89, 384]) vs torch.Size([96, 384])
(your model vs the model I was testing). The test model is nvidia/nemo/tts_en_fastpitch:IPA_1.13.0 from NGC.
Nvidia also got back to me with a solution to this problem which did work if you (or anyone else) wants to use this model in RIVA. The only difference to the steps I posted above is the onnx opset number - it should be 14 instead of 15 (nemo2riva --onnx-opset=14 --out <RIVA PATH> <NEMO PATH> --key tlt_encode
)
Thanks!
Interesting! Thanks for sharing!