Flux ONNX Model Inference failures

#1
by chandruflux - opened

I ran the black-forest-labs/FLUX.1-schnell-onnx and it looks like these models were created "opset_version=11",

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from onnx_models\black-forest-labs_FLUX.1-schnell-onnx\clip.opt\model.onnx failed:D:\a_work\1\s\onnxruntime\core\graph\model.cc:181 onnxruntime::Model::Model Unsupported model IR version: 11, max supported IR version: 10

It looks like these models are not working on current ONNX.

I created a ONNX script to port the Pytorch model into ONNX and it is still failing.. any inputs will be much appreciated

Python code: FP32 model
`import torch
from pathlib import Path
import json
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel

Base paths

base_model_dir = Path("pytorch_models/black-forest-labs_FLUX.1-schnell")
onnx_output_dir = Path("test_onnx_models") / "black-forest-labs_FLUX.1-schnell-onnx"
onnx_output_dir.mkdir(parents=True, exist_ok=True)

Tokenizer path

tokenizer_dir = base_model_dir / "tokenizer"

Auto-export tokenizer if tokenizer.json doesn't exist

tokenizer_json_path = tokenizer_dir / "tokenizer.json"
if not tokenizer_json_path.exists():
print(" Tokenizer.json not found! Exporting tokenizer using transformers...")
tokenizer = CLIPTokenizer.from_pretrained(str(tokenizer_dir))
tokenizer.save_pretrained(str(tokenizer_dir))
print(f" Tokenizer exported to {tokenizer_dir}")
else:
print(f" Tokenizer already present at {tokenizer_json_path}")

Map components and classes

component_classes = {
"transformer": ("transformer.opt", UNet2DConditionModel),
"vae": ("vae.opt", AutoencoderKL),
"text_encoder": ("clip.opt", CLIPTextModel),
"text_encoder_2": ("t5.opt", T5EncoderModel)
}

Input names

input_names_map = {
"transformer": ["sample", "timestep", "encoder_hidden_states"],
"vae": ["sample"],
"text_encoder": ["input_ids"],
"text_encoder_2": ["input_ids"]
}

Dynamic axes

dynamic_axes_map = {
"transformer": {"sample": {0: "batch", 2: "height", 3: "width"}, "encoder_hidden_states": {0: "batch", 1: "tokens"}},
"vae": {"sample": {0: "batch", 2: "height", 3: "width"}},
"text_encoder": {"input_ids": {0: "batch", 1: "tokens"}},
"text_encoder_2": {"input_ids": {0: "batch", 1: "tokens"}},
}

Patch config (remove num_attention_heads if exists)

def patch_config_json_if_needed(model_dir):
config_path = model_dir / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
if "num_attention_heads" in config:
print(f"⚙️ Patching num_attention_heads in {model_dir}")
del config["num_attention_heads"]
with open(config_path, "w") as f:
json.dump(config, f, indent=2)

Load weights from safetensors index.json

def load_weights_from_index(model_dir):
index_file = next(model_dir.glob("*safetensors.index.json"), None)
if not index_file:
return None
with open(index_file, "r") as f:
index_data = json.load(f)
state_dict = {}
for filename in index_data['weight_map'].values():
weights = load_file(model_dir / filename)
state_dict.update(weights)
return state_dict

VAE wrapper to handle input projection (4 -> latent_channels)

class VAEWrapper(torch.nn.Module):
def init(self, vae, latent_channels):
super().init()
self.vae = vae
self.latent_channels = latent_channels
# Inject a projection layer if VAE expects more than 4 channels
self.input_proj = None
if latent_channels != 4:
self.input_proj = torch.nn.Conv2d(4, latent_channels, kernel_size=1)
torch.nn.init.xavier_uniform_(self.input_proj.weight)
print(f" VAE input projection added: 4 -> {latent_channels}")

def forward(self, latents):
    if self.input_proj:
        latents = self.input_proj(latents)
    if self.vae.post_quant_conv is not None:
        latents = self.vae.post_quant_conv(latents)
    return self.vae.decoder(latents)

Main export loop

for component, (output_folder, model_class) in component_classes.items():
print(f"\n🔹 Processing {component} -> {output_folder}")

model_dir = base_model_dir / component
output_dir = onnx_output_dir / output_folder
output_dir.mkdir(parents=True, exist_ok=True)
model_onnx_path = output_dir / "model.onnx"

#  Skip if already exported
if model_onnx_path.exists():
    print(f" Skipped (already exists): {model_onnx_path}")
    continue

patch_config_json_if_needed(model_dir)

config_path = model_dir / "config.json"
if not config_path.exists():
    print(f" Skipped {component}: config.json not found.")
    continue

with open(config_path, "r") as f:
    config = json.load(f)

#  Load model smartly
if model_class in [CLIPTextModel, T5EncoderModel]:
    model = model_class.from_pretrained(str(model_dir))
else:
    state_dict = load_weights_from_index(model_dir)
    if state_dict:
        model = model_class.from_config(config)
        model.load_state_dict(state_dict, strict=False)
    else:
        model = model_class.from_pretrained(str(model_dir))

model.eval()

#  Dummy inputs per component
if component == "transformer":
    in_channels = config.get('in_channels', 4)
    cross_attention_dim = config.get('cross_attention_dim', 1280)
    dummy_inputs = (
        torch.randn(1, in_channels, 64, 64),  # latent input
        torch.tensor([1]),                   # timestep
        torch.randn(1, 77, cross_attention_dim)  # text embeddings
    )
elif component == "vae":
    latent_channels = config.get('latent_channels', 16)  # could be 4 or 16 depending on VAE
    model = VAEWrapper(model, latent_channels)  # wrap VAE to always accept (B, 4, H, W)
    dummy_inputs = (torch.randn(1, 4, 64, 64),)  # VAE always takes (1, 4, 64, 64)
else:
    dummy_inputs = (torch.randint(0, 10000, (1, 77)),)  # input_ids

#  Export to ONNX
print(f" Exporting {component} to {model_onnx_path}")
torch.onnx.export(
    model,
    dummy_inputs,
    str(model_onnx_path),
    input_names=input_names_map[component],
    output_names=["output"],
    opset_version=17,                          # Modern opset
    export_params=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=False,
    use_external_data_format=True,             # For >2GB models
    dynamic_axes=dynamic_axes_map.get(component, None)  #  Flexible inputs
)
print(f" Exported {component} to {model_onnx_path}")

print("\n All components exported properly with correct input shape and latest ONNX IR! ")`

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment