theblackcat102's picture
Create merging/merge.py
465eb1b verified
raw
history blame
13.4 kB
import json
import torch
from safetensors import safe_open
from safetensors.torch import save_file, load_file
from kernel import weight_dequant
TGT_PATH = "PATH_TO_Deepseek-v3-dense-model"
SRC_PATH = "PATH_TO_Deepseek-v3-BASE"
with open(f'{TGT_PATH}/model.safetensors.index.json','r') as f:
dense_large_index = json.load(f)
with open(f'{SRC_PATH}/model.safetensors.index.json','r') as f:
model_index = json.load(f)
def init_non_experts_weights():
updated_cnt = 0
updated = []
for k, filename in dense_large_index['weight_map'].items():
if k in model_index['weight_map']:
tgt_safe_tensors = "/PATH_TO_BASE/"+filename
tensors = load_file(tgt_safe_tensors)
initial_size = len(tensors)
src_safe_tensors = SRC_PATH+model_index['weight_map'][k]
if k+'_scale_inv' not in model_index['weight_map']:
with safe_open(src_safe_tensors, framework="pt", device='cpu') as f:
tensors[k] = f.get_tensor(k).bfloat16()
else:
print(k, 'scale_inv')
with safe_open(src_safe_tensors, framework="pt", device='cuda') as f:
weight = f.get_tensor(k)
src_scale_inv_safe_tensors = SRC_PATH+model_index['weight_map'][k+'_scale_inv']
with safe_open(src_scale_inv_safe_tensors, framework="pt", device='cuda') as f:
scale_inv = f.get_tensor(k+'_scale_inv')
dequant_tensor = weight_dequant(weight.bfloat16(), scale_inv)
tensors[k] = dequant_tensor
updated_cnt += 1
assert initial_size == len(tensors)
save_file(tensors, tgt_safe_tensors, metadata={'format': 'pt'})
updated.append(k)
def get_adjacent_filenames(filename):
# Extract the current number
import re
current_num = int(re.search(r'(\d+)-of-', filename).group(1))
total_files = int(re.search(r'-of-(\d+)', filename).group(1))
# Get padding length from original filename (5 in this case)
padding = len(re.search(r'(\d+)-of-', filename).group(1))
# Generate previous number (wrap around to end if at start)
prev_num = total_files if current_num == 0 else current_num - 1
# Generate next number (wrap around to 0 if at end)
next_num = 0 if current_num == total_files else current_num + 1
# Create the filename pattern
base = filename.split('-of-')[0].rsplit('-', 1)[0]
ext = filename.split('.')[-1]
# Format the filenames using the same padding
prev_file = f"{base}-{str(prev_num).zfill(padding)}-of-000163.{ext}"
next_file = f"{base}-{str(next_num).zfill(padding)}-of-000163.{ext}"
return prev_file, next_file
def get_safetensors_mapping(target_layer):
gate_name = {}
share_experts_name = None
experts = {}
pre_mlp_norm = {}
for key_name, filename in model_index['weight_map'].items():
try:
layer_idx = key_name.split('.')[2]
layer_idx = int(layer_idx)
if layer_idx == target_layer:
if 'self_attn' in key_name:
pre_mlp_norm[key_name] = filename
elif 'input_layernorm' in key_name or 'post_attention_layernorm' in key_name:
pre_mlp_norm[key_name] = filename
elif '.gate.' in key_name:
gate_name[key_name] = filename
elif 'shared_experts' in key_name:
share_experts_name = filename
elif 'experts' in key_name:
expert_num = int(key_name.split('.')[5])
experts[expert_num] = filename
except (ValueError, IndexError):
continue
return {
'pre_mlp_keys': pre_mlp_norm,
'gate_safetensors': gate_name,
'share_expert_safetensors': share_experts_name,
'experts_safetensors': experts
}
def load_related_tensors(mapping):
tensors = {}
for key, filename in mapping.items():
with safe_open("/mnt/ssd/DeepSeek-V3-Base/"+filename, framework="pt", device='cpu') as f:
tensors[key] = f.get_tensor(key)
return tensors
def load_experts_weights(experts_safetensors_map, expert_range=[], target_layer=-1):
tensors = {}
expert_ids_matched = {}
for expert_id, safe_tensor_file in experts_safetensors_map.items():
if expert_id not in expert_range:
continue
tgt_safe_tensors = SRC_PATH+safe_tensor_file
matched = 0
with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
for k in f.keys():
if 'experts' not in k or 'shared_experts' in k:
continue
layer_idx = k.split('.')[2]
layer_idx = int(layer_idx)
expert_idx = int(k.split('.')[5])
if expert_idx in expert_range:
if expert_idx not in expert_ids_matched:
expert_ids_matched[expert_idx] = {}
tensors[k] = f.get_tensor(k)
matched += 1
postfix = '.'.join(k.split('.')[6:])
expert_ids_matched[expert_idx][postfix] = 1
for expert_id, keys in expert_ids_matched.items():
if len(keys) != 6:
original_src = experts_safetensors_map[expert_id]
prev_filename, next_filename = get_adjacent_filenames(original_src)
prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
for _filename in [prev_filename, next_filename, prev_prev_filename]:
with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
for k in f.keys():
if 'experts' not in k or 'shared_experts' in k:
continue
layer_idx = k.split('.')[2]
layer_idx = int(layer_idx)
expert_idx = int(k.split('.')[5])
if expert_idx == expert_id:
tensors[k] = f.get_tensor(k)
matched += 1
postfix = '.'.join(k.split('.')[6:])
expert_ids_matched[expert_idx][postfix] = 1
return tensors
def load_shared_experts_weights(safe_tensor_file, target_layer=-1):
tgt_safe_tensors = SRC_PATH+safe_tensor_file
tensors = {}
with safe_open(tgt_safe_tensors, framework="pt", device='cpu') as f:
for k in f.keys():
if 'shared_experts' in k:
tensors[k] = f.get_tensor(k)
if len(tensors) <= 1:
prev_filename, next_filename = get_adjacent_filenames(safe_tensor_file)
prev_prev_filename, _ = get_adjacent_filenames(prev_filename)
for _filename in [prev_filename, next_filename, prev_prev_filename]:
with safe_open(SRC_PATH+_filename, framework="pt", device='cpu') as f:
for k in f.keys():
if 'shared_experts' not in k:
continue
layer_idx = k.split('.')[2]
layer_idx = int(layer_idx)
if target_layer == layer_idx:
tensors[k] = f.get_tensor(k)
return tensors
if __name__ == "__main__":
init_non_experts_weights()
expert_ranges = [
list(range(0, 256//8)), # 0-31
list(range(32, 2*256//8)), # 32-63
list(range(64, 3*256//8)), # 64-95
list(range(96, 4*256//8)), # 96-127
list(range(128, 5*256//8)), # 128-159
list(range(160, 6*256//8)), # 160-191
list(range(192, 7*256//8)), # 192-223
list(range(224, 256)), # 224-255
]
for target_layer in range(3, 62):
result = get_safetensors_mapping(target_layer)
if len(result['experts_safetensors']) == 0:
print('empty at ', target_layer)
final_up_proj = []
final_gate_proj = []
final_down_proj = []
for expert_range in expert_ranges:
experts_weights = load_experts_weights(result['experts_safetensors'], expert_range, target_layer)
new_state_dict = {}
for weight_name, weight in experts_weights.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = experts_weights[scale_inv_name]
new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
else:
new_state_dict[weight_name] = weight
up_proj, gate_proj, down_proj = [], [], []
for expert_id in expert_range:
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.up_proj.weight'
up_proj.append(new_state_dict[key])
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.gate_proj.weight'
gate_proj.append(new_state_dict[key])
key = f'model.layers.{target_layer}.mlp.experts.{expert_id}.down_proj.weight'
down_proj.append(new_state_dict[key])
avg_up_proj = torch.mean(torch.stack(up_proj, dim=0), dim=0)
avg_gate_proj = torch.mean(torch.stack(gate_proj, dim=0), dim=0)
avg_down_proj = torch.mean(torch.stack(down_proj, dim=0), dim=0)
final_up_proj.append(avg_up_proj)
final_gate_proj.append(avg_gate_proj)
final_down_proj.append(avg_down_proj)
# append the final shared experts
shared_experts_weight = load_shared_experts_weights(result['share_expert_safetensors'], target_layer)
new_state_dict = {}
for weight_name, weight in shared_experts_weight.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = shared_experts_weight[scale_inv_name]
new_state_dict[weight_name] = weight_dequant(weight.bfloat16().cuda(), scale_inv.cuda()).cpu()
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
else:
new_state_dict[weight_name] = weight
key = f'model.layers.{target_layer}.mlp.shared_experts.up_proj.weight'
final_up_proj.append(new_state_dict[key])
key = f'model.layers.{target_layer}.mlp.shared_experts.gate_proj.weight'
final_gate_proj.append(new_state_dict[key])
key = f'model.layers.{target_layer}.mlp.shared_experts.down_proj.weight'
final_down_proj.append(new_state_dict[key])
dense_up_proj = torch.concatenate(final_up_proj, dim=0)
dense_gate_proj = torch.concatenate(final_gate_proj, dim=0)
dense_down_proj = torch.concatenate([ t.T for t in final_down_proj], dim=0).T.contiguous()
assert dense_down_proj.shape[1] == 18432
assert dense_gate_proj.shape[0] == 18432
assert dense_up_proj.shape[0] == 18432
# GATE PROJ
key = f"model.layers.{target_layer}.mlp.gate_proj.weight"
target_safetensors = dense_large_index['weight_map'][key]
tensors = load_file(TGT_PATH+target_safetensors)
print(len(tensors))
assert tensors[key].shape == dense_gate_proj.shape
tensors[key] = dense_gate_proj.bfloat16()
print(len(tensors), TGT_PATH+target_safetensors)
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
# UP PROJ
key = f"model.layers.{target_layer}.mlp.up_proj.weight"
target_safetensors = dense_large_index['weight_map'][key]
tensors = load_file(TGT_PATH+target_safetensors)
assert tensors[key].shape == dense_up_proj.shape
tensors[key] = dense_up_proj.bfloat16()
print(len(tensors), TGT_PATH+target_safetensors)
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})
# DOWN PROJ
key = f"model.layers.{target_layer}.mlp.down_proj.weight"
target_safetensors = dense_large_index['weight_map'][key]
tensors = load_file(TGT_PATH+target_safetensors)
assert tensors[key].shape == dense_down_proj.shape
print(len(tensors), TGT_PATH+target_safetensors)
tensors[key] = dense_down_proj.bfloat16()
save_file(tensors, TGT_PATH+target_safetensors, metadata={'format': 'pt'})