Spaces:
Sleeping
Sleeping
import json | |
import torch | |
from safetensors.torch import load_file, save_file | |
from pathlib import Path | |
import gc | |
import gguf | |
from dequant import dequantize_tensor # https://github.com/city96/ComfyUI-GGUF | |
import os | |
import argparse | |
import gradio as gr | |
# also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning | |
import subprocess | |
subprocess.run('pip cache purge', shell=True) | |
import spaces | |
def spaces_dummy(): | |
pass | |
flux_dev_repo = "ChuckMcSneed/FLUX.1-dev" | |
flux_schnell_repo = "black-forest-labs/FLUX.1-schnell" | |
system_temp_dir = "temp" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch.set_grad_enabled(False) | |
GGUF_QTYPE = [gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q5_1, | |
gguf.GGMLQuantizationType.Q5_0, gguf.GGMLQuantizationType.Q4_1, | |
gguf.GGMLQuantizationType.Q4_0, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] | |
TORCH_DTYPE = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, | |
torch.bfloat16, torch.complex32, torch.chalf, torch.complex64, torch.cfloat, | |
torch.complex128, torch.cdouble, torch.uint8, torch.uint16, torch.uint32, torch.uint64, | |
torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long, | |
torch.bool, torch.float8_e4m3fn, torch.float8_e5m2] | |
TORCH_QUANTIZED_DTYPE = [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2] | |
def list_sub(a, b): | |
return [e for e in a if e not in b] | |
def is_repo_name(s): | |
import re | |
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) | |
def clear_cache(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def clear_sd(sd: dict): | |
for k in list(sd.keys()): | |
sd.pop(k) | |
del sd | |
torch.cuda.empty_cache() | |
gc.collect() | |
def clone_sd(sd: dict): | |
from copy import deepcopy | |
print("Cloning state dict.") | |
for k in list(sd.keys()): | |
sd[k] = deepcopy(sd.pop(k)) | |
#sd[k] = sd.pop(k).detach().clone().to(device="cpu") | |
torch.cuda.empty_cache() | |
gc.collect() | |
def print_resource_usage(): | |
import psutil | |
cpu_usage = psutil.cpu_percent() | |
ram_usage = psutil.virtual_memory().used / psutil.virtual_memory().total * 100 | |
print(f"CPU usage: {cpu_usage}% / RAM usage: {ram_usage}%") | |
def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Start downloading...") | |
url = url.strip() | |
if "drive.google.com" in url: | |
original_dir = os.getcwd() | |
os.chdir(directory) | |
os.system(f"gdown --fuzzy {url}") | |
os.chdir(original_dir) | |
elif "huggingface.co" in url: | |
url = url.replace("?download=true", "") | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
else: | |
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
elif "civitai.com" in url: | |
if "?" in url: | |
url = url.split("?")[0] | |
if civitai_api_key: | |
url = url + f"?token={civitai_api_key}" | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
else: | |
print("You need an API key to download Civitai models.") | |
else: | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
def get_local_model_list(dir_path): | |
model_list = [] | |
valid_extensions = ('.safetensors') | |
for file in Path(dir_path).glob("*"): | |
if file.suffix in valid_extensions: | |
file_path = str(Path(f"{dir_path}/{file.name}")) | |
model_list.append(file_path) | |
return model_list | |
def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)): | |
if not "http" in url and is_repo_name(url) and not Path(url).exists(): | |
print(f"Use HF Repo: {url}") | |
new_file = url | |
elif not "http" in url and Path(url).exists(): | |
print(f"Use local file: {url}") | |
new_file = url | |
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): | |
print(f"File to download alreday exists: {url}") | |
new_file = f"{temp_dir}/{url.split('/')[-1]}" | |
else: | |
print(f"Start downloading: {url}") | |
before = get_local_model_list(temp_dir) | |
try: | |
download_thing(temp_dir, url.strip(), civitai_key) | |
except Exception: | |
print(f"Download failed: {url}") | |
return "" | |
after = get_local_model_list(temp_dir) | |
new_file = list_sub(after, before)[0] if list_sub(after, before) else "" | |
if not new_file: | |
print(f"Download failed: {url}") | |
return "" | |
print(f"Download completed: {url}") | |
return new_file | |
def save_readme_md(dir, url): | |
orig_url = "" | |
if "http" in url: | |
orig_url = url | |
if orig_url: | |
md = f"""--- | |
license: other | |
license_name: flux-1-dev-non-commercial-license | |
license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE. | |
language: | |
- en | |
library_name: diffusers | |
pipeline_tag: text-to-image | |
tags: | |
- text-to-image | |
- Flux | |
--- | |
Converted from [{orig_url}]({orig_url}). | |
""" | |
else: | |
md = f"""--- | |
license: other | |
license_name: flux-1-dev-non-commercial-license | |
license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE. | |
language: | |
- en | |
library_name: diffusers | |
pipeline_tag: text-to-image | |
tags: | |
- text-to-image | |
- Flux | |
--- | |
""" | |
path = str(Path(dir, "README.md")) | |
with open(path, mode='w', encoding="utf-8") as f: | |
f.write(md) | |
def is_repo_exists(repo_id): | |
from huggingface_hub import HfApi | |
api = HfApi() | |
try: | |
if api.repo_exists(repo_id=repo_id): return True | |
else: return False | |
except Exception as e: | |
print(f"Error: Failed to connect {repo_id}. ") | |
return True # for safe | |
def create_diffusers_repo(new_repo_id, diffusers_folder, is_private, is_overwrite, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import HfApi | |
import os | |
hf_token = os.environ.get("HF_TOKEN") | |
api = HfApi() | |
try: | |
progress(0, desc="Start uploading...") | |
api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=is_overwrite) | |
for path in Path(diffusers_folder).glob("*"): | |
if path.is_dir(): | |
api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token) | |
elif path.is_file(): | |
api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token) | |
progress(1, desc="Uploaded.") | |
url = f"https://huggingface.co/{new_repo_id}" | |
except Exception as e: | |
print(f"Error: Failed to upload to {new_repo_id}. ") | |
print(e) | |
return "" | |
return url | |
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py | |
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; | |
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation | |
with torch.no_grad(), torch.autocast(device): | |
def swap_scale_shift(weight): | |
shift, scale = weight.chunk(2, dim=0) | |
new_weight = torch.cat([scale, shift], dim=0) | |
return new_weight | |
with torch.no_grad(), torch.autocast(device): | |
def convert_flux_transformer_checkpoint_to_diffusers( | |
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0, | |
progress=gr.Progress(track_tqdm=True)): | |
def conv(cdict: dict, odict: dict, ckey: str, okey: str): | |
if okey in odict.keys(): | |
progress(0, desc=f"Converting {okey} => {ckey}") | |
print(f"Converting {okey} => {ckey}") | |
cdict[ckey] = odict.pop(okey) | |
gc.collect() | |
def convswap(cdict: dict, odict: dict, ckey: str, okey: str): | |
if okey in odict.keys(): | |
progress(0, desc=f"Converting (swap) {okey} => {ckey}") | |
print(f"Converting {okey} => {ckey} (swap)") | |
cdict[ckey] = swap_scale_shift(odict.pop(okey)) | |
gc.collect() | |
def convqkv(cdict: dict, odict: dict, i: int): | |
keys = odict.keys() | |
if (f"double_blocks.{i}.img_attn.qkv.weight" in keys or f"double_blocks.{i}.txt_attn.qkv.weight" in keys\ | |
or f"double_blocks.{i}.img_attn.qkv.bias" in keys or f"double_blocks.{i}.txt_attn.qkv.bias" in keys)\ | |
and (f"double_blocks.{i}.img_attn.qkv.weight" not in keys or f"double_blocks.{i}.txt_attn.qkv.weight" not in keys\ | |
or f"double_blocks.{i}.img_attn.qkv.bias" not in keys or f"double_blocks.{i}.txt_attn.qkv.bias" not in keys): | |
progress(0, desc=f"Key error in converting Q, K, V (double_blocks.{i}).") | |
print(f"Key error in converting Q, K, V (double_blocks.{i}).") | |
return | |
progress(0, desc=f"Converting Q, K, V (double_blocks.{i}).") | |
print(f"Converting Q, K, V (double_blocks.{i}).") | |
sample_q, sample_k, sample_v = torch.chunk( | |
odict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 | |
) | |
context_q, context_k, context_v = torch.chunk( | |
odict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 | |
) | |
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( | |
odict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 | |
) | |
context_q_bias, context_k_bias, context_v_bias = torch.chunk( | |
odict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 | |
) | |
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) | |
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) | |
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) | |
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) | |
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) | |
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) | |
cdict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) | |
cdict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) | |
cdict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) | |
cdict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) | |
cdict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) | |
cdict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) | |
gc.collect() | |
def convqkvmlp(cdict: dict, odict: dict, i: int, inner_dim: int, mlp_ratio: float): | |
keys = odict.keys() | |
if (f"single_blocks.{i}.linear1.weight" in keys or f"single_blocks.{i}.linear1.bias" in keys)\ | |
and (f"single_blocks.{i}.linear1.weight" not in keys or f"single_blocks.{i}.linear1.bias" not in keys): | |
progress(0, desc=f"Key error in converting Q, K, V, mlp (single_blocks.{i}).") | |
print(f"Key error in converting Q, K, V, mlp (single_blocks.{i}).") | |
return | |
progress(0, desc=f"Converting Q, K, V, mlp (single_blocks.{i}).") | |
print(f"Converting Q, K, V, mlp (single_blocks.{i}).") | |
mlp_hidden_dim = int(inner_dim * mlp_ratio) | |
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) | |
q, k, v, mlp = torch.split(odict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) | |
q_bias, k_bias, v_bias, mlp_bias = torch.split( | |
odict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 | |
) | |
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) | |
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) | |
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) | |
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) | |
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) | |
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) | |
cdict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) | |
cdict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) | |
gc.collect() | |
converted_state_dict = {} | |
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.") | |
## time_text_embed.timestep_embedder <- time_in | |
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.bias", "time_in.in_layer.bias") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.weight", "time_in.out_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.bias", "time_in.out_layer.bias") | |
## time_text_embed.text_embedder <- vector_in | |
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.bias", "vector_in.in_layer.bias") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.bias", "vector_in.out_layer.bias") | |
# guidance | |
has_guidance = any("guidance" in k for k in original_state_dict) | |
if has_guidance: | |
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.bias", "guidance_in.in_layer.bias") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight") | |
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.bias", "guidance_in.out_layer.bias") | |
# context_embedder | |
conv(converted_state_dict, original_state_dict, "context_embedder.weight", "txt_in.weight") | |
conv(converted_state_dict, original_state_dict, "context_embedder.bias", "txt_in.bias") | |
# x_embedder | |
conv(converted_state_dict, original_state_dict, "x_embedder.weight", "img_in.weight") | |
conv(converted_state_dict, original_state_dict, "x_embedder.bias", "img_in.bias") | |
progress(0.25, desc="Converting FLUX.1 state dict to Diffusers format.") | |
# double transformer blocks | |
for i in range(num_layers): | |
block_prefix = f"transformer_blocks.{i}." | |
# norms. | |
## norm1 | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.weight", f"double_blocks.{i}.img_mod.lin.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.bias", f"double_blocks.{i}.img_mod.lin.bias") | |
## norm1_context | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.weight", f"double_blocks.{i}.txt_mod.lin.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.bias", f"double_blocks.{i}.txt_mod.lin.bias") | |
# Q, K, V | |
convqkv(converted_state_dict, original_state_dict, i) | |
# qk_norm | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"double_blocks.{i}.img_attn.norm.query_norm.scale") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"double_blocks.{i}.img_attn.norm.key_norm.scale") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_q.weight", f"double_blocks.{i}.txt_attn.norm.query_norm.scale") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_k.weight", f"double_blocks.{i}.txt_attn.norm.key_norm.scale") | |
# ff img_mlp | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.weight", f"double_blocks.{i}.img_mlp.0.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.bias", f"double_blocks.{i}.img_mlp.0.bias") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.weight", f"double_blocks.{i}.img_mlp.2.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.bias", f"double_blocks.{i}.img_mlp.2.bias") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{i}.txt_mlp.0.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.bias", f"double_blocks.{i}.txt_mlp.0.bias") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.weight", f"double_blocks.{i}.txt_mlp.2.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.bias", f"double_blocks.{i}.txt_mlp.2.bias") | |
# output projections. | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.weight", f"double_blocks.{i}.img_attn.proj.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.bias", f"double_blocks.{i}.img_attn.proj.bias") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.weight", f"double_blocks.{i}.txt_attn.proj.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.bias", f"double_blocks.{i}.txt_attn.proj.bias") | |
progress(0.5, desc="Converting FLUX.1 state dict to Diffusers format.") | |
# single transfomer blocks | |
for i in range(num_single_layers): | |
block_prefix = f"single_transformer_blocks.{i}." | |
# norm.linear <- single_blocks.0.modulation.lin | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.weight", f"single_blocks.{i}.modulation.lin.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.bias", f"single_blocks.{i}.modulation.lin.bias") | |
# Q, K, V, mlp | |
convqkvmlp(converted_state_dict, original_state_dict, i, inner_dim, mlp_ratio) | |
# qk norm | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"single_blocks.{i}.norm.query_norm.scale") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"single_blocks.{i}.norm.key_norm.scale") | |
# output projections. | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.weight", f"single_blocks.{i}.linear2.weight") | |
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.bias", f"single_blocks.{i}.linear2.bias") | |
progress(0.75, desc="Converting FLUX.1 state dict to Diffusers format.") | |
conv(converted_state_dict, original_state_dict, "proj_out.weight", "final_layer.linear.weight") | |
conv(converted_state_dict, original_state_dict, "proj_out.bias", "final_layer.linear.bias") | |
convswap(converted_state_dict, original_state_dict, "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight") | |
convswap(converted_state_dict, original_state_dict, "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias") | |
progress(1, desc="Converting FLUX.1 state dict to Diffusers format.") | |
return converted_state_dict | |
# read safetensors metadata | |
def read_safetensors_metadata(path): | |
with open(path, 'rb') as f: | |
header_size = int.from_bytes(f.read(8), 'little') | |
header_json = f.read(header_size).decode('utf-8') | |
header = json.loads(header_json) | |
metadata = header.get('__metadata__', {}) | |
return metadata.copy() | |
def normalize_key(k: str): | |
return k.replace("vae.", "").replace("model.diffusion_model.", "")\ | |
.replace("text_encoders.clip_l.transformer.", "")\ | |
.replace("text_encoders.t5xxl.transformer.", "") | |
def load_json_list(path: str): | |
try: | |
with open(path, encoding='utf-8') as f: | |
return list(json.load(f)) | |
except Exception as e: | |
print(e) | |
return [] | |
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/modeling_utils.py | |
# https://huggingface.co/docs/huggingface_hub/v0.24.5/package_reference/serialization | |
# https://huggingface.co/docs/huggingface_hub/index | |
with torch.no_grad(): | |
def to_safetensors(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import save_torch_state_dict | |
print(f"Saving a temporary file to disk: {path}") | |
os.makedirs(path, exist_ok=True) | |
try: | |
for k, v in sd.items(): | |
sd[k] = v.to(device="cpu") | |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) | |
except Exception as e: | |
print(e) | |
# https://discuss.huggingface.co/t/t5forconditionalgeneration-checkpoint-size-mismatch-19418/24119 | |
# https://github.com/huggingface/transformers/issues/13769 | |
# https://github.com/huggingface/optimum-quanto/issues/278 | |
# https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py | |
# https://huggingface.co/docs/accelerate/usage_guides/big_modeling | |
with torch.no_grad(): | |
def to_safetensors_flux_module(sd: dict, path: str, pattern: str, size: str, | |
quantization: bool=False, name: str = "", | |
metadata: dict | None = None, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import save_torch_state_dict, save_torch_model | |
from accelerate import init_empty_weights | |
try: | |
progress(0, desc=f"Preparing to save FLUX.1 {name} to Diffusers format.") | |
print(f"Preparing to save FLUX.1 {name} to Diffusers format.") | |
for k, v in sd.items(): | |
sd[k] = v.to(device="cpu") | |
progress(0, desc=f"Loading FLUX.1 {name}.") | |
print(f"Loading FLUX.1 {name}.") | |
os.makedirs(path, exist_ok=True) | |
if quantization: | |
progress(0.5, desc=f"Saving quantized FLUX.1 {name} to {path}") | |
print(f"Saving quantized FLUX.1 {name} to {path}") | |
else: | |
progress(0.5, desc=f"Saving FLUX.1 {name} to: {path}") | |
if False and path.endswith("/transformer"): | |
from diffusers import FluxTransformer2DModel | |
has_guidance = any("guidance" in k for k in sd) | |
with init_empty_weights(): | |
model = FluxTransformer2DModel(guidance_embeds=has_guidance) | |
model.to("cpu") | |
model.load_state_dict(sd, strict=True) | |
print(f"Saving FLUX.1 {name} to: {path} (FluxTransformer2DModel)") | |
if metadata is not None: | |
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}") | |
save_torch_model(model=model, save_directory=path, | |
filename_pattern=pattern, max_shard_size=size, metadata=metadata) | |
else: | |
save_torch_model(model=model, save_directory=path, | |
filename_pattern=pattern, max_shard_size=size) | |
else: | |
print(f"Saving FLUX.1 {name} to: {path}") | |
if metadata is not None: | |
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}") | |
save_torch_state_dict(state_dict=sd, save_directory=path, | |
filename_pattern=pattern, max_shard_size=size, metadata=metadata) | |
else: | |
save_torch_state_dict(state_dict=sd, save_directory=path, | |
filename_pattern=pattern, max_shard_size=size) | |
progress(1, desc=f"Saved FLUX.1 {name} to: {path}") | |
print(f"Saved FLUX.1 {name} to: {path}") | |
except Exception as e: | |
print(e) | |
finally: | |
gc.collect() | |
flux_transformer_json = "flux_transformer_keys.json" | |
flux_t5xxl_json = "flux_t5xxl_keys.json" | |
flux_clip_json = "flux_clip_keys.json" | |
flux_vae_json = "flux_vae_keys.json" | |
keys_flux_t5xxl = set(load_json_list(flux_t5xxl_json)) | |
keys_flux_transformer = set(load_json_list(flux_transformer_json)) | |
keys_flux_clip = set(load_json_list(flux_clip_json)) | |
keys_flux_vae = set(load_json_list(flux_vae_json)) | |
with torch.no_grad(): | |
def dequant_tensor(v: torch.Tensor, dtype: torch.dtype, dequant: bool): | |
try: | |
#print(f"shape: {v.shape} / dim: {v.ndim}") | |
if dequant: | |
qtype = v.tensor_type | |
if v.dtype in TORCH_DTYPE: return v.to(dtype) if v.dtype != dtype else v | |
elif qtype in GGUF_QTYPE: return dequantize_tensor(v, dtype) | |
elif torch.dtype in TORCH_QUANTIZED_DTYPE: return torch.dequantize(v).to(dtype) | |
else: return torch.dequantize(v).to(dtype) | |
else: return v.to(dtype) if v.dtype != dtype else v | |
except Exception as e: | |
print(e) | |
with torch.no_grad(): | |
def normalize_flux_state_dict(path: str, savepath: str, dtype: torch.dtype = torch.bfloat16, | |
dequant: bool = False, progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc=f"Loading and normalizing FLUX.1 safetensors: {path}") | |
print(f"Loading and normalizing FLUX.1 safetensors: {path}") | |
new_sd = dict() | |
state_dict = load_file(path, device="cpu") | |
try: | |
for k in list(state_dict.keys()): | |
v = state_dict.pop(k) | |
nk = normalize_key(k) | |
print(f"{k} => {nk}") # | |
new_sd[nk] = dequant_tensor(v, dtype, dequant) | |
except Exception as e: | |
print(e) | |
return | |
finally: | |
clear_sd(state_dict) | |
new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix)) | |
metadata = read_safetensors_metadata(path) | |
progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}") | |
print(f"Saving FLUX.1 safetensors: {new_path}") | |
os.makedirs(savepath, exist_ok=True) | |
save_file(new_sd, new_path, metadata={"format": "pt", **metadata}) | |
progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}") | |
print(f"Saved FLUX.1 safetensors: {new_path}") | |
clear_sd(new_sd) | |
with torch.no_grad(): | |
def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16, | |
dequant: bool = False, name: str = "", keys: set = {}, | |
progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc=f"Loading and normalizing FLUX.1 {name} safetensors: {path}") | |
print(f"Loading and normalizing FLUX.1 {name} safetensors: {path}") | |
new_sd = dict() | |
state_dict = load_file(path, device="cpu") | |
try: | |
for k in list(state_dict.keys()): | |
if k not in keys: state_dict.pop(k) | |
gc.collect() | |
for k in list(state_dict.keys()): | |
v = state_dict.pop(k) | |
if k in keys: | |
nk = normalize_key(k) | |
progress(0.5, desc=f"{k} => {nk}") # | |
print(f"{k} => {nk}") # | |
new_sd[nk] = dequant_tensor(v, dtype, dequant) | |
#print_resource_usage() # | |
except Exception as e: | |
print(e) | |
return None | |
finally: | |
progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}") | |
print(f"Normalized FLUX.1 {name} safetensors: {path}") | |
clear_sd(state_dict) | |
return new_sd | |
with torch.no_grad(): | |
def convert_flux_transformer_sd_to_diffusers(sd: dict, progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.") | |
print("Converting FLUX.1 state dict to Diffusers format.") | |
num_layers = 19 | |
num_single_layers = 38 | |
inner_dim = 3072 | |
mlp_ratio = 4.0 | |
try: | |
sd = convert_flux_transformer_checkpoint_to_diffusers( | |
sd, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio | |
) | |
except Exception as e: | |
print(e) | |
finally: | |
progress(1, desc="Converted FLUX.1 state dict to Diffusers format.") | |
print("Converted FLUX.1 state dict to Diffusers format.") | |
gc.collect() | |
return sd | |
with torch.no_grad(): | |
def load_sharded_safetensors(path: str): | |
import glob | |
sd = {} | |
try: | |
for filepath in glob.glob(f"{path}/*.safetensors"): | |
sharded_sd = load_file(str(filepath), device="cpu") | |
for k, v in sharded_sd.items(): | |
sharded_sd[k] = v.to(device="cpu") | |
sd = sd | sharded_sd.copy() | |
clear_sd(sharded_sd) | |
except Exception as e: | |
print(e) | |
return sd | |
# https://huggingface.co/docs/safetensors/api/torch | |
with torch.no_grad(): | |
def convert_flux_transformer_sd_to_diffusers_sharded(sd: dict, path: str, pattern: str, | |
size: str, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import save_torch_state_dict#, load_torch_model | |
import glob | |
try: | |
progress(0, desc=f"Saving temporary files to disk: {path}") | |
print(f"Saving temporary files to disk: {path}") | |
os.makedirs(path, exist_ok=True) | |
for k, v in sd.items(): | |
if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu") | |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) | |
clear_sd(sd) | |
progress(0.25, desc=f"Saved temporary files to disk: {path}") | |
print(f"Saved temporary files to disk: {path}") | |
for filepath in glob.glob(f"{path}/*.safetensors"): | |
progress(0.25, desc=f"Processing temporary files: {str(filepath)}") | |
print(f"Processing temporary files: {str(filepath)}") | |
sharded_sd = load_file(str(filepath), device="cpu") | |
sharded_sd = convert_flux_transformer_sd_to_diffusers(sharded_sd) | |
for k, v in sharded_sd.items(): | |
sharded_sd[k] = v.to(device="cpu") | |
save_file(sharded_sd, str(filepath)) | |
clear_sd(sharded_sd) | |
print(f"Loading temporary files from disk: {path}") | |
sd = load_sharded_safetensors(path) | |
print(f"Loaded temporary files from disk: {path}") | |
except Exception as e: | |
print(e) | |
return sd | |
with torch.no_grad(): | |
def extract_normalized_flux_state_dict_sharded(loadpath: str, dtype: torch.dtype, | |
dequant: bool, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import save_torch_state_dict#, load_torch_model | |
import glob | |
try: | |
progress(0, desc=f"Loading model file: {loadpath}") | |
print(f"Loading model file: {loadpath}") | |
sd = load_file(loadpath, device="cpu") | |
progress(0, desc=f"Saving temporary files to disk: {path}") | |
print(f"Saving temporary files to disk: {path}") | |
os.makedirs(path, exist_ok=True) | |
for k, v in sd.items(): | |
sd[k] = v.to(device="cpu") | |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) | |
clear_sd(sd) | |
progress(0.25, desc=f"Saved temporary files to disk: {path}") | |
print(f"Saved temporary files to disk: {path}") | |
for filepath in glob.glob(f"{path}/*.safetensors"): | |
progress(0.25, desc=f"Processing temporary files: {str(filepath)}") | |
print(f"Processing temporary files: {str(filepath)}") | |
sharded_sd = extract_norm_flux_module_sd(str(filepath), dtype, dequant, | |
"Transformer", keys_flux_transformer) | |
for k, v in sharded_sd.items(): | |
sharded_sd[k] = v.to(device="cpu") | |
save_file(sharded_sd, str(filepath)) | |
clear_sd(sharded_sd) | |
print(f"Processed temporary files: {str(filepath)}") | |
print(f"Loading temporary files from disk: {path}") | |
sd = load_sharded_safetensors(path) | |
print(f"Loaded temporary files from disk: {path}") | |
except Exception as e: | |
print(e) | |
return sd | |
def download_repo(repo_name, path, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)): | |
from huggingface_hub import snapshot_download | |
print(f"Downloading {repo_name}.") | |
try: | |
if "text_encoder_2" in use_original: | |
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"]) | |
else: | |
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "text_encoder_2/model*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"]) | |
except Exception as e: | |
print(e) | |
def copy_nontensor_files(from_path, to_path, use_original=["vae", "text_encoder"]): | |
import shutil | |
if "text_encoder_2" in use_original: | |
te_from = str(Path(from_path, "text_encoder_2")) | |
te_to = str(Path(to_path, "text_encoder_2")) | |
print(f"Copying Text Encoder 2 files {te_from} to {te_to}") | |
shutil.copytree(te_from, te_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) | |
if "text_encoder" in use_original: | |
te1_from = str(Path(from_path, "text_encoder")) | |
te1_to = str(Path(to_path, "text_encoder")) | |
print(f"Copying Text Encoder 1 files {te1_from} to {te1_to}") | |
shutil.copytree(te1_from, te1_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) | |
if "vae" in use_original: | |
vae_from = str(Path(from_path, "vae")) | |
vae_to = str(Path(to_path, "vae")) | |
print(f"Copying VAE files {vae_from} to {vae_to}") | |
shutil.copytree(vae_from, vae_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) | |
tn2_from = str(Path(from_path, "tokenizer_2")) | |
tn2_to = str(Path(to_path, "tokenizer_2")) | |
print(f"Copying Tokenizer 2 files {tn2_from} to {tn2_to}") | |
shutil.copytree(tn2_from, tn2_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) | |
print(f"Copying non-tensor files {from_path} to {to_path}") | |
shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp", "*.index.json"), dirs_exist_ok=True) | |
def save_flux_other_diffusers(path: str, model_type: str = "dev", use_original: list = ["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)): | |
import shutil | |
progress(0, desc="Loading FLUX.1 Components.") | |
print("Loading FLUX.1 Components.") | |
temppath = system_temp_dir | |
if model_type == "schnell": repo = flux_schnell_repo | |
else: repo = flux_dev_repo | |
os.makedirs(temppath, exist_ok=True) | |
os.makedirs(path, exist_ok=True) | |
download_repo(repo, temppath, use_original) | |
progress(0.5, desc="Saving FLUX.1 Components.") | |
print("Saving FLUX.1 Components.") | |
copy_nontensor_files(temppath, path, use_original) | |
shutil.rmtree(temppath) | |
with torch.no_grad(): | |
def fix_flux_safetensors(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, | |
quantization: bool = False, model_type: str = "dev", dequant: bool = False): | |
save_flux_other_diffusers(savepath, model_type) | |
normalize_flux_state_dict(loadpath, savepath, dtype, dequant) | |
clear_cache() | |
with torch.no_grad(): # Much lower memory consumption, but higher disk load | |
def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, | |
quantization: bool = False, model_type: str = "dev", | |
dequant: bool = False, use_original: list = ["vae", "text_encoder"], | |
new_repo_id: str = "", local: bool = False, progress=gr.Progress(track_tqdm=True)): | |
unet_sd_path = savepath.removesuffix("/") + "/transformer" | |
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
unet_sd_size = "9.5GB" | |
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2" | |
te_sd_pattern = "model{suffix}.safetensors" | |
te_sd_size = "5GB" | |
clip_sd_path = savepath.removesuffix("/") + "/text_encoder" | |
clip_sd_pattern = "model{suffix}.safetensors" | |
clip_sd_size = "9.5GB" | |
vae_sd_path = savepath.removesuffix("/") + "/vae" | |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
vae_sd_size = "9.5GB" | |
print_resource_usage() # | |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)} | |
clear_cache() | |
print_resource_usage() # | |
if "vae" not in use_original: | |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE", | |
keys_flux_vae) | |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size, | |
quantization, "VAE", None) | |
clear_sd(vae_sd) | |
print_resource_usage() # | |
if "text_encoder" not in use_original: | |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder", | |
keys_flux_clip) | |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size, | |
quantization, "Text Encoder", None) | |
clear_sd(clip_sd) | |
print_resource_usage() # | |
if "text_encoder_2" not in use_original: | |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2", | |
keys_flux_t5xxl) | |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size, | |
quantization, "Text Encoder 2", None) | |
clear_sd(te_sd) | |
print_resource_usage() # | |
unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer", | |
keys_flux_transformer) | |
clear_cache() | |
print_resource_usage() # | |
if not local: | |
os.remove(loadpath) | |
print("Deleted downloaded file.") | |
clear_cache() | |
print_resource_usage() # | |
unet_sd = convert_flux_transformer_sd_to_diffusers(unet_sd) | |
clear_cache() | |
print_resource_usage() # | |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size, | |
quantization, "Transformer", metadata) | |
clear_sd(unet_sd) | |
print_resource_usage() # | |
save_flux_other_diffusers(savepath, model_type, use_original) | |
print_resource_usage() # | |
with torch.no_grad(): # lowest memory consumption, but higheest disk load | |
def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, | |
quantization: bool = False, model_type: str = "dev", | |
dequant: bool = False, use_original: list = ["vae", "text_encoder"], | |
new_repo_id: str = "", progress=gr.Progress(track_tqdm=True)): | |
unet_sd_path = savepath.removesuffix("/") + "/transformer" | |
unet_temp_path = system_temp_dir.removesuffix("/") + "/sharded" | |
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
unet_sd_size = "10GB" | |
unet_temp_size = "5GB" | |
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2" | |
te_sd_pattern = "model{suffix}.safetensors" | |
te_sd_size = "5GB" | |
clip_sd_path = savepath.removesuffix("/") + "/text_encoder" | |
clip_sd_pattern = "model{suffix}.safetensors" | |
clip_sd_size = "10GB" | |
vae_sd_path = savepath.removesuffix("/") + "/vae" | |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
vae_sd_size = "10GB" | |
print_resource_usage() # | |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)} | |
clear_cache() | |
print_resource_usage() # | |
if "vae" not in use_original: | |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE", | |
keys_flux_vae) | |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size, | |
quantization, "VAE", None) | |
clear_sd(vae_sd) | |
print_resource_usage() # | |
if "text_encoder" not in use_original: | |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder", | |
keys_flux_clip) | |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size, | |
quantization, "Text Encoder", None) | |
clear_sd(clip_sd) | |
print_resource_usage() # | |
if "text_encoder_2" not in use_original: | |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2", | |
keys_flux_t5xxl) | |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size, | |
quantization, "Text Encoder 2", None) | |
clear_sd(te_sd) | |
print_resource_usage() # | |
unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant, | |
unet_temp_path, unet_sd_pattern, unet_temp_size) | |
clear_cache() | |
print_resource_usage() # | |
unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path, | |
unet_sd_pattern, unet_temp_size) | |
clear_cache() | |
print_resource_usage() # | |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size, | |
quantization, "Transformer", metadata) | |
clear_sd(unet_sd) | |
print_resource_usage() # | |
save_flux_other_diffusers(savepath, model_type, use_original) | |
print_resource_usage() # | |
def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16", | |
model_type="dev", dequant=False, use_original=["vae", "text_encoder"], | |
hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Start converting...") | |
temp_dir = "." | |
print_resource_usage() # | |
new_file = get_download_file(temp_dir, url, civitai_key) | |
if not new_file: | |
print(f"Not found: {url}") | |
return "" | |
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # | |
dtype = torch.bfloat16 | |
quantization = False | |
if data_type == "fp8": dtype = torch.float8_e4m3fn | |
elif data_type == "fp16": dtype = torch.float16 | |
elif data_type == "qfloat8": | |
dtype = torch.bfloat16 | |
quantization = True | |
else: dtype = torch.bfloat16 | |
new_repo_id = f"{hf_user}/{Path(new_repo_name).stem}" | |
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" | |
flux_to_diffusers_lowmem(new_file, new_repo_name, dtype, quantization, model_type, dequant, use_original, new_repo_id) | |
"""if is_upload_sf: | |
import shutil | |
shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve())) | |
else: os.remove(new_file)""" | |
progress(1, desc="Converted.") | |
q.put(new_repo_name) | |
return new_repo_name | |
def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=False, data_type="bf16", | |
model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Start converting...") | |
temp_dir = "." | |
print_resource_usage() # | |
new_file = get_download_file(temp_dir, url, civitai_key) | |
if not new_file: | |
print(f"Not found: {url}") | |
return "" | |
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # | |
dtype = torch.bfloat16 | |
quantization = False | |
if data_type == "fp8": dtype = torch.float8_e4m3fn | |
elif data_type == "fp16": dtype = torch.float16 | |
elif data_type == "qfloat8": | |
dtype = torch.bfloat16 | |
quantization = True | |
else: dtype = torch.bfloat16 | |
fix_flux_safetensors(new_file, new_repo_name, dtype, model_type, dequant) | |
os.remove(new_file) | |
progress(1, desc="Converted.") | |
q.put(new_repo_name) | |
return new_repo_name | |
def convert_url_to_diffusers_repo_flux(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False, | |
is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False, | |
repo_urls=[], fix_only=False, use_original=["vae", "text_encoder"], | |
progress=gr.Progress(track_tqdm=True)): | |
import multiprocessing as mp | |
import shutil | |
if not hf_user: | |
print(f"Invalid user name: {hf_user}") | |
progress(1, desc=f"Invalid user name: {hf_user}") | |
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") | |
if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token | |
if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") | |
q = mp.Queue() | |
if fix_only: | |
p = mp.Process(target=convert_url_to_fixed_flux_safetensors, args=(dl_url, civitai_key, | |
is_upload_sf, data_type, model_type, dequant, q)) | |
#new_path = convert_url_to_fixed_flux_safetensors(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant) | |
else: | |
p = mp.Process(target=convert_url_to_diffusers_flux, args=(dl_url, civitai_key, | |
is_upload_sf, data_type, model_type, dequant, use_original, hf_user, hf_repo, q)) | |
#new_path = convert_url_to_diffusers_flux(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant) | |
p.start() | |
new_path = q.get() | |
p.join() | |
if not new_path: return "" | |
new_repo_id = f"{hf_user}/{Path(new_path).stem}" | |
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" | |
if not is_repo_name(new_repo_id): | |
print(f"Invalid repo name: {new_repo_id}") | |
progress(1, desc=f"Invalid repo name: {new_repo_id}") | |
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") | |
if not is_overwrite and is_repo_exists(new_repo_id): | |
print(f"Repo already exists: {new_repo_id}") | |
progress(1, desc=f"Repo already exists: {new_repo_id}") | |
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") | |
#save_readme_md(new_path, dl_url) | |
repo_url = create_diffusers_repo(new_repo_id, new_path, is_private, is_overwrite) | |
shutil.rmtree(new_path) | |
if not repo_urls: repo_urls = [] | |
repo_urls.append(repo_url) | |
md = "Your new repo:<br>" | |
for u in repo_urls: | |
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--url", default=None, type=str, required=False, help="URL of the model to convert.") | |
parser.add_argument("--file", default=None, type=str, required=False, help="Filename of the model to convert.") | |
parser.add_argument("--fix", action="store_true", help="Only fix the keys of the local model.") | |
parser.add_argument("--civitai_key", default=None, type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).") | |
parser.add_argument("--dtype", type=str, default="fp8") | |
parser.add_argument("--model", type=str, default="dev") | |
parser.add_argument("--dequant", action="store_true", help="Dequantize model.") | |
args = parser.parse_args() | |
assert (args.url, args.file) != (None, None), "Must provide --url or --file!" | |
dtype = torch.bfloat16 | |
quantization = False | |
if args.dtype == "fp8": dtype = torch.float8_e4m3fn | |
elif args.dtype == "fp16": dtype = torch.float16 | |
elif args.dtype == "qfloat8": | |
dtype = torch.bfloat16 | |
quantization = True | |
else: dtype = torch.bfloat16 | |
use_original = ["vae", "text_encoder"] | |
new_repo_id = "" | |
use_local = True | |
if args.file is not None and Path(args.file).exists(): | |
if args.fix: normalize_flux_state_dict(args.file, ".", dtype, args.dequant) | |
else: flux_to_diffusers_lowmem(args.file, Path(args.file).stem, dtype, quantization, | |
args.model, args.dequant, use_original, new_repo_id, use_local) | |
elif args.url is not None: | |
convert_url_to_diffusers_flux(args.url, args.civitai_key, False, args.dtype, args.model, | |
args.dequant) | |