Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from genie.config import GenieConfig | |
from genie.st_mask_git import STMaskGIT | |
def convert_lightning_checkpoint(lightning_checkpoint, num_layers, num_heads, d_model, save_dir): | |
""" | |
v0.0.1 saved models in Lightning checkpoints, this can convert Lightning checkpoints to HF checkpoints. | |
""" | |
config = GenieConfig(num_layers=num_layers, num_heads=num_heads, d_model=d_model) | |
model = STMaskGIT(config) | |
lightning_checkpoint = torch.load(lightning_checkpoint) | |
model_state_dict = lightning_checkpoint["state_dict"] | |
# Remove `model.` prefix | |
model_state_dict = {name.replace("model.", ""): params for name, params in model_state_dict.items()} | |
model.load_state_dict(model_state_dict) | |
model.save_pretrained(save_dir) | |