Spaces:
Sleeping
Sleeping
def replace_layer_recursive(model, old_layer, new_layer): | |
for name, layer in model._modules.items(): | |
if layer == old_layer: | |
model._modules[name] = new_layer | |
return True | |
elif replace_layer_recursive(layer, old_layer, new_layer): | |
return True | |
return False | |
def replace_all_layer_type_recursive(model, old_layer_type, new_layer): | |
for name, layer in model._modules.items(): | |
if isinstance(layer, old_layer_type): | |
model._modules[name] = new_layer | |
replace_all_layer_type_recursive(layer, old_layer_type, new_layer) | |
def find_layer_types_recursive(model, layer_types): | |
def predicate(layer): | |
return type(layer) in layer_types | |
return find_layer_predicate_recursive(model, predicate) | |
def find_layer_predicate_recursive(model, predicate): | |
result = [] | |
for name, layer in model._modules.items(): | |
if predicate(layer): | |
result.append(layer) | |
result.extend(find_layer_predicate_recursive(layer, predicate)) | |
return result | |