Spaces:
Running
Running
File size: 4,438 Bytes
2cc021a e331aa7 2cc021a e331aa7 2cc021a e331aa7 2cc021a e331aa7 d90dedd e331aa7 d5cdfff e331aa7 d5cdfff e331aa7 e7c1d62 2cc021a e331aa7 2cc021a d5cdfff 2cc021a d5cdfff 2cc021a e331aa7 2cc021a e331aa7 d5cdfff e331aa7 d5cdfff d90dedd e331aa7 9c57b91 d5cdfff 2cc021a e331aa7 9c57b91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import requests
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import torch
from io import BytesIO
from huggingface_hub import CommitInfo, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt, download_controlnet_from_original_ckpt
)
from transformers import CONFIG_MAPPING
COMMIT_MESSAGE = " This PR adds fp32 and fp16 weights in PyTorch and safetensors format to {}"
def convert_single(model_id: str, filename: str, model_type: str, sample_size: int, scheduler_type: str, extract_ema: bool, folder: str, progress):
from_safetensors = filename.endswith(".safetensors")
progress(0, desc="Downloading model")
local_file = os.path.join(model_id, filename)
ckpt_file = local_file if os.path.isfile(local_file) else hf_hub_download(repo_id=model_id, filename=filename)
if model_type == "v1":
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
elif model_type == "v2":
if sample_size == 512:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference.yaml"
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
elif model_type == "ControlNet":
config_url = (Path(model_id)/"resolve/main"/filename).with_suffix(".yaml")
config_url = "https://huggingface.co/" + str(config_url)
config_file = BytesIO(requests.get(config_url).content)
if model_type == "ControlNet":
progress(0.2, desc="Converting ControlNet Model")
pipeline = download_controlnet_from_original_ckpt(ckpt_file, config_file, image_size=sample_size, from_safetensors=from_safetensors, extract_ema=extract_ema)
to_args = {"dtype": torch.float16}
else:
progress(0.1, desc="Converting Model")
pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, config_file, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)
to_args = {"torch_dtype": torch.float16}
pipeline.save_pretrained(folder)
pipeline.save_pretrained(folder, safe_serialization=True)
pipeline = pipeline.to(**to_args)
pipeline.save_pretrained(folder, variant="fp16")
pipeline.save_pretrained(folder, safe_serialization=True, variant="fp16")
return folder
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
try:
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
if details.target_branch == "refs/heads/main":
return discussion
def convert(token: str, model_id: str, filename: str, model_type: str, sample_size: int = 512, scheduler_type: str = "pndm", extract_ema: bool = True, progress=gr.Progress()):
api = HfApi()
pr_title = "Adding `diffusers` weights of this model"
with TemporaryDirectory() as d:
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
os.makedirs(folder)
new_pr = None
try:
folder = convert_single(model_id, filename, model_type, sample_size, scheduler_type, extract_ema, folder, progress)
progress(0.7, desc="Uploading to Hub")
new_pr = api.upload_folder(folder_path=folder, path_in_repo="./", repo_id=model_id, repo_type="model", token=token, commit_message=pr_title, commit_description=COMMIT_MESSAGE.format(model_id), create_pr=True)
pr_number = new_pr.split("%2F")[-1].split("/")[0]
link = f"Pr created at: {'https://huggingface.co/' + os.path.join(model_id, 'discussions', pr_number)}"
progress(1, desc="Done")
except Exception as e:
raise gr.exceptions.Error(str(e))
finally:
shutil.rmtree(folder)
return link
|