File size: 765 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from genie.config import GenieConfig
from genie.st_mask_git import STMaskGIT


def convert_lightning_checkpoint(lightning_checkpoint, num_layers, num_heads, d_model, save_dir):
    """
    v0.0.1 saved models in Lightning checkpoints, this can convert Lightning checkpoints to HF checkpoints.
    """
    config = GenieConfig(num_layers=num_layers, num_heads=num_heads, d_model=d_model)
    model = STMaskGIT(config)

    lightning_checkpoint = torch.load(lightning_checkpoint)
    model_state_dict = lightning_checkpoint["state_dict"]

    # Remove `model.` prefix
    model_state_dict = {name.replace("model.", ""): params for name, params in model_state_dict.items()}
    model.load_state_dict(model_state_dict)
    model.save_pretrained(save_dir)