File size: 5,556 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
#!/usr/bin/env python3
"""
Merges multiple models and their dependencies into a single model
using multiple merge yaml documents in a single yaml file as the input
"""
import logging
import os
import sys
from pathlib import Path
import click
import yaml
from mergekit.config import MergeConfiguration
from mergekit.merge import MergeOptions, run_merge
from mergekit.options import add_merge_options
merges = {}
def has_circular_dependency(nodes):
"""
Detects circular in merges dependencies using DFS
Returns the node where the cycle is detected
"""
def dfs(node, visited, stack):
"""
Returns True if a cycle is detected
"""
visited[node] = True
stack[node] = True
for dependency in nodes[node]["deps"]:
if not visited[dependency]:
if dfs(dependency, visited, stack):
return True
elif stack[dependency]:
return True
stack[node] = False
return False
visited = {key: False for key in nodes}
stack = {key: False for key in nodes}
for node in nodes:
if not visited[node]:
if dfs(node, visited, stack):
return node
return None
def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path):
"""
Merges a model and its dependencies
Params:
m: name of the model to merge
merge_options: MergeOptions
force: overwrite existing merge results
out_path: output path
"""
# check if output_path exists
if os.path.exists(out_path / m):
if not force:
logging.info("Skipping %s as it already exists", m)
del merges[m]
return
logging.info("Overwriting %s as --force was specified", m)
if len(merges[m]["deps"]) != 0:
for dep in merges[m]["deps"]:
if dep in merges:
merge(dep, merge_options, force, out_path)
logging.info("Merging model %s", m)
merge_config: MergeConfiguration = MergeConfiguration.model_validate(merges[m])
run_merge(
merge_config,
str(out_path / merges[m]["name"]),
options=merge_options,
)
del merges[m]
def add_model_deps(model: str, name: str, out_path: Path):
"""
Adds a model to `name`s dependencies if it is not already there and is a merge
"""
model_lora = model.split("+")
# name must not have a slash to avoid path traversal
# therefore, we can use it to check if its a merge from the config
if "/" not in model_lora[0]:
# avoid duplicate deps
if model_lora[0] not in merges[name]["deps"]:
merges[name]["deps"].append(model_lora[0])
model = str(out_path / model_lora[0])
if len(model_lora) == 2:
model += "+" + model_lora[1]
return model
@click.command("mergekit-mega")
@click.argument("config_file")
@click.argument("out_path")
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--force",
"-f",
type=bool,
default=False,
is_flag=True,
help="Overwrite existing merge results instead of skipping them",
)
@click.option(
"--require-nameless",
"-R",
type=bool,
default=False,
is_flag=True,
help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.",
)
@add_merge_options
def main(
merge_options: MergeOptions,
config_file: str,
out_path: str,
force: bool,
verbose: bool,
require_nameless: bool,
):
"""
Main entrypoint for mergekit-mega command see module docstring for more info
Params are supplied by click decorators
"""
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
out_path = Path(out_path)
final_found = False
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.load_all(f, Loader=yaml.FullLoader)
for d in data:
if "name" not in d:
if final_found:
logging.error("Only one merge must not have a name")
sys.exit(1)
# this sets the name of the final merge to the config file name without the extension
d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0]
final_found = True
if "/" in d["name"]:
logging.error("name must not contain a slash")
sys.exit(1)
merges[d["name"]] = d
merges[d["name"]]["deps"] = []
if "base_model" in d:
d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path)
if "slices" in d:
for slc in d["slices"]:
for src in slc["sources"]:
src["model"] = add_model_deps(src["model"], d["name"], out_path)
if "models" in d:
for mdl in d["models"]:
mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path)
if require_nameless and not final_found:
logging.error("No final merge found")
sys.exit(1)
logging.info("Merging: %s", ", ".join(merges))
if (dep := has_circular_dependency(merges)) is not None:
logging.error("Circular dependency detected: %s", dep)
sys.exit(1)
while len(merges) != 0:
m = list(merges.keys())[0]
merge(m, merge_options, force, out_path)
if __name__ == "__main__":
main()
|