File size: 5,556 Bytes
a164e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python3
"""
Merges multiple models and their dependencies into a single model
using multiple merge yaml documents in a single yaml file as the input
"""

import logging
import os
import sys
from pathlib import Path

import click
import yaml

from mergekit.config import MergeConfiguration
from mergekit.merge import MergeOptions, run_merge
from mergekit.options import add_merge_options

merges = {}


def has_circular_dependency(nodes):
    """
    Detects circular in merges dependencies using DFS
    Returns the node where the cycle is detected
    """

    def dfs(node, visited, stack):
        """
        Returns True if a cycle is detected
        """
        visited[node] = True
        stack[node] = True

        for dependency in nodes[node]["deps"]:
            if not visited[dependency]:
                if dfs(dependency, visited, stack):
                    return True
            elif stack[dependency]:
                return True

        stack[node] = False
        return False

    visited = {key: False for key in nodes}
    stack = {key: False for key in nodes}

    for node in nodes:
        if not visited[node]:
            if dfs(node, visited, stack):
                return node

    return None


def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path):
    """
    Merges a model and its dependencies

    Params:
        m: name of the model to merge
        merge_options: MergeOptions
        force: overwrite existing merge results
        out_path: output path
    """
    # check if output_path exists
    if os.path.exists(out_path / m):
        if not force:
            logging.info("Skipping %s as it already exists", m)
            del merges[m]
            return
        logging.info("Overwriting %s as --force was specified", m)

    if len(merges[m]["deps"]) != 0:
        for dep in merges[m]["deps"]:
            if dep in merges:
                merge(dep, merge_options, force, out_path)

    logging.info("Merging model %s", m)
    merge_config: MergeConfiguration = MergeConfiguration.model_validate(merges[m])
    run_merge(
        merge_config,
        str(out_path / merges[m]["name"]),
        options=merge_options,
    )
    del merges[m]


def add_model_deps(model: str, name: str, out_path: Path):
    """
    Adds a model to `name`s dependencies if it is not already there and is a merge
    """
    model_lora = model.split("+")
    # name must not have a slash to avoid path traversal
    # therefore, we can use it to check if its a merge from the config
    if "/" not in model_lora[0]:
        # avoid duplicate deps
        if model_lora[0] not in merges[name]["deps"]:
            merges[name]["deps"].append(model_lora[0])
        model = str(out_path / model_lora[0])
        if len(model_lora) == 2:
            model += "+" + model_lora[1]

    return model


@click.command("mergekit-mega")
@click.argument("config_file")
@click.argument("out_path")
@click.option(
    "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
    "--force",
    "-f",
    type=bool,
    default=False,
    is_flag=True,
    help="Overwrite existing merge results instead of skipping them",
)
@click.option(
    "--require-nameless",
    "-R",
    type=bool,
    default=False,
    is_flag=True,
    help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.",
)
@add_merge_options
def main(
    merge_options: MergeOptions,
    config_file: str,
    out_path: str,
    force: bool,
    verbose: bool,
    require_nameless: bool,
):
    """
    Main entrypoint for mergekit-mega command see module docstring for more info
    Params are supplied by click decorators
    """
    logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

    out_path = Path(out_path)
    final_found = False

    with open(config_file, "r", encoding="utf-8") as f:
        data = yaml.load_all(f, Loader=yaml.FullLoader)

        for d in data:
            if "name" not in d:
                if final_found:
                    logging.error("Only one merge must not have a name")
                    sys.exit(1)
                # this sets the name of the final merge to the config file name without the extension
                d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0]
                final_found = True

            if "/" in d["name"]:
                logging.error("name must not contain a slash")
                sys.exit(1)

            merges[d["name"]] = d
            merges[d["name"]]["deps"] = []
            if "base_model" in d:
                d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path)
            if "slices" in d:
                for slc in d["slices"]:
                    for src in slc["sources"]:
                        src["model"] = add_model_deps(src["model"], d["name"], out_path)
            if "models" in d:
                for mdl in d["models"]:
                    mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path)

    if require_nameless and not final_found:
        logging.error("No final merge found")
        sys.exit(1)

    logging.info("Merging: %s", ", ".join(merges))

    if (dep := has_circular_dependency(merges)) is not None:
        logging.error("Circular dependency detected: %s", dep)
        sys.exit(1)

    while len(merges) != 0:
        m = list(merges.keys())[0]
        merge(m, merge_options, force, out_path)


if __name__ == "__main__":
    main()