QuietImpostor commited on
Commit
259c99a
1 Parent(s): 6a172c6

Upload conversion script

Browse files

Adding the script used to convert Gemini Nano to Gemma weights as requested by ethanc8.

Files changed (1) hide show
  1. gemmafy_gemini.py +108 -0
gemmafy_gemini.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
+ from safetensors import safe_open
4
+ from safetensors.torch import save_file
5
+ from huggingface_hub import hf_hub_download, login
6
+ import os
7
+ from tqdm import tqdm
8
+
9
+ def load_gemini_weights(repo_id):
10
+ print("Downloading Gemini Nano weights...")
11
+ gemini_model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
12
+ return gemini_model_path
13
+
14
+ def adapt_gemini_to_gemma(gemini_path, output_path, custom_config):
15
+ print("Adapting Gemini weights to Gemma format...")
16
+
17
+ with safe_open(gemini_path, framework="pt", device="cpu") as f:
18
+ gemini_keys = list(f.keys())
19
+
20
+ # Process embedding layer
21
+ embed_weight = f.get_tensor('model.embed_tokens.weight')
22
+ vocab_size = embed_weight.size(0) // 2048
23
+ embed_weight = embed_weight[:vocab_size * 2048].view(vocab_size, 2048)
24
+
25
+ adapted_weights = {'model.embed_tokens.weight': embed_weight,
26
+ 'lm_head.weight': embed_weight.clone()}
27
+
28
+ # Process other layers
29
+ for key in tqdm(gemini_keys, desc="Processing layers"):
30
+ if key.startswith('model.layers.'):
31
+ parts = key.split('.')
32
+ layer_num = int(parts[2])
33
+ if layer_num >= custom_config.num_hidden_layers:
34
+ continue
35
+
36
+ weight = f.get_tensor(key)
37
+ adapted_weights[key] = weight
38
+ elif key == 'model.norm.weight':
39
+ adapted_weights[key] = f.get_tensor(key)
40
+
41
+ # Save adapted weights with metadata
42
+ print("Saving adapted weights...")
43
+ metadata = {"format": "pt"}
44
+ save_file(adapted_weights, output_path, metadata=metadata)
45
+ return adapted_weights
46
+
47
+ def create_custom_gemma_config(gemma_repo, gemini_path):
48
+ with safe_open(gemini_path, framework="pt", device="cpu") as f:
49
+ embed_weight = f.get_tensor('model.embed_tokens.weight')
50
+ ffn_weight = f.get_tensor('model.layers.0.mlp.gate_proj.weight')
51
+
52
+ custom_config = AutoConfig.from_pretrained(gemma_repo)
53
+ custom_config.vocab_size = embed_weight.size(0) // 2048
54
+ custom_config.intermediate_size = ffn_weight.size(0)
55
+ custom_config.num_hidden_layers = 32 # Assuming Gemini Nano has 32 layers
56
+ return custom_config
57
+
58
+ def test_model(model_path, tokenizer, prompt, max_length=50):
59
+ print("Testing the adapted model...")
60
+ # Load the model in 8-bit quantization
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ model_path,
63
+ device_map="auto",
64
+ load_in_8bit=True,
65
+ ignore_mismatched_sizes=True
66
+ )
67
+
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
+ with torch.no_grad():
70
+ outputs = model.generate(**inputs, max_length=max_length)
71
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
+ return generated_text
73
+
74
+ def main():
75
+ gemini_repo = "QuietImpostor/Gemini-Nano-Safetensors"
76
+ gemma_repo = "google/gemma-2b-it"
77
+ output_path = "/kaggle/temp/gemini-gemmafied/"
78
+ login(token="hf_...")
79
+
80
+ # Load Gemini Nano weights
81
+ gemini_path = load_gemini_weights(gemini_repo)
82
+
83
+ # Create custom Gemma config
84
+ custom_config = create_custom_gemma_config(gemma_repo, gemini_path)
85
+
86
+ # Adapt Gemini weights to Gemma format
87
+ adapted_weights_path = os.path.join(output_path, "model.safetensors")
88
+ os.makedirs(output_path, exist_ok=True)
89
+ adapt_gemini_to_gemma(gemini_path, adapted_weights_path, custom_config)
90
+
91
+ # Save the custom config
92
+ custom_config.save_pretrained(output_path)
93
+
94
+ # Load Gemini Nano tokenizer and save it
95
+ print("Saving tokenizer...")
96
+ tokenizer = AutoTokenizer.from_pretrained(gemini_repo)
97
+ tokenizer.save_pretrained(output_path)
98
+
99
+ print("Adaptation complete!")
100
+
101
+ # Test the model
102
+ prompt = "The future of artificial intelligence is"
103
+ generated_text = test_model(output_path, tokenizer, prompt)
104
+ print(f"Prompt: {prompt}")
105
+ print(f"Generated text: {generated_text}")
106
+
107
+ if __name__ == "__main__":
108
+ main()