File size: 4,976 Bytes
7e4953f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import json
import os
import shutil
from tempfile import TemporaryDirectory
from typing import List, Optional
from huggingface_hub import (
CommitInfo,
CommitOperationAdd,
Discussion,
HfApi,
hf_hub_download,
)
from huggingface_hub.file_download import repo_folder_name
class AlreadyExists(Exception):
pass
def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
config_file_name = "generation_config.json"
config_file = hf_hub_download(repo_id=model_id, filename=config_file_name)
old_config_file = config_file
new_config_file = os.path.join(folder, config_file)
success = convert_file(old_config_file, new_config_file)
if success:
operations = [
CommitOperationAdd(
path_in_repo=config_file_name, path_or_fileobj=new_config_file
)
]
model_type = success
return operations, model_type
else:
return False, False
def convert_file(
old_config: str,
new_config: str,
):
with open(old_config, "r") as f:
old_dict = json.load(f)
old_dict["max_initial_timestamp_index"] = 50
old_dict["prev_sot_token_id"] = old_dict["suppress_tokens"][-2]
with open(new_config, "w") as f:
json_str = json.dumps(old_dict, indent=2, sort_keys=True) + "\n"
f.write(json_str)
return "Whisper"
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
for discussion in discussions:
if (
discussion.status == "open"
and discussion.is_pull_request
and discussion.title == pr_title
):
return discussion
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
pr_title = "Correct long-form generation config parameters 'max_initial_timestamp_index' and 'prev_sot_token_id'."
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
if "generation_config.json" not in filenames:
print(f"Model: {model_id} has no generation_config.json file to change")
return
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
operations = None
pr = previous_pr(api, model_id, pr_title)
if pr is not None and not force:
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
new_pr = pr
raise AlreadyExists(
f"Model {model_id} already has an open PR check out {url}"
)
else:
operations, model_type = convert_single(model_id, folder)
if operations:
pr_title = pr_title.format(model_type)
contributor = model_id.split("/")[0]
pr_description = (
f"Hey {contributor} 👋, \n\n Your model repository seems to contain outdated generation config parameters, such as 'max_initial_timestamp_index' and is missing the 'prev_sot_token_id' parameter. "
"These parameters need to be updated to correctly handle long-form generation as stated in as part of https://github.com/huggingface/transformers/pull/27658. "
"This PR makes sure that everything is up to date and can be safely merged. \n\n Best, the Transformers team."
)
new_pr = api.create_commit(
repo_id=model_id,
operations=operations,
commit_message=pr_title,
commit_description=pr_description,
create_pr=True,
)
print(f"Pr created at {new_pr.pr_url}")
else:
print(f"No files to convert for {model_id}")
finally:
shutil.rmtree(folder)
return new_pr
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
parser.add_argument(
"--force",
action="store_true",
help="Create the PR even if it already exists of if the model was already converted.",
)
args = parser.parse_args()
model_id = args.model_id
api = HfApi()
convert(api, model_id, force=args.force)
|