File size: 5,694 Bytes
1d30d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format
# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder
#
# TODO: this script is LLM-generated and probably very inefficient and should be rewritten

import torch
import json
import os
import sys
import re

from safetensors.torch import save_file

# default
model_path = './model.pt';

# read from CLI
if len(sys.argv) > 1:
    model_path = sys.argv[1]

# get the directory of the input model
path_dst = os.path.dirname(model_path)

print(f"Loading model from {model_path}")

model = torch.load(model_path, map_location='cpu')

#print(model)

# print all keys
for key in model.keys():
    print(key)
    if key == 'hyper_parameters':
        #print(model[key])
        # dump as json pretty
        print(json.dumps(model[key], indent=4))
    #if key != 'state_dict' and key != 'optimizer_states':
    #    print(model[key])

# Check if the loaded model is a state_dict or a model instance
if isinstance(model, torch.nn.Module):
    state_dict = model.state_dict()
else:
    state_dict = model

# Print the structure of the state_dict to understand its format
print("State dictionary keys:")
for key in state_dict.keys():
    print(key)

# Ensure the state_dict is flat and contains only torch.Tensor objects
def flatten_state_dict(state_dict, parent_key='', sep='.'):
    items = []
    items_new = []

    for k, v in state_dict.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, torch.Tensor):
            items.append((new_key, v))
        elif isinstance(v, dict):
            items.extend(flatten_state_dict(v, new_key, sep=sep).items())
            return dict(items)

    size_total_mb = 0

    for key, value in list(items):
        # keep only what we need for inference
        if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
           not key.startswith('state_dict.backbone.') and \
           not key.startswith('state_dict.head.out'):
               print('Skipping key: ', key)
               continue

        new_key = key

        new_key = new_key.replace('state_dict.', '')
        new_key = new_key.replace('pos_net', 'posnet')

        # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight"
        if new_key.startswith("backbone.posnet."):
            match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key)
            if match:
               new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}"

        # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
        if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
            new_key = "backbone.embedding.weight"

        # these are the only rows used
        # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
        if new_key.endswith("norm.scale.weight"):
            new_key = new_key.replace("norm.scale.weight", "norm.weight")
            value = value[0]

        if new_key.endswith("norm.shift.weight"):
            new_key = new_key.replace("norm.shift.weight", "norm.bias")
            value = value[0]

        if new_key.endswith("gamma"):
            new_key = new_key.replace("gamma", "gamma.weight")

        # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias
        if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")):
            value = value.unsqueeze(1)

        if new_key.endswith("dwconv.bias"):
            value = value.unsqueeze(1)

        size_mb = value.element_size() * value.nelement() / (1024 * 1024)
        print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")

        size_total_mb += size_mb

        #print(key, '->', new_key, ': ', value)
        #print(key, '->', new_key)

        items_new.append((new_key, value))

    print(f"Total size: {size_total_mb:8.2f} MB")

    return dict(items_new)

flattened_state_dict = flatten_state_dict(state_dict)


# Convert the model to the safetensors format
output_path = path_dst + '/model.safetensors'
save_file(flattened_state_dict, output_path)

print(f"Model has been successfully converted and saved to {output_path}")

# Calculate the total size of the .safetensors file
total_size = os.path.getsize(output_path)

# Create the weight map
weight_map = {
    "model.safetensors": ["*"]  # Assuming all weights are in one file
}

# Create metadata for the index.json file
metadata = {
    "total_size": total_size,
    "weight_map": weight_map
}

# Save the metadata to index.json
index_path = path_dst + '/index.json'
with open(index_path, 'w') as f:
    json.dump(metadata, f, indent=4)

print(f"Metadata has been saved to {index_path}")

config = {
    "architectures": [
        "WavTokenizerDec"
    ],
    "hidden_size": 1282, # or 2402 for 40t/s
    "n_embd_features": 512,
    "n_ff": 2304,
    "vocab_size": 4096,
    "n_head": 1,
    "layer_norm_epsilon": 1e-6,
    "group_norm_epsilon": 1e-6,
    "group_norm_groups": 32,
    "max_position_embeddings": 8192, # ?
    "n_layer": 12,
    "posnet": {
        "n_embd": 768,
        "n_layer": 6
    },
    "convnext": {
        "n_embd": 768,
        "n_layer": 12
    },
}

with open(path_dst + '/config.json', 'w') as f:
    json.dump(config, f, indent=4)

print(f"Config has been saved to {path_dst + 'config.json'}")