File size: 289 Bytes
8c02843
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import os

def get_checkpoint_path_from_dir(checkpoint_dir):
    checkpoint_path = None
    for file in os.listdir(checkpoint_dir):
        if "ckpt" in file:
            checkpoint_path = os.path.join(checkpoint_dir, file)
    assert checkpoint_path is not None
    return checkpoint_path