# Copyright 2020-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Count remaining (non-zero) weights in the encoder (i.e. the transformer layers). Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %. """ import argparse import os import torch from emmental.modules import ThresholdBinarizer, TopKBinarizer def main(args): serialization_dir = args.serialization_dir pruning_method = args.pruning_method threshold = args.threshold st = torch.load(os.path.join(serialization_dir, "pytorch_model.bin"), map_location="cpu") remaining_count = 0 # Number of remaining (not pruned) params in the encoder encoder_count = 0 # Number of params in the encoder print("name".ljust(60, " "), "Remaining Weights %", "Remaining Weight") for name, param in st.items(): if "encoder" not in name: continue if "mask_scores" in name: if pruning_method == "topK": mask_ones = TopKBinarizer.apply(param, threshold).sum().item() elif pruning_method == "sigmoied_threshold": mask_ones = ThresholdBinarizer.apply(param, threshold, True).sum().item() elif pruning_method == "l0": l, r = -0.1, 1.1 s = torch.sigmoid(param) s_bar = s * (r - l) + l mask = s_bar.clamp(min=0.0, max=1.0) mask_ones = (mask > 0.0).sum().item() else: raise ValueError("Unknown pruning method") remaining_count += mask_ones print(name.ljust(60, " "), str(round(100 * mask_ones / param.numel(), 3)).ljust(20, " "), str(mask_ones)) else: encoder_count += param.numel() if "bias" in name or "LayerNorm" in name: remaining_count += param.numel() print("") print("Remaining Weights (global) %: ", 100 * remaining_count / encoder_count) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--pruning_method", choices=["l0", "topK", "sigmoied_threshold"], type=str, required=True, help=( "Pruning Method (l0 = L0 regularization, topK = Movement pruning, sigmoied_threshold = Soft movement" " pruning)" ), ) parser.add_argument( "--threshold", type=float, required=False, help=( "For `topK`, it is the level of remaining weights (in %) in the fine-pruned model." "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared." "Not needed for `l0`" ), ) parser.add_argument( "--serialization_dir", type=str, required=True, help="Folder containing the model that was previously fine-pruned", ) args = parser.parse_args() main(args)