patrickvonplaten anzorq commited on
Commit
d8ca2a9
·
0 Parent(s):

Duplicate from anzorq/sd-to-diffusers

Browse files

Co-authored-by: AQ <[email protected]>

Files changed (6) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +181 -0
  4. hf_utils.py +50 -0
  5. requirements.txt +8 -0
  6. utils.py +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD To Diffusers
3
+ emoji: 🎨➡️🧨
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: anzorq/sd-to-diffusers
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from huggingface_hub import HfApi, upload_folder
4
+ import gradio as gr
5
+ import hf_utils
6
+ import utils
7
+
8
+ subprocess.run(["git", "clone", "https://github.com/qunash/diffusers", "diffs"])
9
+
10
+ def error_str(error, title="Error"):
11
+ return f"""#### {title}
12
+ {error}""" if error else ""
13
+
14
+ def on_token_change(token):
15
+ model_names, error = hf_utils.get_my_model_names(token)
16
+ if model_names:
17
+ model_names.append("Other")
18
+
19
+ return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(visible=bool(model_names)), gr.update(value=error_str(error))
20
+
21
+ def url_to_model_id(model_id_str):
22
+ return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
23
+
24
+ def get_ckpt_names(token, radio_model_names, input_model):
25
+
26
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
27
+
28
+ if token == "" or model_id == "":
29
+ return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
30
+
31
+ try:
32
+ api = HfApi(token=token)
33
+ ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
34
+
35
+ if not ckpt_files:
36
+ return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
37
+
38
+ return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
39
+
40
+ except Exception as e:
41
+ return error_str(e), gr.update(choices=[]), None
42
+
43
+ def convert_and_push(radio_model_names, input_model, ckpt_name, sd_version, token, path_in_repo):
44
+
45
+ if sd_version == None:
46
+ return error_str("You must select a stable diffusion version.", title="Invalid input")
47
+
48
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
49
+
50
+ try:
51
+ model_id = url_to_model_id(model_id)
52
+
53
+ # 1. Download the checkpoint file
54
+ ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
55
+
56
+ # 2. Run the conversion script
57
+ os.makedirs(model_id, exist_ok=True)
58
+ subprocess.run(
59
+ [
60
+ "python3",
61
+ "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
62
+ "--checkpoint_path",
63
+ ckpt_path,
64
+ "--dump_path" ,
65
+ model_id,
66
+ "--sd_version",
67
+ sd_version
68
+ ]
69
+ )
70
+
71
+ # 3. Push to the model repo
72
+ commit_message="Add Diffusers weights"
73
+ upload_folder(
74
+ folder_path=model_id,
75
+ repo_id=model_id,
76
+ path_in_repo=path_in_repo,
77
+ token=token,
78
+ create_pr=True,
79
+ commit_message=commit_message,
80
+ commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
81
+ )
82
+
83
+ # # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
84
+ hf_utils.delete_file(revision)
85
+ subprocess.run(["rm", "-rf", model_id.split('/')[0]])
86
+ import glob
87
+ for f in glob.glob("*.yaml*"):
88
+ subprocess.run(["rm", "-rf", f])
89
+
90
+ return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
91
+ You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
92
+
93
+ return "Done"
94
+
95
+ except Exception as e:
96
+ return error_str(e)
97
+
98
+
99
+ DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
100
+ With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
101
+ You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub.
102
+ You can skip the queue by running the app in the colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/qunash/f0f3152c5851c0c477b68b7b98d547fe/convert-sd-to-diffusers.ipynb)"""
103
+
104
+ with gr.Blocks() as demo:
105
+
106
+ gr.Markdown(DESCRIPTION)
107
+ with gr.Row():
108
+
109
+ with gr.Column(scale=11):
110
+ with gr.Column():
111
+ gr.Markdown("## 1. Load model info")
112
+ input_token = gr.Textbox(
113
+ max_lines=1,
114
+ type="password",
115
+ label="Enter your Hugging Face token",
116
+ placeholder="READ permission is sufficient"
117
+ )
118
+ gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
119
+ with gr.Group(visible=False) as group_model:
120
+ radio_model_names = gr.Radio(label="Choose a model")
121
+ input_model = gr.Textbox(
122
+ max_lines=1,
123
+ label="Model name or URL",
124
+ placeholder="username/model_name",
125
+ visible=False,
126
+ )
127
+
128
+ btn_get_ckpts = gr.Button("Load", visible=False)
129
+
130
+ with gr.Column(scale=10):
131
+ with gr.Column(visible=False) as group_convert:
132
+ gr.Markdown("## 2. Convert to Diffusers🧨")
133
+ radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
134
+ path_in_repo = gr.Textbox(label="Path where the weights will be saved", placeholder="Leave empty for root folder")
135
+ radio_sd_version = gr.Radio(label="Choose the model version", choices=["v1", "v2", "v2.1"])
136
+ gr.Markdown("Conversion may take a few minutes.")
137
+ btn_convert = gr.Button("Convert & Push")
138
+
139
+ error_output = gr.Markdown(label="Output")
140
+
141
+ input_token.change(
142
+ fn=on_token_change,
143
+ inputs=input_token,
144
+ outputs=[group_model, radio_model_names, btn_get_ckpts, error_output],
145
+ queue=False,
146
+ scroll_to_output=True)
147
+
148
+ radio_model_names.change(
149
+ lambda x: gr.update(visible=x == "Other"),
150
+ inputs=radio_model_names,
151
+ outputs=input_model,
152
+ queue=False,
153
+ scroll_to_output=True)
154
+
155
+ btn_get_ckpts.click(
156
+ fn=get_ckpt_names,
157
+ inputs=[input_token, radio_model_names, input_model],
158
+ outputs=[error_output, radio_ckpts, group_convert],
159
+ scroll_to_output=True,
160
+ queue=False
161
+ )
162
+
163
+ btn_convert.click(
164
+ fn=convert_and_push,
165
+ inputs=[radio_model_names, input_model, radio_ckpts, radio_sd_version, input_token, path_in_repo],
166
+ outputs=error_output,
167
+ scroll_to_output=True
168
+ )
169
+
170
+ # gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
171
+ gr.HTML("""
172
+ <div style="border-top: 1px solid #303030;">
173
+ <br>
174
+ <p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
175
+ <a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
176
+ <p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
177
+ </div>
178
+ """)
179
+
180
+ demo.queue()
181
+ demo.launch(debug=True, share=utils.is_google_colab())
hf_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import get_hf_file_metadata, hf_hub_url, hf_hub_download, scan_cache_dir, whoami, list_models
2
+
3
+
4
+ def get_my_model_names(token):
5
+
6
+ try:
7
+ author = whoami(token=token)
8
+ model_infos = list_models(author=author["name"], use_auth_token=token)
9
+ return [model.modelId for model in model_infos], None
10
+
11
+ except Exception as e:
12
+ return [], e
13
+
14
+ def download_file(repo_id: str, filename: str, token: str):
15
+ """Download a file from a repo on the Hugging Face Hub.
16
+
17
+ Returns:
18
+ file_path (:obj:`str`): The path to the downloaded file.
19
+ revision (:obj:`str`): The commit hash of the file.
20
+ """
21
+
22
+ md = get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename=filename), token=token)
23
+ revision = md.commit_hash
24
+
25
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
26
+
27
+ return file_path, revision
28
+
29
+ def delete_file(revision: str):
30
+ """Delete a file from local cache.
31
+
32
+ Args:
33
+ revision (:obj:`str`): The commit hash of the file.
34
+ Returns:
35
+ None
36
+ """
37
+ scan_cache_dir().delete_revisions(revision).execute()
38
+
39
+ def get_pr_url(api, repo_id, title):
40
+ try:
41
+ discussions = api.get_repo_discussions(repo_id=repo_id)
42
+ except Exception:
43
+ return None
44
+ for discussion in discussions:
45
+ if (
46
+ discussion.status == "open"
47
+ and discussion.is_pull_request
48
+ and discussion.title == title
49
+ ):
50
+ return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/huggingface_hub@main
2
+ git+https://github.com/huggingface/diffusers.git
3
+ torch
4
+ #transformers
5
+ git+https://github.com/huggingface/transformers
6
+ pytorch_lightning
7
+ OmegaConf
8
+ ftfy
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False