|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
import click |
|
import yaml |
|
from pydantic import BaseModel |
|
|
|
from mergekit.common import MergeOptions |
|
from mergekit.config import ( |
|
ConditionalParameter, |
|
InputSliceDefinition, |
|
MergeConfiguration, |
|
) |
|
from mergekit.merge import run_merge |
|
|
|
|
|
class LayerSlice(BaseModel): |
|
model: str |
|
start: int |
|
end: int |
|
scale: Optional[float] = None |
|
|
|
|
|
class BakllamaConfig(BaseModel): |
|
layer_slices: List[LayerSlice] |
|
embedding_source: Optional[str] = None |
|
lm_head_source: Optional[str] = None |
|
|
|
|
|
@click.command("bakllama") |
|
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) |
|
@click.argument("out_path", type=str) |
|
@click.option( |
|
"--clone-tensors/--no-clone-tensors", |
|
type=bool, |
|
is_flag=True, |
|
help="Clone tensors before saving, to allow multiple occurrences of the same layer", |
|
default=False, |
|
) |
|
@click.option("--fp16/--no-fp16", type=bool, default=False) |
|
def main( |
|
config_path: str, |
|
out_path: str, |
|
clone_tensors: bool, |
|
fp16: bool, |
|
): |
|
"""Wrapper for using legacy bakllama configuration files.""" |
|
with open(config_path, "r", encoding="utf-8") as file: |
|
config = BakllamaConfig.model_validate(yaml.safe_load(file)) |
|
|
|
slices = [] |
|
for s in config.layer_slices: |
|
parameters = {} |
|
if s.scale is not None: |
|
parameters["scale"] = ConditionalParameter( |
|
value=s.scale, filter="down_proj" |
|
) |
|
slices.append( |
|
InputSliceDefinition( |
|
model=s.model, layer_range=(s.start, s.end), parameters=parameters |
|
) |
|
) |
|
|
|
merge_config = MergeConfiguration( |
|
merge_method="passthrough", slices=slices, dtype="float16" if fp16 else None |
|
) |
|
run_merge(merge_config, out_path, MergeOptions(clone_tensors=clone_tensors)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|