|
import gc |
|
import torch |
|
import accelerate |
|
|
|
|
|
def get_module_by_name_suffix(model, module_name: str): |
|
for name, module in model.named_modules(): |
|
if name.endswith(module_name): |
|
return module |
|
|
|
def simple_dispatch_model(model, device_map): |
|
from accelerate.hooks import add_hook_to_module, AlignDevicesHook |
|
|
|
if "" in device_map: |
|
d = device_map[""] |
|
model = model.to(torch.device(d)) |
|
model.hf_device_map = device_map |
|
return model |
|
|
|
tied_params = accelerate.utils.modeling.find_tied_parameters(model) |
|
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: |
|
main_device = "cpu" |
|
else: |
|
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] |
|
|
|
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"] |
|
prev_hook = None |
|
for idx, (n, d) in enumerate(cpu_offload_group): |
|
m = get_module_by_name_suffix(model, n) |
|
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook) |
|
|
|
if len(cpu_offload_group) > 1: |
|
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook |
|
|
|
for n, d in device_map.items(): |
|
m = get_module_by_name_suffix(model, n) |
|
if d != "cpu": |
|
d = torch.device(d) |
|
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) |
|
add_hook_to_module(m, hook) |
|
accelerate.utils.modeling.retie_parameters(model, tied_params) |
|
model.hf_device_map = device_map |
|
|
|
return model |
|
|
|
def set_module_name(model, name, value): |
|
if '.' in name: |
|
parent_name = name.rsplit('.', 1)[0] |
|
child_name = name[len(parent_name) + 1:] |
|
parent = model.get_submodule(parent_name) |
|
else: |
|
parent_name = '' |
|
parent = model |
|
child_name = name |
|
|
|
setattr(parent, child_name, value) |
|
|
|
def clear_memory(weight=None): |
|
if weight is not None: |
|
del weight |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def compute_memory_used_pct(device): |
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) |
|
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 |
|
return memory_pct |