Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
#!/usr/bin/env python3
"""
Merges multiple models and their dependencies into a single model
using multiple merge yaml documents in a single yaml file as the input
"""
import logging
import os
import sys
from pathlib import Path
import click
import yaml
from mergekit.config import MergeConfiguration
from mergekit.merge import MergeOptions, run_merge
from mergekit.options import add_merge_options
merges = {}
def has_circular_dependency(nodes):
"""
Detects circular in merges dependencies using DFS
Returns the node where the cycle is detected
"""
def dfs(node, visited, stack):
"""
Returns True if a cycle is detected
"""
visited[node] = True
stack[node] = True
for dependency in nodes[node]["deps"]:
if not visited[dependency]:
if dfs(dependency, visited, stack):
return True
elif stack[dependency]:
return True
stack[node] = False
return False
visited = {key: False for key in nodes}
stack = {key: False for key in nodes}
for node in nodes:
if not visited[node]:
if dfs(node, visited, stack):
return node
return None
def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path):
"""
Merges a model and its dependencies
Params:
m: name of the model to merge
merge_options: MergeOptions
force: overwrite existing merge results
out_path: output path
"""
# check if output_path exists
if os.path.exists(out_path / m):
if not force:
logging.info("Skipping %s as it already exists", m)
del merges[m]
return
logging.info("Overwriting %s as --force was specified", m)
if len(merges[m]["deps"]) != 0:
for dep in merges[m]["deps"]:
if dep in merges:
merge(dep, merge_options, force, out_path)
logging.info("Merging model %s", m)
merge_config: MergeConfiguration = MergeConfiguration.model_validate(merges[m])
run_merge(
merge_config,
str(out_path / merges[m]["name"]),
options=merge_options,
)
del merges[m]
def add_model_deps(model: str, name: str, out_path: Path):
"""
Adds a model to `name`s dependencies if it is not already there and is a merge
"""
model_lora = model.split("+")
# name must not have a slash to avoid path traversal
# therefore, we can use it to check if its a merge from the config
if "/" not in model_lora[0]:
# avoid duplicate deps
if model_lora[0] not in merges[name]["deps"]:
merges[name]["deps"].append(model_lora[0])
model = str(out_path / model_lora[0])
if len(model_lora) == 2:
model += "+" + model_lora[1]
return model
@click.command("mergekit-mega")
@click.argument("config_file")
@click.argument("out_path")
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--force",
"-f",
type=bool,
default=False,
is_flag=True,
help="Overwrite existing merge results instead of skipping them",
)
@click.option(
"--require-nameless",
"-R",
type=bool,
default=False,
is_flag=True,
help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.",
)
@add_merge_options
def main(
merge_options: MergeOptions,
config_file: str,
out_path: str,
force: bool,
verbose: bool,
require_nameless: bool,
):
"""
Main entrypoint for mergekit-mega command see module docstring for more info
Params are supplied by click decorators
"""
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
out_path = Path(out_path)
final_found = False
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.load_all(f, Loader=yaml.FullLoader)
for d in data:
if "name" not in d:
if final_found:
logging.error("Only one merge must not have a name")
sys.exit(1)
# this sets the name of the final merge to the config file name without the extension
d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0]
final_found = True
if "/" in d["name"]:
logging.error("name must not contain a slash")
sys.exit(1)
merges[d["name"]] = d
merges[d["name"]]["deps"] = []
if "base_model" in d:
d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path)
if "slices" in d:
for slc in d["slices"]:
for src in slc["sources"]:
src["model"] = add_model_deps(src["model"], d["name"], out_path)
if "models" in d:
for mdl in d["models"]:
mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path)
if require_nameless and not final_found:
logging.error("No final merge found")
sys.exit(1)
logging.info("Merging: %s", ", ".join(merges))
if (dep := has_circular_dependency(merges)) is not None:
logging.error("Circular dependency detected: %s", dep)
sys.exit(1)
while len(merges) != 0:
m = list(merges.keys())[0]
merge(m, merge_options, force, out_path)
if __name__ == "__main__":
main()