Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Patch for torch module to make it compatible with newer diffusers versions | |
while using PyTorch 2.0.1 | |
""" | |
import torch | |
import sys | |
import warnings | |
import types | |
import functools | |
# Check if the attributes already exist | |
if not hasattr(torch, 'float8_e4m3fn'): | |
# Add missing attributes for compatibility | |
# These won't actually function, but they'll allow imports to succeed | |
torch.float8_e4m3fn = torch.float16 # Use float16 as a placeholder type | |
warnings.warn( | |
"Added placeholder for torch.float8_e4m3fn. Actual 8-bit operations won't work, " | |
"but imports should succeed. Using PyTorch 2.0.1 with newer diffusers." | |
) | |
if not hasattr(torch, 'float8_e5m2'): | |
torch.float8_e5m2 = torch.float16 # Use float16 as a placeholder type | |
# Add other missing torch types that might be referenced | |
for type_name in ['bfloat16', 'bfloat8', 'float8_e4m3fnuz']: | |
if not hasattr(torch, type_name): | |
setattr(torch, type_name, torch.float16) | |
# Create a placeholder for torch._dynamo if it doesn't exist | |
if not hasattr(torch, '_dynamo'): | |
torch._dynamo = types.ModuleType('torch._dynamo') | |
sys.modules['torch._dynamo'] = torch._dynamo | |
# Add common attributes/functions used by torch._dynamo | |
torch._dynamo.config = types.SimpleNamespace(suppress_errors=True) | |
torch._dynamo.optimize = lambda *args, **kwargs: lambda f: f | |
torch._dynamo.disable = lambda: None | |
torch._dynamo.reset_repro_cache = lambda: None | |
# Add torch.compile if it doesn't exist | |
if not hasattr(torch, 'compile'): | |
# Just return the function unchanged | |
torch.compile = lambda fn, **kwargs: fn | |
# Create a placeholder for torch.cuda.amp if it doesn't exist | |
if not hasattr(torch.cuda, 'amp'): | |
torch.cuda.amp = types.ModuleType('torch.cuda.amp') | |
sys.modules['torch.cuda.amp'] = torch.cuda.amp | |
# Mock autocast | |
class MockAutocast: | |
def __init__(self, *args, **kwargs): | |
pass | |
def __enter__(self): | |
return self | |
def __exit__(self, *args): | |
pass | |
def __call__(self, func): | |
def wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
return wrapper | |
torch.cuda.amp.autocast = MockAutocast | |
print("PyTorch patched for compatibility with newer diffusers - using latest diffusers with PyTorch 2.0.1") |