bababababooey commited on
Commit
d7e01c8
1 Parent(s): ffd11f3

Upload 32to31.py

Browse files
Files changed (1) hide show
  1. swapper/32to31.py +182 -0
swapper/32to31.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from huggingface_hub import snapshot_download
5
+
6
+ import torch
7
+ from safetensors import safe_open
8
+ from transformers import AutoProcessor, MllamaForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, AutoConfig
9
+
10
+ #total_layers=80 # 70B model has 80 layers
11
+ total_layers=32 # 8B model has 32 layers
12
+
13
+ #cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98] # 90B
14
+ cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38] # 11b
15
+
16
+ # Update paths - switch source and target
17
+ target_model = "meta-llama/Llama-3.1-8B-Instruct"
18
+ print(f"Target model: {target_model}")
19
+
20
+ source_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
21
+ print(f"Source model: {source_model}")
22
+
23
+ def create_inverse_layer_mapping(total_layers=total_layers, cross_attn_layers=cross_attention_layers):
24
+ """
25
+ Creates a mapping from 90B/11B layer indices to 70B/8B layer indices.
26
+ """
27
+ mapping = {}
28
+ removed_layers = []
29
+
30
+ #for i in range(100): # 90B has 100 layers (80 + 20 cross-attention layers)
31
+ for i in range(40): # 11B has 40 layers (32 + 8 cross-attention layers)
32
+ if i not in cross_attn_layers and len(mapping) < total_layers:
33
+ mapping[i] = len(mapping)
34
+ else:
35
+ removed_layers.append(i)
36
+ return mapping, removed_layers
37
+
38
+ def load_sharded_state_dict(model_id):
39
+ """
40
+ Load a sharded state dict from either a local directory or a Hugging Face model ID.
41
+
42
+ Args:
43
+ model_id: Either a local path or a Hugging Face model ID (e.g., "meta-llama/Llama-2-7b")
44
+
45
+ Returns:
46
+ dict: The loaded state dictionary
47
+ """
48
+ # Check if model_id is a local path
49
+ if os.path.isdir(model_id):
50
+ model_dir = model_id
51
+ else:
52
+ # If not local, assume it's a Hugging Face model ID and download it
53
+ print(f"Downloading model from Hugging Face: {model_id}")
54
+ model_dir = snapshot_download(
55
+ model_id,
56
+ allow_patterns=["*.safetensors*", "*.json"],
57
+ ignore_patterns=["*.bin", "*.md", "*.py"]
58
+ )
59
+
60
+ # Load the index file
61
+ index_file = os.path.join(model_dir, 'model.safetensors.index.json')
62
+ if not os.path.exists(index_file):
63
+ raise FileNotFoundError(f"Could not find index file: {index_file}")
64
+
65
+ with open(index_file, 'r') as f:
66
+ index_data = json.load(f)
67
+
68
+ weight_map = index_data['weight_map']
69
+ state_dict = {}
70
+ shard_to_params = {}
71
+
72
+ # Group parameters by shard file
73
+ for param_name, shard_file in weight_map.items():
74
+ if shard_file not in shard_to_params:
75
+ shard_to_params[shard_file] = []
76
+ shard_to_params[shard_file].append(param_name)
77
+
78
+ # Load parameters from each shard
79
+ for shard_file, params_in_shard in shard_to_params.items():
80
+ shard_path = os.path.join(model_dir, shard_file)
81
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
82
+ for name in params_in_shard:
83
+ state_dict[name] = f.get_tensor(name)
84
+
85
+ return state_dict
86
+
87
+ def compare_model_states(model, new_state_dict):
88
+ current_state = model.state_dict()
89
+ unchanged_params = []
90
+ changed_params = []
91
+ missing_params = []
92
+
93
+ for name, param in current_state.items():
94
+ if name not in new_state_dict:
95
+ missing_params.append(name)
96
+ elif torch.equal(param, new_state_dict[name]):
97
+ unchanged_params.append(name)
98
+ else:
99
+ sum_abs_diff = torch.sum(torch.abs(param - new_state_dict[name]))
100
+ changed_params.append({'name': name, 'sum_abs_diff': sum_abs_diff.item()})
101
+
102
+ return {
103
+ 'unchanged': unchanged_params,
104
+ 'changed': changed_params,
105
+ 'missing': missing_params
106
+ }
107
+
108
+
109
+ layer_mapping, removed_layers = create_inverse_layer_mapping()
110
+
111
+ # Load source (90B) state dict
112
+ source_state_dict = load_sharded_state_dict(source_model)
113
+
114
+ # Create new state dict for target model (70B)
115
+ target_state_dict = {}
116
+
117
+ # Convert parameter names and copy tensors
118
+ for name, param in source_state_dict.items():
119
+ # Skip parameters that aren't part of the language model layers
120
+ if not (name.startswith('language_model.model.layers.') or
121
+ name == 'language_model.model.embed_tokens.weight' or
122
+ name == 'language_model.lm_head.weight' or
123
+ name == 'language_model.model.norm.weight'):
124
+ continue
125
+
126
+ if name.startswith('language_model.model.layers.'):
127
+ # Handle layer parameters
128
+ layer_match = re.match(r'language_model\.model\.layers\.(\d+)\.(.+)', name)
129
+ if layer_match:
130
+ source_layer = int(layer_match.group(1))
131
+ if source_layer in layer_mapping:
132
+ target_layer = layer_mapping[source_layer]
133
+ new_name = f'model.layers.{target_layer}.{layer_match.group(2)}'
134
+ target_state_dict[new_name] = param
135
+ elif name == 'language_model.lm_head.weight':
136
+ # Handle lm_head weight
137
+ target_state_dict['lm_head.weight'] = param
138
+ elif name == 'language_model.model.embed_tokens.weight':
139
+ # Handle embeddings - keep original vocab size for 70B model
140
+ original_embed_size = 128256
141
+ target_state_dict['model.embed_tokens.weight'] = param[:original_embed_size, :]
142
+ elif name == 'language_model.model.norm.weight':
143
+ # Handle model norm weight
144
+ target_state_dict['model.norm.weight'] = param
145
+
146
+
147
+ #write target_state_dict keys to file for verification
148
+ with open('target_state_dict.txt', 'w') as f:
149
+ f.write('\n'.join(target_state_dict.keys()))
150
+
151
+ config = AutoConfig.from_pretrained(target_model)
152
+
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ None,
155
+ config=config,
156
+ state_dict = target_state_dict,
157
+ torch_dtype=torch.bfloat16,
158
+ )
159
+
160
+ '''
161
+
162
+ origmodel = AutoModelForCausalLM.from_pretrained(
163
+ target_model,
164
+ torch_dtype=torch.bfloat16,
165
+ )
166
+
167
+ result = compare_model_states(model, origmodel.state_dict())
168
+ print("Unchanged parameters:", len(result['unchanged']))
169
+ print("Changed parameters:", len(result['changed']))
170
+ print("Missing parameters:", len(result['missing']))
171
+
172
+ #write result to file
173
+ with open('result.txt', 'w') as f:
174
+ f.write(json.dumps(result, indent=2))
175
+
176
+ '''
177
+
178
+ processor = AutoTokenizer.from_pretrained(target_model) #8b/70b
179
+ #processor = AutoProcessor.from_pretrained(source_model) #11b/90b
180
+
181
+ model.save_pretrained("Llama-3.2-8B-extracted")
182
+ processor.save_pretrained("Llama-3.2-8B-extracted")