Spaces:
Runtime error
Runtime error
Ubuntu
commited on
Commit
β’
b7fe3c7
1
Parent(s):
968cf44
add revision option
Browse files- app.py +4 -2
- convert.py +15 -9
app.py
CHANGED
@@ -19,7 +19,7 @@ if HF_TOKEN:
|
|
19 |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
20 |
|
21 |
|
22 |
-
def run(token: str, model_id: str) -> str:
|
23 |
if token == "" or model_id == "":
|
24 |
return """
|
25 |
### Invalid input π
|
@@ -31,7 +31,7 @@ def run(token: str, model_id: str) -> str:
|
|
31 |
is_private = api.model_info(repo_id=model_id).private
|
32 |
print("is_private", is_private)
|
33 |
|
34 |
-
commit_info = convert(api=api, model_id=model_id, force=True)
|
35 |
print("[commit_info]", commit_info)
|
36 |
|
37 |
# save in a (public) dataset:
|
@@ -72,6 +72,7 @@ The steps are the following:
|
|
72 |
|
73 |
- Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
|
74 |
- Input a model id from the Hub
|
|
|
75 |
- Click "Submit"
|
76 |
- That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR π₯
|
77 |
|
@@ -86,6 +87,7 @@ demo = gr.Interface(
|
|
86 |
inputs=[
|
87 |
gr.Text(max_lines=1, label="your_hf_token"),
|
88 |
gr.Text(max_lines=1, label="model_id"),
|
|
|
89 |
],
|
90 |
outputs=[gr.Markdown(label="output")],
|
91 |
fn=run,
|
|
|
19 |
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, token=HF_TOKEN)
|
20 |
|
21 |
|
22 |
+
def run(token: str, model_id: str, revision: str = "main") -> str:
|
23 |
if token == "" or model_id == "":
|
24 |
return """
|
25 |
### Invalid input π
|
|
|
31 |
is_private = api.model_info(repo_id=model_id).private
|
32 |
print("is_private", is_private)
|
33 |
|
34 |
+
commit_info = convert(api=api, model_id=model_id, revision=revision, force=True)
|
35 |
print("[commit_info]", commit_info)
|
36 |
|
37 |
# save in a (public) dataset:
|
|
|
72 |
|
73 |
- Paste a read-access token from hf.co/settings/tokens. Read access is enough given that we will open a PR against the source repo.
|
74 |
- Input a model id from the Hub
|
75 |
+
- Optionally select a revision like fp16
|
76 |
- Click "Submit"
|
77 |
- That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the opened PR π₯
|
78 |
|
|
|
87 |
inputs=[
|
88 |
gr.Text(max_lines=1, label="your_hf_token"),
|
89 |
gr.Text(max_lines=1, label="model_id"),
|
90 |
+
gr.Text(max_lines=1, label="revision", default="main"),
|
91 |
],
|
92 |
outputs=[gr.Markdown(label="output")],
|
93 |
fn=run,
|
convert.py
CHANGED
@@ -51,15 +51,15 @@ def rename(pt_filename: str) -> str:
|
|
51 |
return local
|
52 |
|
53 |
|
54 |
-
def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
|
55 |
-
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
|
56 |
with open(filename, "r") as f:
|
57 |
data = json.load(f)
|
58 |
|
59 |
filenames = set(data["weight_map"].values())
|
60 |
local_filenames = []
|
61 |
for filename in filenames:
|
62 |
-
pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
|
63 |
|
64 |
sf_filename = rename(pt_filename)
|
65 |
sf_filename = os.path.join(folder, sf_filename)
|
@@ -143,14 +143,14 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
|
|
143 |
return discussion
|
144 |
|
145 |
|
146 |
-
def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["CommitOperationAdd"]:
|
147 |
operations = []
|
148 |
|
149 |
extensions = set([".bin", ".ckpt"])
|
150 |
for filename in filenames:
|
151 |
prefix, ext = os.path.splitext(filename)
|
152 |
if ext in extensions:
|
153 |
-
pt_filename = hf_hub_download(model_id, filename=filename)
|
154 |
dirname, raw_filename = os.path.split(filename)
|
155 |
if raw_filename == "pytorch_model.bin":
|
156 |
# XXX: This is a special case to handle `transformers` and the
|
@@ -164,9 +164,9 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
164 |
return operations
|
165 |
|
166 |
|
167 |
-
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
168 |
pr_title = "Adding `safetensors` variant of this model"
|
169 |
-
info = api.model_info(model_id)
|
170 |
|
171 |
def is_valid_filename(filename):
|
172 |
return len(filename.split("/")) > 1 or filename in ["pytorch_model.bin", "diffusion_pytorch_model.bin"]
|
@@ -190,7 +190,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
190 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
191 |
else:
|
192 |
print("Convert generic")
|
193 |
-
operations = convert_generic(model_id, folder, filenames)
|
194 |
|
195 |
if operations:
|
196 |
new_pr = api.create_commit(
|
@@ -225,7 +225,13 @@ if __name__ == "__main__":
|
|
225 |
action="store_true",
|
226 |
help="Create the PR even if it already exists of if the model was already converted.",
|
227 |
)
|
|
|
|
|
|
|
|
|
|
|
228 |
args = parser.parse_args()
|
229 |
model_id = args.model_id
|
|
|
230 |
api = HfApi()
|
231 |
-
convert(api, model_id, force=args.force)
|
|
|
51 |
return local
|
52 |
|
53 |
|
54 |
+
def convert_multi(model_id: str, folder: str, revision: str = "main") -> List["CommitOperationAdd"]:
|
55 |
+
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", revision=revision)
|
56 |
with open(filename, "r") as f:
|
57 |
data = json.load(f)
|
58 |
|
59 |
filenames = set(data["weight_map"].values())
|
60 |
local_filenames = []
|
61 |
for filename in filenames:
|
62 |
+
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, revision=revision)
|
63 |
|
64 |
sf_filename = rename(pt_filename)
|
65 |
sf_filename = os.path.join(folder, sf_filename)
|
|
|
143 |
return discussion
|
144 |
|
145 |
|
146 |
+
def convert_generic(model_id: str, folder: str, filenames: Set[str], revision: str = "main") -> List["CommitOperationAdd"]:
|
147 |
operations = []
|
148 |
|
149 |
extensions = set([".bin", ".ckpt"])
|
150 |
for filename in filenames:
|
151 |
prefix, ext = os.path.splitext(filename)
|
152 |
if ext in extensions:
|
153 |
+
pt_filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
154 |
dirname, raw_filename = os.path.split(filename)
|
155 |
if raw_filename == "pytorch_model.bin":
|
156 |
# XXX: This is a special case to handle `transformers` and the
|
|
|
164 |
return operations
|
165 |
|
166 |
|
167 |
+
def convert(api: "HfApi", model_id: str, force: bool = False, revision: str = "main") -> Optional["CommitInfo"]:
|
168 |
pr_title = "Adding `safetensors` variant of this model"
|
169 |
+
info = api.model_info(model_id, revision=revision)
|
170 |
|
171 |
def is_valid_filename(filename):
|
172 |
return len(filename.split("/")) > 1 or filename in ["pytorch_model.bin", "diffusion_pytorch_model.bin"]
|
|
|
190 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
191 |
else:
|
192 |
print("Convert generic")
|
193 |
+
operations = convert_generic(model_id, folder, filenames, revision=revision)
|
194 |
|
195 |
if operations:
|
196 |
new_pr = api.create_commit(
|
|
|
225 |
action="store_true",
|
226 |
help="Create the PR even if it already exists of if the model was already converted.",
|
227 |
)
|
228 |
+
parser.add_argument(
|
229 |
+
"revision",
|
230 |
+
default="main",
|
231 |
+
help="Branch to convert. E.g. main, fp16, bf16"
|
232 |
+
)
|
233 |
args = parser.parse_args()
|
234 |
model_id = args.model_id
|
235 |
+
revision = args.revision
|
236 |
api = HfApi()
|
237 |
+
convert(api, model_id, force=args.force, revision=revision)
|