Cosmos-Predict2 / diffusers_repo /scripts /convert_omnigen_to_diffusers.py
multimodalart's picture
Upload 2025 files
22a452a verified
raw
history blame
7.5 kB
import argparse
import os
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
def main(args):
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
if not os.path.exists(args.origin_ckpt_path):
print("Model not found, downloading...")
cache_folder = os.getenv("HF_HUB_CACHE")
args.origin_ckpt_path = snapshot_download(
repo_id=args.origin_ckpt_path,
cache_dir=cache_folder,
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
)
print(f"Downloaded model to {args.origin_ckpt_path}")
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
ckpt = load_file(ckpt, device="cpu")
mapping_dict = {
"pos_embed": "patch_embedding.pos_embed",
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
"final_layer.linear.weight": "proj_out.weight",
"final_layer.linear.bias": "proj_out.bias",
"time_token.mlp.0.weight": "time_token.linear_1.weight",
"time_token.mlp.0.bias": "time_token.linear_1.bias",
"time_token.mlp.2.weight": "time_token.linear_2.weight",
"time_token.mlp.2.bias": "time_token.linear_2.bias",
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
"llm.embed_tokens.weight": "embed_tokens.weight",
}
converted_state_dict = {}
for k, v in ckpt.items():
if k in mapping_dict:
converted_state_dict[mapping_dict[k]] = v
elif "qkv" in k:
to_q, to_k, to_v = v.chunk(3)
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
elif "o_proj" in k:
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
else:
converted_state_dict[k[4:]] = v
transformer = OmniGenTransformer2DModel(
rope_scaling={
"long_factor": [
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0799999237060547,
1.2299998998641968,
1.2299998998641968,
1.2999999523162842,
1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281,
],
"short_factor": [
1.05,
1.05,
1.05,
1.1,
1.1,
1.1,
1.2500000000000002,
1.2500000000000002,
1.4000000000000004,
1.4500000000000004,
1.5500000000000005,
1.8500000000000008,
1.9000000000000008,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.000000000000001,
2.1000000000000005,
2.1000000000000005,
2.2,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3499999999999996,
2.3999999999999995,
2.3999999999999995,
2.6499999999999986,
2.6999999999999984,
2.8999999999999977,
2.9499999999999975,
3.049999999999997,
3.049999999999997,
3.049999999999997,
],
"type": "su",
},
patch_size=2,
in_channels=4,
pos_embed_max_size=192,
)
transformer.load_state_dict(converted_state_dict, strict=True)
transformer.to(torch.bfloat16)
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
pipeline.save_pretrained(args.dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--origin_ckpt_path",
default="Shitao/OmniGen-v1",
type=str,
required=False,
help="Path to the checkpoint to convert.",
)
parser.add_argument(
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
)
args = parser.parse_args()
main(args)