Spaces:
Running
on
Zero
Running
on
Zero
from diffusers.utils import is_torch_available | |
from diffusers.utils.testing_utils import ( | |
backend_empty_cache, | |
backend_max_memory_allocated, | |
backend_reset_peak_memory_stats, | |
torch_device, | |
) | |
if is_torch_available(): | |
import torch | |
import torch.nn as nn | |
class LoRALayer(nn.Module): | |
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only | |
Taken from | |
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 | |
""" | |
def __init__(self, module: nn.Module, rank: int): | |
super().__init__() | |
self.module = module | |
self.adapter = nn.Sequential( | |
nn.Linear(module.in_features, rank, bias=False), | |
nn.Linear(rank, module.out_features, bias=False), | |
) | |
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 | |
nn.init.normal_(self.adapter[0].weight, std=small_std) | |
nn.init.zeros_(self.adapter[1].weight) | |
self.adapter.to(module.weight.device) | |
def forward(self, input, *args, **kwargs): | |
return self.module(input, *args, **kwargs) + self.adapter(input) | |
def get_memory_consumption_stat(model, inputs): | |
backend_reset_peak_memory_stats(torch_device) | |
backend_empty_cache(torch_device) | |
model(**inputs) | |
max_mem_allocated = backend_max_memory_allocated(torch_device) | |
return max_mem_allocated | |