|
import os |
|
import json |
|
|
|
import torch |
|
import jittor as jt |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
def load_from_torch_shard_ckpt(model, ckpt_dir): |
|
""" |
|
Load sharded checkpoints directly from huggingface dir. |
|
""" |
|
with open(os.path.join(ckpt_dir, 'pytorch_model.bin.index.json')) as fp: |
|
ckpt_index = json.load(fp) |
|
|
|
total_size = ckpt_index['metadata']['total_size'] |
|
weight_map = ckpt_index['weight_map'] |
|
|
|
file_weight_map = {} |
|
for key, value in weight_map.items(): |
|
|
|
if value not in file_weight_map: |
|
file_weight_map[value] = [] |
|
file_weight_map[value].append(key) |
|
|
|
load_from_map(model, ckpt_dir, file_weight_map) |
|
|
|
|
|
def load_from_map(model: jt.Module, ckpt_dir, file_weight_map): |
|
|
|
for filename, names in tqdm(file_weight_map.items()): |
|
cur_state_dict = torch.load(os.path.join(ckpt_dir, filename)) |
|
for key, value in cur_state_dict.items(): |
|
var = jt.Var(value.numpy()) |
|
if value.requires_grad: |
|
var.start_grad() |
|
else: |
|
var.stop_grad() |
|
cur_state_dict[key] = var |
|
|
|
model.load_state_dict(cur_state_dict) |
|
|
|
|
|
del cur_state_dict |
|
jt.sync_all() |
|
jt.gc() |
|
|
|
def check_state_dict(model: jt.Module, ckpt_dir, file_weight_map): |
|
for filename, names in file_weight_map.items(): |
|
cur_state_dict = torch.load(os.path.join(ckpt_dir, filename)) |
|
for name in names: |
|
assert np.equal( |
|
model.state_dict()[name].numpy(), cur_state_dict[name].numpy()).all() |
|
|
|
|
|
del cur_state_dict |
|
jt.sync_all() |
|
jt.gc() |
|
|