File size: 1,342 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch.nn as nn
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name):
# get the op by its name relative to the module
for name, m in module.named_modules():
if name == op_name:
return m
raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op):
# get the name of the op relative to the module
for name, m in module.named_modules():
if m is op:
return name
raise ValueError(f"Cannot find op {op} in module {module}")
def append_str_prefix(x, prefix):
if isinstance(x, str):
return prefix + x
elif isinstance(x, tuple):
return tuple([append_str_prefix(y, prefix) for y in x])
elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x]
else:
return x |