File size: 4,077 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import List, Optional
import click
import yaml
from mergekit.config import InputModelDefinition, MergeConfiguration
from mergekit.merge import run_merge
from mergekit.options import MergeOptions, add_merge_options
@click.command("mergekit-legacy")
@click.argument("out_path", type=str)
@click.option(
"--merge", "merge", type=str, multiple=True, help="Add a model to the merge"
)
@click.option(
"--density",
"density",
type=float,
multiple=True,
default=[],
help="Fraction of weights to keep for each model (ties only)",
)
@click.option(
"--weight",
"weight",
type=float,
multiple=True,
default=[],
help="Weighting for a model (default 1.0 for all models if not specified)",
)
@click.option(
"--method", "method", type=str, default="ties", help="Method used to merge models"
)
@click.option(
"--base-model", "base_model", type=str, default=None, help="Base model for merge"
)
@click.option(
"--normalize/--no-normalize",
"normalize",
is_flag=True,
default=True,
help="Divide merged parameters by the sum of weights",
)
@click.option(
"--int8-mask/--no-int8-mask",
"int8_mask",
is_flag=True,
help="Store intermediate masks in int8 to save memory",
)
@click.option("--bf16/--no-bf16", "bf16", is_flag=True, help="Use bfloat16")
@click.option(
"--naive-count/--no-naive-count",
"naive_count",
is_flag=True,
help="Use naive sign count instead of weight (ties only)",
)
@click.option(
"--print-yaml/--no-print-yaml",
"print_yaml",
is_flag=True,
help="Print generated YAML configuration",
)
@add_merge_options
def main(
out_path: str,
merge: List[str],
density: List[float],
weight: List[float],
method: str,
base_model: Optional[str],
normalize: bool,
int8_mask: bool,
bf16: bool,
naive_count: bool,
print_yaml: bool,
merge_options: MergeOptions,
):
"""Wrapper for using a subset of legacy-style script arguments."""
models = [InputModelDefinition(model=model, parameters={}) for model in merge]
if base_model and base_model not in merge:
models.append(InputModelDefinition(model=base_model, parameters={}))
parameters = {}
if density:
if len(density) == 1:
density = [density[0]] * len(models)
for idx, d in enumerate(density):
models[idx].parameters["density"] = d
if method == "slerp":
assert len(weight) == 1, "Must specify exactly one weight for SLERP"
parameters["t"] = weight[0]
else:
if weight:
if len(weight) == 1:
weight = [weight[0]] * len(models)
for idx, w in enumerate(weight):
models[idx].parameters["weight"] = w
if int8_mask:
parameters["int8_mask"] = True
if naive_count:
parameters["consensus_method"] = "count"
parameters["normalize"] = normalize
merge_config = MergeConfiguration(
merge_method=method,
models=models,
parameters=parameters,
base_model=base_model,
dtype="bfloat16" if bf16 else None,
)
if print_yaml:
print(yaml.dump(merge_config.model_dump(mode="json", exclude_none=True)))
run_merge(
merge_config,
out_path,
options=merge_options,
)
if __name__ == "__main__":
main()
|