""" Patch for torch module to make it compatible with newer diffusers versions while using PyTorch 2.0.1 """ import torch import sys import warnings import types import functools # Check if the attributes already exist if not hasattr(torch, 'float8_e4m3fn'): # Add missing attributes for compatibility # These won't actually function, but they'll allow imports to succeed torch.float8_e4m3fn = torch.float16 # Use float16 as a placeholder type warnings.warn( "Added placeholder for torch.float8_e4m3fn. Actual 8-bit operations won't work, " "but imports should succeed. Using PyTorch 2.0.1 with newer diffusers." ) if not hasattr(torch, 'float8_e5m2'): torch.float8_e5m2 = torch.float16 # Use float16 as a placeholder type # Add other missing torch types that might be referenced for type_name in ['bfloat16', 'bfloat8', 'float8_e4m3fnuz']: if not hasattr(torch, type_name): setattr(torch, type_name, torch.float16) # Create a placeholder for torch._dynamo if it doesn't exist if not hasattr(torch, '_dynamo'): torch._dynamo = types.ModuleType('torch._dynamo') sys.modules['torch._dynamo'] = torch._dynamo # Add common attributes/functions used by torch._dynamo torch._dynamo.config = types.SimpleNamespace(suppress_errors=True) torch._dynamo.optimize = lambda *args, **kwargs: lambda f: f torch._dynamo.disable = lambda: None torch._dynamo.reset_repro_cache = lambda: None # Add torch.compile if it doesn't exist if not hasattr(torch, 'compile'): # Just return the function unchanged torch.compile = lambda fn, **kwargs: fn # Create a placeholder for torch.cuda.amp if it doesn't exist if not hasattr(torch.cuda, 'amp'): torch.cuda.amp = types.ModuleType('torch.cuda.amp') sys.modules['torch.cuda.amp'] = torch.cuda.amp # Mock autocast class MockAutocast: def __init__(self, *args, **kwargs): pass def __enter__(self): return self def __exit__(self, *args): pass def __call__(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper torch.cuda.amp.autocast = MockAutocast print("PyTorch patched for compatibility with newer diffusers - using latest diffusers with PyTorch 2.0.1")