File size: 4,976 Bytes
7e4953f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import json
import os
import shutil
from tempfile import TemporaryDirectory
from typing import List, Optional

from huggingface_hub import (
    CommitInfo,
    CommitOperationAdd,
    Discussion,
    HfApi,
    hf_hub_download,
)
from huggingface_hub.file_download import repo_folder_name


class AlreadyExists(Exception):
    pass


def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
    config_file_name = "generation_config.json"
    config_file = hf_hub_download(repo_id=model_id, filename=config_file_name)

    old_config_file = config_file

    new_config_file = os.path.join(folder, config_file)
    success = convert_file(old_config_file, new_config_file)
    if success:
        operations = [
            CommitOperationAdd(
                path_in_repo=config_file_name, path_or_fileobj=new_config_file
            )
        ]
        model_type = success
        return operations, model_type
    else:
        return False, False


def convert_file(
    old_config: str,
    new_config: str,
):
    with open(old_config, "r") as f:
        old_dict = json.load(f)

    old_dict["max_initial_timestamp_index"] = 50
    old_dict["prev_sot_token_id"] = old_dict["suppress_tokens"][-2]

    with open(new_config, "w") as f:
        json_str = json.dumps(old_dict, indent=2, sort_keys=True) + "\n"
        f.write(json_str)

    return "Whisper"


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
    try:
        discussions = api.get_repo_discussions(repo_id=model_id)
    except Exception:
        return None
    for discussion in discussions:
        if (
            discussion.status == "open"
            and discussion.is_pull_request
            and discussion.title == pr_title
        ):
            return discussion


def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
    pr_title = "Correct long-form generation config parameters 'max_initial_timestamp_index' and 'prev_sot_token_id'."
    info = api.model_info(model_id)
    filenames = set(s.rfilename for s in info.siblings)

    if "generation_config.json" not in filenames:
        print(f"Model: {model_id} has no generation_config.json file to change")
        return

    with TemporaryDirectory() as d:
        folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
        os.makedirs(folder)
        new_pr = None
        try:
            operations = None
            pr = previous_pr(api, model_id, pr_title)
            if pr is not None and not force:
                url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
                new_pr = pr
                raise AlreadyExists(
                    f"Model {model_id} already has an open PR check out {url}"
                )
            else:
                operations, model_type = convert_single(model_id, folder)

            if operations:
                pr_title = pr_title.format(model_type)
                contributor = model_id.split("/")[0]
                pr_description = (
                    f"Hey {contributor} 👋, \n\n Your model repository seems to contain outdated generation config parameters, such as 'max_initial_timestamp_index' and is missing the 'prev_sot_token_id' parameter. "
                    "These parameters need to be updated to correctly handle long-form generation as stated in  as part of https://github.com/huggingface/transformers/pull/27658. "
                    "This PR makes sure that everything is up to date and can be safely merged. \n\n Best, the Transformers team."
                )
                new_pr = api.create_commit(
                    repo_id=model_id,
                    operations=operations,
                    commit_message=pr_title,
                    commit_description=pr_description,
                    create_pr=True,
                )
                print(f"Pr created at {new_pr.pr_url}")
            else:
                print(f"No files to convert for {model_id}")
        finally:
            shutil.rmtree(folder)
        return new_pr


if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
    It is PyTorch exclusive for now.
    It works by downloading the weights (PT), converting them locally, and uploading them back
    as a PR on the hub.
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "model_id",
        type=str,
        help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Create the PR even if it already exists of if the model was already converted.",
    )
    args = parser.parse_args()
    model_id = args.model_id
    api = HfApi()
    convert(api, model_id, force=args.force)