File size: 5,130 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import os
from typing import Generator, List, Optional
import huggingface_hub
import yaml
from huggingface_hub.utils import HFValidationError
from yaml.nodes import SequenceNode as SequenceNode
from mergekit.config import MergeConfiguration, ModelReference
CARD_TEMPLATE = """---
{metadata}
---
# {name}
This is a merge of pre-trained language models created using [mergekit](https://github.com/cg123/mergekit).
## Merge Details
### Merge Method
This model was merged using the {merge_method} merge method{base_text}.
### Models Merged
The following models were included in the merge:
{model_list}
### Configuration
The following YAML configuration was used to produce this model:
```yaml
{config_yaml}
```
"""
def is_hf(path: str) -> bool:
"""
Determines if the given path is a Hugging Face model repository.
Args:
path: A string path to check.
"""
if path[0] in "/~" or path.count("/") > 1:
return False # definitely a local path
if not os.path.exists(path):
return True # If path doesn't exist locally, it must be a HF repo
try:
return huggingface_hub.repo_exists(path, repo_type="model", token=False)
except HFValidationError:
return False
def extract_hf_paths(models: List[ModelReference]) -> Generator[str, None, None]:
"""
Yields all valid Hugging Face paths from a list of ModelReference objects.
Args:
models: A list of ModelReference objects.
"""
for model in models:
if is_hf(model.model.path):
yield model.model.path
if model.lora and is_hf(model.lora.path):
yield model.lora.path
def method_md(merge_method: str) -> str:
"""
Returns a markdown string for the given merge method.
Args:
merge_method: A string indicating the merge method used.
"""
methods = {
"linear": "[linear](https://arxiv.org/abs/2203.05482)",
"ties": "[TIES](https://arxiv.org/abs/2306.01708)",
"slerp": "SLERP",
"task_arithmetic": "[task arithmetic](https://arxiv.org/abs/2212.04089)",
"dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)",
"dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)",
"model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)",
}
return methods.get(merge_method, merge_method)
def maybe_link_hf(path: str) -> str:
"""
Convert a path to a clickable link if it's a Hugging Face model path.
Args:
path: A string path to possibly convert to a link.
"""
if is_hf(path):
return f"[{path}](https://huggingface.co/{path})"
return path
def modelref_md(model: ModelReference) -> str:
"""
Generates markdown description for a ModelReference object.
Args:
model: A ModelReference object.
Returns:
A markdown formatted string describing the model reference.
"""
text = maybe_link_hf(model.model.path)
if model.lora:
text += " + " + maybe_link_hf(model.lora.path)
return text
def generate_card(
config: MergeConfiguration,
config_yaml: str,
name: Optional[str] = None,
) -> str:
"""
Generates a markdown card for a merged model configuration.
Args:
config: A MergeConfiguration object.
config_yaml: YAML source text of the config.
name: An optional name for the model.
"""
if not name:
name = "Untitled Model (1)"
hf_bases = list(extract_hf_paths(config.referenced_models()))
tags = ["mergekit", "merge"]
actual_base = config.base_model
if config.merge_method == "slerp":
# curse my past self
actual_base = None
base_text = ""
if actual_base:
base_text = f" using {modelref_md(actual_base)} as a base"
model_bullets = []
for model in config.referenced_models():
if model == actual_base:
# actual_base is mentioned in base_text - don't include in list
continue
model_bullets.append("* " + modelref_md(model))
return CARD_TEMPLATE.format(
metadata=yaml.dump(
{"base_model": hf_bases, "tags": tags, "library_name": "transformers"}
),
model_list="\n".join(model_bullets),
base_text=base_text,
merge_method=method_md(config.merge_method),
name=name,
config_yaml=config_yaml,
)
|