NilEneb's picture
Upload folder using huggingface_hub
ad93086 verified
raw
history blame
22.7 kB
# Copyright Forge 2024
import time
import torch
import contextlib
from backend import stream, memory_management, utils
from backend.patcher.lora import merge_lora_to_weight
stash = {}
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None):
patches = getattr(layer, 'forge_online_loras', None)
weight_patches, bias_patches = None, None
if patches is not None:
weight_patches = patches.get('weight', None)
if patches is not None:
bias_patches = patches.get('bias', None)
weight = None
if layer.weight is not None:
weight = layer.weight
if weight_fn is not None:
if weight_args is not None:
fn_device = weight_args.get('device', None)
if fn_device is not None:
weight = weight.to(device=fn_device)
weight = weight_fn(weight)
if weight_args is not None:
weight = weight.to(**weight_args)
if weight_patches is not None:
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype)
bias = None
if layer.bias is not None:
bias = layer.bias
if bias_fn is not None:
if bias_args is not None:
fn_device = bias_args.get('device', None)
if fn_device is not None:
bias = bias.to(device=fn_device)
bias = bias_fn(bias)
if bias_args is not None:
bias = bias.to(**bias_args)
if bias_patches is not None:
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype)
return weight, bias
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None):
weight, bias, signal = None, None, None
non_blocking = True
if getattr(x.device, 'type', None) == 'mps':
non_blocking = False
target_dtype = x.dtype
target_device = x.device
if skip_weight_dtype:
weight_args = dict(device=target_device, non_blocking=non_blocking)
else:
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if skip_bias_dtype:
bias_args = dict(device=target_device, non_blocking=non_blocking)
else:
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
signal = stream.mover_stream.record_event()
else:
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn)
return weight, bias, signal
@contextlib.contextmanager
def main_stream_worker(weight, bias, signal):
if signal is None or not stream.should_use_stream():
yield
return
with stream.stream_context()(stream.current_stream):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
stash[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s) in stash.items():
if s.query():
garbage.append(k)
for k in garbage:
del stash[k]
return
def cleanup_cache():
if not stream.should_use_stream():
return
stream.current_stream.synchronize()
stream.mover_stream.synchronize()
stash.clear()
return
current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None
class ForgeOperations:
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, *args, **kwargs):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy))
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return torch.nn.functional.linear(x, weight, bias)
class Conv2d(torch.nn.Conv2d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(x, weight, bias)
class Conv3d(torch.nn.Conv3d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class Conv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return self._conv_forward(x, weight, bias)
else:
weight, bias = get_weight_and_bias(self)
return super()._conv_forward(input, weight, bias)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 2
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose1d(torch.nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 1
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class ConvTranspose3d(torch.nn.ConvTranspose3d):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x, output_size=None):
if self.parameters_manual_cast:
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
else:
weight, bias = get_weight_and_bias(self)
num_spatial_dims = 3
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation)
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
class GroupNorm(torch.nn.GroupNorm):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps)
else:
return super().forward(x)
class LayerNorm(torch.nn.LayerNorm):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
kwargs['dtype'] = current_dtype
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
def reset_parameters(self):
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
else:
return super().forward(x)
class Embedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
kwargs['device'] = current_device
super().__init__(*args, **kwargs)
self.parameters_manual_cast = current_manual_cast_enabled
self.bias = None
def reset_parameters(self):
self.bias = None
return None
def forward(self, x):
if self.parameters_manual_cast:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
else:
return super().forward(x)
try:
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit
class ForgeOperationsBNB4bits(ForgeOperations):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, **kwargs):
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled
def forward(self, x):
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias = utils.tensor2parameter(self.bias.to(x.dtype))
if hasattr(self, 'forge_online_loras'):
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)
bnb_avaliable = True
except:
bnb_avaliable = False
from backend.operations_gguf import dequantize_tensor
class ForgeOperationsGGUF(ForgeOperations):
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None
self.parameters_manual_cast = current_manual_cast_enabled
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
computation_dtype = self.dummy.dtype
if computation_dtype not in [torch.float16, torch.bfloat16]:
# GGUF cast only supports 16bits otherwise super slow
computation_dtype = torch.float16
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
self.weight.computation_dtype = computation_dtype
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
self.bias.computation_dtype = computation_dtype
del self.dummy
else:
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight']
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias']
return
def _apply(self, fn, recurse=True):
for k, p in self.named_parameters(recurse=False, remove_duplicate=True):
setattr(self, k, utils.tensor2parameter(fn(p)))
return self
def forward(self, x):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias = utils.tensor2parameter(dequantize_tensor(self.bias).to(x.dtype))
if self.weight is not None and self.weight.dtype != x.dtype and getattr(self.weight, 'gguf_cls', None) is None:
self.weight = utils.tensor2parameter(self.weight.to(x.dtype))
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(x, weight, bias)
@contextlib.contextmanager
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None):
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype
if operations is None:
if bnb_dtype in ['gguf']:
operations = ForgeOperationsGGUF
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']:
operations = ForgeOperationsBNB4bits
else:
operations = ForgeOperations
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
def shift_manual_cast(model, enabled):
for m in model.modules():
if hasattr(m, 'parameters_manual_cast'):
m.parameters_manual_cast = enabled
return
@contextlib.contextmanager
def automatic_memory_management():
memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=memory_management.get_torch_device()
)
module_list = []
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
torch.nn.Module.to = patched_to
yield
finally:
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
start = time.perf_counter()
module_list = set(module_list)
for module in module_list:
module.cpu()
memory_management.soft_empty_cache()
end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
return
class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, target_device: torch.device):
original_class = module.__class__
module.__dict__['forge_backup_original_class'] = original_class
def hacked_get_attr(self, name: str):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
p = _parameters[name]
if p is None:
return None
if p.__class__ == torch.nn.Parameter:
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad)
else:
return p.to(target_device)
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name].to(target_device)
return super(original_class, self).__getattr__(name)
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
'__getattr__': hacked_get_attr,
})
return
@staticmethod
def _uninstall_module(module: torch.nn.Module):
if 'forge_backup_original_class' in module.__dict__:
module.__class__ = module.__dict__.pop('forge_backup_original_class')
return
@staticmethod
def install_model(model: torch.nn.Module, target_device: torch.device):
for m in model.modules():
DynamicSwapInstaller._install_module(m, target_device)
return
@staticmethod
def uninstall_model(model: torch.nn.Module):
for m in model.modules():
DynamicSwapInstaller._uninstall_module(m)
return