import os from unittest.mock import patch from transformers.dynamic_module_utils import get_imports def fixed_get_imports(filename: str | os.PathLike) -> list[str]: """Workaround for flash_attn import issue.""" if not str(filename).endswith(("/modeling_florence2.py", "configuration_florence2.py")): return get_imports(filename) imports = get_imports(filename) if "flash_attn" in imports: imports.remove("flash_attn") return imports def load_model_without_flash_attn(model_loader): """Load a model using the flash_attn workaround.""" with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): return model_loader()