|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
import click |
|
import yaml |
|
|
|
from mergekit.config import InputModelDefinition, MergeConfiguration |
|
from mergekit.merge import run_merge |
|
from mergekit.options import MergeOptions, add_merge_options |
|
|
|
|
|
@click.command("mergekit-legacy") |
|
@click.argument("out_path", type=str) |
|
@click.option( |
|
"--merge", "merge", type=str, multiple=True, help="Add a model to the merge" |
|
) |
|
@click.option( |
|
"--density", |
|
"density", |
|
type=float, |
|
multiple=True, |
|
default=[], |
|
help="Fraction of weights to keep for each model (ties only)", |
|
) |
|
@click.option( |
|
"--weight", |
|
"weight", |
|
type=float, |
|
multiple=True, |
|
default=[], |
|
help="Weighting for a model (default 1.0 for all models if not specified)", |
|
) |
|
@click.option( |
|
"--method", "method", type=str, default="ties", help="Method used to merge models" |
|
) |
|
@click.option( |
|
"--base-model", "base_model", type=str, default=None, help="Base model for merge" |
|
) |
|
@click.option( |
|
"--normalize/--no-normalize", |
|
"normalize", |
|
is_flag=True, |
|
default=True, |
|
help="Divide merged parameters by the sum of weights", |
|
) |
|
@click.option( |
|
"--int8-mask/--no-int8-mask", |
|
"int8_mask", |
|
is_flag=True, |
|
help="Store intermediate masks in int8 to save memory", |
|
) |
|
@click.option("--bf16/--no-bf16", "bf16", is_flag=True, help="Use bfloat16") |
|
@click.option( |
|
"--naive-count/--no-naive-count", |
|
"naive_count", |
|
is_flag=True, |
|
help="Use naive sign count instead of weight (ties only)", |
|
) |
|
@click.option( |
|
"--print-yaml/--no-print-yaml", |
|
"print_yaml", |
|
is_flag=True, |
|
help="Print generated YAML configuration", |
|
) |
|
@add_merge_options |
|
def main( |
|
out_path: str, |
|
merge: List[str], |
|
density: List[float], |
|
weight: List[float], |
|
method: str, |
|
base_model: Optional[str], |
|
normalize: bool, |
|
int8_mask: bool, |
|
bf16: bool, |
|
naive_count: bool, |
|
print_yaml: bool, |
|
merge_options: MergeOptions, |
|
): |
|
"""Wrapper for using a subset of legacy-style script arguments.""" |
|
models = [InputModelDefinition(model=model, parameters={}) for model in merge] |
|
if base_model and base_model not in merge: |
|
models.append(InputModelDefinition(model=base_model, parameters={})) |
|
|
|
parameters = {} |
|
|
|
if density: |
|
if len(density) == 1: |
|
density = [density[0]] * len(models) |
|
for idx, d in enumerate(density): |
|
models[idx].parameters["density"] = d |
|
|
|
if method == "slerp": |
|
assert len(weight) == 1, "Must specify exactly one weight for SLERP" |
|
parameters["t"] = weight[0] |
|
else: |
|
if weight: |
|
if len(weight) == 1: |
|
weight = [weight[0]] * len(models) |
|
for idx, w in enumerate(weight): |
|
models[idx].parameters["weight"] = w |
|
|
|
if int8_mask: |
|
parameters["int8_mask"] = True |
|
if naive_count: |
|
parameters["consensus_method"] = "count" |
|
parameters["normalize"] = normalize |
|
|
|
merge_config = MergeConfiguration( |
|
merge_method=method, |
|
models=models, |
|
parameters=parameters, |
|
base_model=base_model, |
|
dtype="bfloat16" if bf16 else None, |
|
) |
|
|
|
if print_yaml: |
|
print(yaml.dump(merge_config.model_dump(mode="json", exclude_none=True))) |
|
|
|
run_merge( |
|
merge_config, |
|
out_path, |
|
options=merge_options, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|