Spaces:
Runtime error
Runtime error
# Copyright Forge 2024 | |
import time | |
import torch | |
import contextlib | |
from backend import stream, memory_management, utils | |
from backend.patcher.lora import merge_lora_to_weight | |
stash = {} | |
def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): | |
patches = getattr(layer, 'forge_online_loras', None) | |
weight_patches, bias_patches = None, None | |
if patches is not None: | |
weight_patches = patches.get('weight', None) | |
if patches is not None: | |
bias_patches = patches.get('bias', None) | |
weight = None | |
if layer.weight is not None: | |
weight = layer.weight | |
if weight_fn is not None: | |
if weight_args is not None: | |
fn_device = weight_args.get('device', None) | |
if fn_device is not None: | |
weight = weight.to(device=fn_device) | |
weight = weight_fn(weight) | |
if weight_args is not None: | |
weight = weight.to(**weight_args) | |
if weight_patches is not None: | |
weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype) | |
bias = None | |
if layer.bias is not None: | |
bias = layer.bias | |
if bias_fn is not None: | |
if bias_args is not None: | |
fn_device = bias_args.get('device', None) | |
if fn_device is not None: | |
bias = bias.to(device=fn_device) | |
bias = bias_fn(bias) | |
if bias_args is not None: | |
bias = bias.to(**bias_args) | |
if bias_patches is not None: | |
bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype) | |
return weight, bias | |
def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None): | |
weight, bias, signal = None, None, None | |
non_blocking = True | |
if getattr(x.device, 'type', None) == 'mps': | |
non_blocking = False | |
target_dtype = x.dtype | |
target_device = x.device | |
if skip_weight_dtype: | |
weight_args = dict(device=target_device, non_blocking=non_blocking) | |
else: | |
weight_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) | |
if skip_bias_dtype: | |
bias_args = dict(device=target_device, non_blocking=non_blocking) | |
else: | |
bias_args = dict(device=target_device, dtype=target_dtype, non_blocking=non_blocking) | |
if stream.should_use_stream(): | |
with stream.stream_context()(stream.mover_stream): | |
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) | |
signal = stream.mover_stream.record_event() | |
else: | |
weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) | |
return weight, bias, signal | |
def main_stream_worker(weight, bias, signal): | |
if signal is None or not stream.should_use_stream(): | |
yield | |
return | |
with stream.stream_context()(stream.current_stream): | |
stream.current_stream.wait_event(signal) | |
yield | |
finished_signal = stream.current_stream.record_event() | |
stash[id(finished_signal)] = (weight, bias, finished_signal) | |
garbage = [] | |
for k, (w, b, s) in stash.items(): | |
if s.query(): | |
garbage.append(k) | |
for k in garbage: | |
del stash[k] | |
return | |
def cleanup_cache(): | |
if not stream.should_use_stream(): | |
return | |
stream.current_stream.synchronize() | |
stream.mover_stream.synchronize() | |
stash.clear() | |
return | |
current_device = None | |
current_dtype = None | |
current_manual_cast_enabled = False | |
current_bnb_dtype = None | |
class ForgeOperations: | |
class Linear(torch.nn.Module): | |
def __init__(self, in_features, out_features, *args, **kwargs): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) | |
self.weight = None | |
self.bias = None | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
if hasattr(self, 'dummy'): | |
if prefix + 'weight' in state_dict: | |
self.weight = torch.nn.Parameter(state_dict[prefix + 'weight'].to(self.dummy)) | |
if prefix + 'bias' in state_dict: | |
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) | |
del self.dummy | |
else: | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.linear(x, weight, bias) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
return torch.nn.functional.linear(x, weight, bias) | |
class Conv2d(torch.nn.Conv2d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return self._conv_forward(x, weight, bias) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
return super()._conv_forward(x, weight, bias) | |
class Conv3d(torch.nn.Conv3d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return self._conv_forward(x, weight, bias) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
return super()._conv_forward(input, weight, bias) | |
class Conv1d(torch.nn.Conv1d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return self._conv_forward(x, weight, bias) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
return super()._conv_forward(input, weight, bias) | |
class ConvTranspose2d(torch.nn.ConvTranspose2d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x, output_size=None): | |
if self.parameters_manual_cast: | |
num_spatial_dims = 2 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
num_spatial_dims = 2 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
class ConvTranspose1d(torch.nn.ConvTranspose1d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x, output_size=None): | |
if self.parameters_manual_cast: | |
num_spatial_dims = 1 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
num_spatial_dims = 1 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
class ConvTranspose3d(torch.nn.ConvTranspose3d): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x, output_size=None): | |
if self.parameters_manual_cast: | |
num_spatial_dims = 3 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
else: | |
weight, bias = get_weight_and_bias(self) | |
num_spatial_dims = 3 | |
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) | |
return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) | |
class GroupNorm(torch.nn.GroupNorm): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.group_norm(x, self.num_groups, weight, bias, self.eps) | |
else: | |
return super().forward(x) | |
class LayerNorm(torch.nn.LayerNorm): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
kwargs['dtype'] = current_dtype | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def reset_parameters(self): | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) | |
else: | |
return super().forward(x) | |
class Embedding(torch.nn.Embedding): | |
def __init__(self, *args, **kwargs): | |
kwargs['device'] = current_device | |
super().__init__(*args, **kwargs) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
self.bias = None | |
def reset_parameters(self): | |
self.bias = None | |
return None | |
def forward(self, x): | |
if self.parameters_manual_cast: | |
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.embedding(x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) | |
else: | |
return super().forward(x) | |
try: | |
from backend.operations_bnb import ForgeLoader4Bit, ForgeParams4bit, functional_linear_4bits, functional_dequantize_4bit | |
class ForgeOperationsBNB4bits(ForgeOperations): | |
class Linear(ForgeLoader4Bit): | |
def __init__(self, *args, **kwargs): | |
super().__init__(device=current_device, dtype=current_dtype, quant_type=current_bnb_dtype) | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def forward(self, x): | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
# Maybe this can also be set to all non-bnb ops since the cost is very low. | |
# And it only invokes one time, and most linear does not have bias | |
self.bias = utils.tensor2parameter(self.bias.to(x.dtype)) | |
if hasattr(self, 'forge_online_loras'): | |
weight, bias, signal = weights_manual_cast(self, x, weight_fn=functional_dequantize_4bit, bias_fn=None, skip_bias_dtype=True) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.linear(x, weight, bias) | |
if not self.parameters_manual_cast: | |
return functional_linear_4bits(x, self.weight, self.bias) | |
elif not self.weight.bnb_quantized: | |
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' | |
layer_original_device = self.weight.device | |
self.weight = self.weight._quantize(x.device) | |
bias = self.bias.to(x.device) if self.bias is not None else None | |
out = functional_linear_4bits(x, self.weight, bias) | |
self.weight = self.weight.to(layer_original_device) | |
return out | |
else: | |
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) | |
with main_stream_worker(weight, bias, signal): | |
return functional_linear_4bits(x, weight, bias) | |
bnb_avaliable = True | |
except: | |
bnb_avaliable = False | |
from backend.operations_gguf import dequantize_tensor | |
class ForgeOperationsGGUF(ForgeOperations): | |
class Linear(torch.nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) | |
self.weight = None | |
self.bias = None | |
self.parameters_manual_cast = current_manual_cast_enabled | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
if hasattr(self, 'dummy'): | |
computation_dtype = self.dummy.dtype | |
if computation_dtype not in [torch.float16, torch.bfloat16]: | |
# GGUF cast only supports 16bits otherwise super slow | |
computation_dtype = torch.float16 | |
if prefix + 'weight' in state_dict: | |
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device) | |
self.weight.computation_dtype = computation_dtype | |
if prefix + 'bias' in state_dict: | |
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device) | |
self.bias.computation_dtype = computation_dtype | |
del self.dummy | |
else: | |
if prefix + 'weight' in state_dict: | |
self.weight = state_dict[prefix + 'weight'] | |
if prefix + 'bias' in state_dict: | |
self.bias = state_dict[prefix + 'bias'] | |
return | |
def _apply(self, fn, recurse=True): | |
for k, p in self.named_parameters(recurse=False, remove_duplicate=True): | |
setattr(self, k, utils.tensor2parameter(fn(p))) | |
return self | |
def forward(self, x): | |
if self.bias is not None and self.bias.dtype != x.dtype: | |
self.bias = utils.tensor2parameter(dequantize_tensor(self.bias).to(x.dtype)) | |
if self.weight is not None and self.weight.dtype != x.dtype and getattr(self.weight, 'gguf_cls', None) is None: | |
self.weight = utils.tensor2parameter(self.weight.to(x.dtype)) | |
weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=None, skip_bias_dtype=True) | |
with main_stream_worker(weight, bias, signal): | |
return torch.nn.functional.linear(x, weight, bias) | |
def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None): | |
global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype | |
current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype | |
if operations is None: | |
if bnb_dtype in ['gguf']: | |
operations = ForgeOperationsGGUF | |
elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']: | |
operations = ForgeOperationsBNB4bits | |
else: | |
operations = ForgeOperations | |
op_names = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'GroupNorm', 'LayerNorm', 'Embedding'] | |
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} | |
try: | |
for op_name in op_names: | |
setattr(torch.nn, op_name, getattr(operations, op_name)) | |
yield | |
finally: | |
for op_name in op_names: | |
setattr(torch.nn, op_name, backups[op_name]) | |
return | |
def shift_manual_cast(model, enabled): | |
for m in model.modules(): | |
if hasattr(m, 'parameters_manual_cast'): | |
m.parameters_manual_cast = enabled | |
return | |
def automatic_memory_management(): | |
memory_management.free_memory( | |
memory_required=3 * 1024 * 1024 * 1024, | |
device=memory_management.get_torch_device() | |
) | |
module_list = [] | |
original_init = torch.nn.Module.__init__ | |
original_to = torch.nn.Module.to | |
def patched_init(self, *args, **kwargs): | |
module_list.append(self) | |
return original_init(self, *args, **kwargs) | |
def patched_to(self, *args, **kwargs): | |
module_list.append(self) | |
return original_to(self, *args, **kwargs) | |
try: | |
torch.nn.Module.__init__ = patched_init | |
torch.nn.Module.to = patched_to | |
yield | |
finally: | |
torch.nn.Module.__init__ = original_init | |
torch.nn.Module.to = original_to | |
start = time.perf_counter() | |
module_list = set(module_list) | |
for module in module_list: | |
module.cpu() | |
memory_management.soft_empty_cache() | |
end = time.perf_counter() | |
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') | |
return | |
class DynamicSwapInstaller: | |
def _install_module(module: torch.nn.Module, target_device: torch.device): | |
original_class = module.__class__ | |
module.__dict__['forge_backup_original_class'] = original_class | |
def hacked_get_attr(self, name: str): | |
if '_parameters' in self.__dict__: | |
_parameters = self.__dict__['_parameters'] | |
if name in _parameters: | |
p = _parameters[name] | |
if p is None: | |
return None | |
if p.__class__ == torch.nn.Parameter: | |
return torch.nn.Parameter(p.to(target_device), requires_grad=p.requires_grad) | |
else: | |
return p.to(target_device) | |
if '_buffers' in self.__dict__: | |
_buffers = self.__dict__['_buffers'] | |
if name in _buffers: | |
return _buffers[name].to(target_device) | |
return super(original_class, self).__getattr__(name) | |
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { | |
'__getattr__': hacked_get_attr, | |
}) | |
return | |
def _uninstall_module(module: torch.nn.Module): | |
if 'forge_backup_original_class' in module.__dict__: | |
module.__class__ = module.__dict__.pop('forge_backup_original_class') | |
return | |
def install_model(model: torch.nn.Module, target_device: torch.device): | |
for m in model.modules(): | |
DynamicSwapInstaller._install_module(m, target_device) | |
return | |
def uninstall_model(model: torch.nn.Module): | |
for m in model.modules(): | |
DynamicSwapInstaller._uninstall_module(m) | |
return | |