File size: 472 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from tqdm import tqdm
import torchaudio
from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config
import numpy as np
import os
import json

def get_model(model_config, path):
    with open(model_config) as f:
        model_config = json.load(f)
    state_dict = torch.load(path)
    model = create_autoencoder_from_config(model_config)
    model.load_state_dict(state_dict['state_dict'])
    return model