Sharp-It / torch_patch.py
YiftachEde's picture
updated
5c79851
raw
history blame
2.4 kB
"""
Patch for torch module to make it compatible with newer diffusers versions
while using PyTorch 2.0.1
"""
import torch
import sys
import warnings
import types
import functools
# Check if the attributes already exist
if not hasattr(torch, 'float8_e4m3fn'):
# Add missing attributes for compatibility
# These won't actually function, but they'll allow imports to succeed
torch.float8_e4m3fn = torch.float16 # Use float16 as a placeholder type
warnings.warn(
"Added placeholder for torch.float8_e4m3fn. Actual 8-bit operations won't work, "
"but imports should succeed. Using PyTorch 2.0.1 with newer diffusers."
)
if not hasattr(torch, 'float8_e5m2'):
torch.float8_e5m2 = torch.float16 # Use float16 as a placeholder type
# Add other missing torch types that might be referenced
for type_name in ['bfloat16', 'bfloat8', 'float8_e4m3fnuz']:
if not hasattr(torch, type_name):
setattr(torch, type_name, torch.float16)
# Create a placeholder for torch._dynamo if it doesn't exist
if not hasattr(torch, '_dynamo'):
torch._dynamo = types.ModuleType('torch._dynamo')
sys.modules['torch._dynamo'] = torch._dynamo
# Add common attributes/functions used by torch._dynamo
torch._dynamo.config = types.SimpleNamespace(suppress_errors=True)
torch._dynamo.optimize = lambda *args, **kwargs: lambda f: f
torch._dynamo.disable = lambda: None
torch._dynamo.reset_repro_cache = lambda: None
# Add torch.compile if it doesn't exist
if not hasattr(torch, 'compile'):
# Just return the function unchanged
torch.compile = lambda fn, **kwargs: fn
# Create a placeholder for torch.cuda.amp if it doesn't exist
if not hasattr(torch.cuda, 'amp'):
torch.cuda.amp = types.ModuleType('torch.cuda.amp')
sys.modules['torch.cuda.amp'] = torch.cuda.amp
# Mock autocast
class MockAutocast:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
torch.cuda.amp.autocast = MockAutocast
print("PyTorch patched for compatibility with newer diffusers - using latest diffusers with PyTorch 2.0.1")