Spaces:
Running
Running
File size: 8,549 Bytes
d8ca2a9 ae3b68e 03cbfd8 d8ca2a9 3997fa3 d8ca2a9 ae3b68e 0b5e97e d8ca2a9 ae3b68e 26cd1e9 ae3b68e d8ca2a9 cbabf63 d8ca2a9 0b5e97e ae3b68e d8ca2a9 ae3b68e d8ca2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
import subprocess
from huggingface_hub import HfApi, upload_folder
import gradio as gr
import hf_utils
import utils
from safetensors import safe_open
import torch
subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers", "diffs"])
def error_str(error, title="Error"):
return f"""#### {title}
{error}""" if error else ""
def on_token_change(token):
model_names, error = hf_utils.get_my_model_names(token)
if model_names:
model_names.append("Other")
return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(visible=bool(model_names)), gr.update(value=error_str(error))
def url_to_model_id(model_id_str):
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
def get_ckpt_names(token, radio_model_names, input_model):
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
if token == "" or model_id == "":
return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
try:
api = HfApi(token=token)
ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
if not ckpt_files:
return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
except Exception as e:
return error_str(e), gr.update(choices=[]), None
def convert_and_push(radio_model_names, input_model, ckpt_name, sd_version, token, path_in_repo, ema, safetensors):
extract_ema = ema == "ema"
if sd_version == None:
return error_str("You must select a stable diffusion version.", title="Invalid input")
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
try:
model_id = url_to_model_id(model_id)
# 1. Download the checkpoint file
ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
if safetensors == "yes":
tensors = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
new_checkpoint_path = "/".join(ckpt_path.split("/")[:-1] + ["model_safe.ckpt"])
torch.save(tensors, new_checkpoint_path)
ckpt_path = new_checkpoint_path
print("Converting ckpt_path", ckpt_path)
print(ckpt_path)
# 2. Run the conversion script
os.makedirs(model_id, exist_ok=True)
run_command = [
"python3",
"./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
"--checkpoint_path",
ckpt_path,
"--dump_path" ,
model_id,
]
if extract_ema:
run_command.append("--extract_ema")
subprocess.run(run_command)
# 3. Push to the model repo
commit_message="Add Diffusers weights"
upload_folder(
folder_path=model_id,
repo_id=model_id,
path_in_repo=path_in_repo,
token=token,
create_pr=True,
commit_message=commit_message,
commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
)
# # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
hf_utils.delete_file(revision)
subprocess.run(["rm", "-rf", model_id.split('/')[0]])
import glob
for f in glob.glob("*.yaml*"):
subprocess.run(["rm", "-rf", f])
return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
return "Done"
except Exception as e:
return error_str(e)
DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub.
You can skip the queue by running the app in the colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/qunash/f0f3152c5851c0c477b68b7b98d547fe/convert-sd-to-diffusers.ipynb)"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=11):
with gr.Column():
gr.Markdown("## 1. Load model info")
input_token = gr.Textbox(
max_lines=1,
type="password",
label="Enter your Hugging Face token",
placeholder="READ permission is sufficient"
)
gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
with gr.Group(visible=False) as group_model:
radio_model_names = gr.Radio(label="Choose a model")
input_model = gr.Textbox(
max_lines=1,
label="Model name or URL",
placeholder="username/model_name",
visible=False,
)
btn_get_ckpts = gr.Button("Load", visible=False)
with gr.Column(scale=10):
with gr.Column(visible=False) as group_convert:
gr.Markdown("## 2. Convert to Diffusers🧨")
radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
path_in_repo = gr.Textbox(label="Path where the weights will be saved", placeholder="Leave empty for root folder")
ema = gr.Radio(label="Extract EMA or non-EMA?", choices=["ema", "non-ema"])
safetensors = gr.Radio(label="Extract from safetensors", choices=["yes", "no"], value="no")
radio_sd_version = gr.Radio(label="Choose the model version", choices=["v1", "v2", "v2.1"])
gr.Markdown("Conversion may take a few minutes.")
btn_convert = gr.Button("Convert & Push")
error_output = gr.Markdown(label="Output")
input_token.change(
fn=on_token_change,
inputs=input_token,
outputs=[group_model, radio_model_names, btn_get_ckpts, error_output],
queue=False,
scroll_to_output=True)
radio_model_names.change(
lambda x: gr.update(visible=x == "Other"),
inputs=radio_model_names,
outputs=input_model,
queue=False,
scroll_to_output=True)
btn_get_ckpts.click(
fn=get_ckpt_names,
inputs=[input_token, radio_model_names, input_model],
outputs=[error_output, radio_ckpts, group_convert],
scroll_to_output=True,
queue=False
)
btn_convert.click(
fn=convert_and_push,
inputs=[radio_model_names, input_model, radio_ckpts, radio_sd_version, input_token, path_in_repo, ema, safetensors],
outputs=error_output,
scroll_to_output=True
)
# gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
gr.HTML("""
<div style="border-top: 1px solid #303030;">
<br>
<p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
<a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
<p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
</div>
""")
demo.queue()
demo.launch(debug=True, share=utils.is_google_colab())
|