|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
from typing import List |
|
|
|
import click |
|
import yaml |
|
|
|
from mergekit.architecture import get_architecture_info |
|
from mergekit.common import ModelReference |
|
from mergekit.config import ( |
|
InputSliceDefinition, |
|
MergeConfiguration, |
|
OutputSliceDefinition, |
|
) |
|
from mergekit.merge import run_merge |
|
from mergekit.options import MergeOptions, add_merge_options |
|
|
|
|
|
@click.command("mergekit-layershuffle") |
|
@click.argument("out_path", type=str) |
|
@click.option("--model", "-m", multiple=True, type=str, help="Add a model to the merge") |
|
@click.option( |
|
"--weight", |
|
"-w", |
|
multiple=True, |
|
type=float, |
|
default=[], |
|
show_default=False, |
|
help="Weighting for a model", |
|
) |
|
@click.option( |
|
"--print-yaml/--no-print-yaml", |
|
is_flag=True, |
|
help="Print YAML merge config for resulting model", |
|
) |
|
@click.option( |
|
"--write-yaml", |
|
type=click.Path(writable=True), |
|
help="Path to write YAML merge config to", |
|
) |
|
@click.option( |
|
"--dry-run", is_flag=True, help="Generate a config but do not run the merge" |
|
) |
|
@click.option("--fp16/--no-fp16", is_flag=True, help="Use FP16 precision") |
|
@click.option( |
|
"--full-random/--no-full-random", |
|
is_flag=True, |
|
help="Randomize layer index as well as source model", |
|
) |
|
@add_merge_options |
|
def main( |
|
out_path: str, |
|
model: List[str], |
|
weight: List[float], |
|
print_yaml: bool, |
|
write_yaml: bool, |
|
dry_run: bool, |
|
fp16: bool, |
|
full_random: bool, |
|
merge_options: MergeOptions, |
|
): |
|
models = [ModelReference.parse(m) for m in model] |
|
|
|
m0_cfg = models[0].config() |
|
arch_info = get_architecture_info(m0_cfg) |
|
total_num_layers = arch_info.num_layers(m0_cfg) |
|
|
|
out_slices: List[OutputSliceDefinition] = [] |
|
|
|
if full_random: |
|
for model, frac in zip(models, weight): |
|
cfg = model.config() |
|
num_layers = int(arch_info.num_layers(cfg) * frac) |
|
for _ in range(num_layers): |
|
src_idx = random.randrange(0, num_layers) |
|
out_slices.append( |
|
OutputSliceDefinition( |
|
sources=[ |
|
InputSliceDefinition( |
|
model=str(model), |
|
layer_range=(src_idx, src_idx + 1), |
|
) |
|
] |
|
) |
|
) |
|
random.shuffle(out_slices) |
|
else: |
|
for layer_idx in range(total_num_layers): |
|
src_model = random.choices(models, weights=weight, k=1)[0] |
|
if out_slices and out_slices[-1].sources[0].model == str(src_model): |
|
out_slices[-1].sources[0].layer_range = ( |
|
out_slices[-1].sources[0].layer_range[0], |
|
layer_idx + 1, |
|
) |
|
else: |
|
out_slices.append( |
|
OutputSliceDefinition( |
|
sources=[ |
|
InputSliceDefinition( |
|
model=str(src_model), |
|
layer_range=(layer_idx, layer_idx + 1), |
|
) |
|
] |
|
) |
|
) |
|
merge_config = MergeConfiguration( |
|
merge_method="passthrough", slices=out_slices, dtype="float16" if fp16 else None |
|
) |
|
|
|
if print_yaml or write_yaml: |
|
yaml_str = yaml.dump(merge_config.model_dump(exclude_none=True, mode="json")) |
|
|
|
if print_yaml: |
|
print(yaml_str) |
|
if write_yaml: |
|
with open(write_yaml, "w", encoding="utf-8") as file: |
|
file.write(yaml_str) |
|
|
|
if dry_run: |
|
return |
|
|
|
run_merge( |
|
merge_config, |
|
out_path, |
|
options=merge_options, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|