Flux ONNX Model Inference failures
I ran the black-forest-labs/FLUX.1-schnell-onnx and it looks like these models were created "opset_version=11",
Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from onnx_models\black-forest-labs_FLUX.1-schnell-onnx\clip.opt\model.onnx failed:D:\a_work\1\s\onnxruntime\core\graph\model.cc:181 onnxruntime::Model::Model Unsupported model IR version: 11, max supported IR version: 10
It looks like these models are not working on current ONNX.
I created a ONNX script to port the Pytorch model into ONNX and it is still failing.. any inputs will be much appreciated
Python code: FP32 model
`import torch
from pathlib import Path
import json
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel
Base paths
base_model_dir = Path("pytorch_models/black-forest-labs_FLUX.1-schnell")
onnx_output_dir = Path("test_onnx_models") / "black-forest-labs_FLUX.1-schnell-onnx"
onnx_output_dir.mkdir(parents=True, exist_ok=True)
Tokenizer path
tokenizer_dir = base_model_dir / "tokenizer"
Auto-export tokenizer if tokenizer.json doesn't exist
tokenizer_json_path = tokenizer_dir / "tokenizer.json"
if not tokenizer_json_path.exists():
print(" Tokenizer.json not found! Exporting tokenizer using transformers...")
tokenizer = CLIPTokenizer.from_pretrained(str(tokenizer_dir))
tokenizer.save_pretrained(str(tokenizer_dir))
print(f" Tokenizer exported to {tokenizer_dir}")
else:
print(f" Tokenizer already present at {tokenizer_json_path}")
Map components and classes
component_classes = {
"transformer": ("transformer.opt", UNet2DConditionModel),
"vae": ("vae.opt", AutoencoderKL),
"text_encoder": ("clip.opt", CLIPTextModel),
"text_encoder_2": ("t5.opt", T5EncoderModel)
}
Input names
input_names_map = {
"transformer": ["sample", "timestep", "encoder_hidden_states"],
"vae": ["sample"],
"text_encoder": ["input_ids"],
"text_encoder_2": ["input_ids"]
}
Dynamic axes
dynamic_axes_map = {
"transformer": {"sample": {0: "batch", 2: "height", 3: "width"}, "encoder_hidden_states": {0: "batch", 1: "tokens"}},
"vae": {"sample": {0: "batch", 2: "height", 3: "width"}},
"text_encoder": {"input_ids": {0: "batch", 1: "tokens"}},
"text_encoder_2": {"input_ids": {0: "batch", 1: "tokens"}},
}
Patch config (remove num_attention_heads if exists)
def patch_config_json_if_needed(model_dir):
config_path = model_dir / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
if "num_attention_heads" in config:
print(f"⚙️ Patching num_attention_heads in {model_dir}")
del config["num_attention_heads"]
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
Load weights from safetensors index.json
def load_weights_from_index(model_dir):
index_file = next(model_dir.glob("*safetensors.index.json"), None)
if not index_file:
return None
with open(index_file, "r") as f:
index_data = json.load(f)
state_dict = {}
for filename in index_data['weight_map'].values():
weights = load_file(model_dir / filename)
state_dict.update(weights)
return state_dict
VAE wrapper to handle input projection (4 -> latent_channels)
class VAEWrapper(torch.nn.Module):
def init(self, vae, latent_channels):
super().init()
self.vae = vae
self.latent_channels = latent_channels
# Inject a projection layer if VAE expects more than 4 channels
self.input_proj = None
if latent_channels != 4:
self.input_proj = torch.nn.Conv2d(4, latent_channels, kernel_size=1)
torch.nn.init.xavier_uniform_(self.input_proj.weight)
print(f" VAE input projection added: 4 -> {latent_channels}")
def forward(self, latents):
if self.input_proj:
latents = self.input_proj(latents)
if self.vae.post_quant_conv is not None:
latents = self.vae.post_quant_conv(latents)
return self.vae.decoder(latents)
Main export loop
for component, (output_folder, model_class) in component_classes.items():
print(f"\n🔹 Processing {component} -> {output_folder}")
model_dir = base_model_dir / component
output_dir = onnx_output_dir / output_folder
output_dir.mkdir(parents=True, exist_ok=True)
model_onnx_path = output_dir / "model.onnx"
# Skip if already exported
if model_onnx_path.exists():
print(f" Skipped (already exists): {model_onnx_path}")
continue
patch_config_json_if_needed(model_dir)
config_path = model_dir / "config.json"
if not config_path.exists():
print(f" Skipped {component}: config.json not found.")
continue
with open(config_path, "r") as f:
config = json.load(f)
# Load model smartly
if model_class in [CLIPTextModel, T5EncoderModel]:
model = model_class.from_pretrained(str(model_dir))
else:
state_dict = load_weights_from_index(model_dir)
if state_dict:
model = model_class.from_config(config)
model.load_state_dict(state_dict, strict=False)
else:
model = model_class.from_pretrained(str(model_dir))
model.eval()
# Dummy inputs per component
if component == "transformer":
in_channels = config.get('in_channels', 4)
cross_attention_dim = config.get('cross_attention_dim', 1280)
dummy_inputs = (
torch.randn(1, in_channels, 64, 64), # latent input
torch.tensor([1]), # timestep
torch.randn(1, 77, cross_attention_dim) # text embeddings
)
elif component == "vae":
latent_channels = config.get('latent_channels', 16) # could be 4 or 16 depending on VAE
model = VAEWrapper(model, latent_channels) # wrap VAE to always accept (B, 4, H, W)
dummy_inputs = (torch.randn(1, 4, 64, 64),) # VAE always takes (1, 4, 64, 64)
else:
dummy_inputs = (torch.randint(0, 10000, (1, 77)),) # input_ids
# Export to ONNX
print(f" Exporting {component} to {model_onnx_path}")
torch.onnx.export(
model,
dummy_inputs,
str(model_onnx_path),
input_names=input_names_map[component],
output_names=["output"],
opset_version=17, # Modern opset
export_params=True,
do_constant_folding=True,
keep_initializers_as_inputs=False,
use_external_data_format=True, # For >2GB models
dynamic_axes=dynamic_axes_map.get(component, None) # Flexible inputs
)
print(f" Exported {component} to {model_onnx_path}")
print("\n All components exported properly with correct input shape and latest ONNX IR! ")`