Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import List, Optional
import click
import yaml
from mergekit.config import InputModelDefinition, MergeConfiguration
from mergekit.merge import run_merge
from mergekit.options import MergeOptions, add_merge_options
@click.command("mergekit-legacy")
@click.argument("out_path", type=str)
@click.option(
"--merge", "merge", type=str, multiple=True, help="Add a model to the merge"
)
@click.option(
"--density",
"density",
type=float,
multiple=True,
default=[],
help="Fraction of weights to keep for each model (ties only)",
)
@click.option(
"--weight",
"weight",
type=float,
multiple=True,
default=[],
help="Weighting for a model (default 1.0 for all models if not specified)",
)
@click.option(
"--method", "method", type=str, default="ties", help="Method used to merge models"
)
@click.option(
"--base-model", "base_model", type=str, default=None, help="Base model for merge"
)
@click.option(
"--normalize/--no-normalize",
"normalize",
is_flag=True,
default=True,
help="Divide merged parameters by the sum of weights",
)
@click.option(
"--int8-mask/--no-int8-mask",
"int8_mask",
is_flag=True,
help="Store intermediate masks in int8 to save memory",
)
@click.option("--bf16/--no-bf16", "bf16", is_flag=True, help="Use bfloat16")
@click.option(
"--naive-count/--no-naive-count",
"naive_count",
is_flag=True,
help="Use naive sign count instead of weight (ties only)",
)
@click.option(
"--print-yaml/--no-print-yaml",
"print_yaml",
is_flag=True,
help="Print generated YAML configuration",
)
@add_merge_options
def main(
out_path: str,
merge: List[str],
density: List[float],
weight: List[float],
method: str,
base_model: Optional[str],
normalize: bool,
int8_mask: bool,
bf16: bool,
naive_count: bool,
print_yaml: bool,
merge_options: MergeOptions,
):
"""Wrapper for using a subset of legacy-style script arguments."""
models = [InputModelDefinition(model=model, parameters={}) for model in merge]
if base_model and base_model not in merge:
models.append(InputModelDefinition(model=base_model, parameters={}))
parameters = {}
if density:
if len(density) == 1:
density = [density[0]] * len(models)
for idx, d in enumerate(density):
models[idx].parameters["density"] = d
if method == "slerp":
assert len(weight) == 1, "Must specify exactly one weight for SLERP"
parameters["t"] = weight[0]
else:
if weight:
if len(weight) == 1:
weight = [weight[0]] * len(models)
for idx, w in enumerate(weight):
models[idx].parameters["weight"] = w
if int8_mask:
parameters["int8_mask"] = True
if naive_count:
parameters["consensus_method"] = "count"
parameters["normalize"] = normalize
merge_config = MergeConfiguration(
merge_method=method,
models=models,
parameters=parameters,
base_model=base_model,
dtype="bfloat16" if bf16 else None,
)
if print_yaml:
print(yaml.dump(merge_config.model_dump(mode="json", exclude_none=True)))
run_merge(
merge_config,
out_path,
options=merge_options,
)
if __name__ == "__main__":
main()