Doa-doa's picture
Upload folder using huggingface_hub
72268ee
import torch.nn as nn
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name):
# get the op by its name relative to the module
for name, m in module.named_modules():
if name == op_name:
return m
raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split('.')
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels)-1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op):
# get the name of the op relative to the module
for name, m in module.named_modules():
if m is op:
return name
raise ValueError(f"Cannot find op {op} in module {module}")
def append_str_prefix(x, prefix):
if isinstance(x, str):
return prefix + x
elif isinstance(x, tuple):
return tuple([append_str_prefix(y, prefix) for y in x])
elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x]
else:
return x