File size: 4,077 Bytes
a164e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from typing import List, Optional

import click
import yaml

from mergekit.config import InputModelDefinition, MergeConfiguration
from mergekit.merge import run_merge
from mergekit.options import MergeOptions, add_merge_options


@click.command("mergekit-legacy")
@click.argument("out_path", type=str)
@click.option(
    "--merge", "merge", type=str, multiple=True, help="Add a model to the merge"
)
@click.option(
    "--density",
    "density",
    type=float,
    multiple=True,
    default=[],
    help="Fraction of weights to keep for each model (ties only)",
)
@click.option(
    "--weight",
    "weight",
    type=float,
    multiple=True,
    default=[],
    help="Weighting for a model (default 1.0 for all models if not specified)",
)
@click.option(
    "--method", "method", type=str, default="ties", help="Method used to merge models"
)
@click.option(
    "--base-model", "base_model", type=str, default=None, help="Base model for merge"
)
@click.option(
    "--normalize/--no-normalize",
    "normalize",
    is_flag=True,
    default=True,
    help="Divide merged parameters by the sum of weights",
)
@click.option(
    "--int8-mask/--no-int8-mask",
    "int8_mask",
    is_flag=True,
    help="Store intermediate masks in int8 to save memory",
)
@click.option("--bf16/--no-bf16", "bf16", is_flag=True, help="Use bfloat16")
@click.option(
    "--naive-count/--no-naive-count",
    "naive_count",
    is_flag=True,
    help="Use naive sign count instead of weight (ties only)",
)
@click.option(
    "--print-yaml/--no-print-yaml",
    "print_yaml",
    is_flag=True,
    help="Print generated YAML configuration",
)
@add_merge_options
def main(
    out_path: str,
    merge: List[str],
    density: List[float],
    weight: List[float],
    method: str,
    base_model: Optional[str],
    normalize: bool,
    int8_mask: bool,
    bf16: bool,
    naive_count: bool,
    print_yaml: bool,
    merge_options: MergeOptions,
):
    """Wrapper for using a subset of legacy-style script arguments."""
    models = [InputModelDefinition(model=model, parameters={}) for model in merge]
    if base_model and base_model not in merge:
        models.append(InputModelDefinition(model=base_model, parameters={}))

    parameters = {}

    if density:
        if len(density) == 1:
            density = [density[0]] * len(models)
        for idx, d in enumerate(density):
            models[idx].parameters["density"] = d

    if method == "slerp":
        assert len(weight) == 1, "Must specify exactly one weight for SLERP"
        parameters["t"] = weight[0]
    else:
        if weight:
            if len(weight) == 1:
                weight = [weight[0]] * len(models)
            for idx, w in enumerate(weight):
                models[idx].parameters["weight"] = w

    if int8_mask:
        parameters["int8_mask"] = True
    if naive_count:
        parameters["consensus_method"] = "count"
    parameters["normalize"] = normalize

    merge_config = MergeConfiguration(
        merge_method=method,
        models=models,
        parameters=parameters,
        base_model=base_model,
        dtype="bfloat16" if bf16 else None,
    )

    if print_yaml:
        print(yaml.dump(merge_config.model_dump(mode="json", exclude_none=True)))

    run_merge(
        merge_config,
        out_path,
        options=merge_options,
    )


if __name__ == "__main__":
    main()