File size: 8,947 Bytes
132eb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
module to handle loading model on cpu/meta device for FSDP
"""
import os
import time
from typing import List, Optional, Type, Union

import safetensors
import torch
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub


def _replace_linear(
    model: nn.Module,
    linear_replacement: Type[nn.Module],
    quant_config: Union[dict, None] = None,
    skip_modules=None,
    **kwargs,
):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        linear_replacement (`torch.nn.Module`):
            The linear module that replaces the old one. Only expects standard arguments.
            If other arguments need to be passed, use a lambda.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    if skip_modules is None:
        skip_modules = ["lm_head"]
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            _replace_linear(
                module, linear_replacement, quant_config, skip_modules, **kwargs
            )

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            if issubclass(linear_replacement, Linear4bit):
                model._modules[  # pylint: disable=protected-access
                    name
                ] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs,
                )
            else:
                raise ValueError(
                    f"Unsupported linear replacement: {type(linear_replacement)}"
                )
    return model


def load_and_quantize(
    module: nn.Module,
    name: str,
    value: Tensor,
    device: torch.device = None,
    dtype: torch.dtype = None,
    skip_names: Optional[List[str]] = None,
    to_cpu: bool = False,
    to_meta: bool = False,
    verbose: bool = False,
    quant_method: str = "bnb",
):
    """
    Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.

    Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True.
    """

    if not skip_names:
        skip_names = []

    def place_on_device(value):
        if to_meta:
            device = "meta"
        elif to_cpu:
            device = "cpu"
        return value.to(device=device, dtype=dtype)

    if any(skip_name in name for skip_name in skip_names):
        if verbose:
            print(f"Skipping {name} because it is in skip_names")
        return

    module_key, _, value_key = name.rpartition(".")
    try:
        submodule = module.get_submodule(module_key)
    except AttributeError as exc:
        print(f"Module {module_key} not found:\n{exc}")
        return

    try:
        if quant_method == "bnb":
            param = submodule.get_parameter(value_key)
            if isinstance(param, Params4bit):
                # With `sync_module_states=True`, a meta device Params4bit needs to be the same
                # shape as the quantized Params4bit with an initialized quant_state. However,
                # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
                # workaround quantizes Params4bit to initialize quant_state on all ranks, then
                # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
                value = type(param)(
                    value.to(device=device, dtype=dtype).data, **param.__dict__
                ).cuda(device)
                if to_meta:
                    value = type(param)(value.data.to("meta"), **value.__dict__)
                elif to_cpu:
                    value = type(param)(value.data.to("cpu"), **value.__dict__)
            else:
                value = type(param)(place_on_device(value).data)

    except AttributeError:
        # it's a buffer
        value = place_on_device(value)

    setattr(submodule, value_key, value)


def n_loading_workers(quant_method: str, param_count: float):
    devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
    left = int(os.cpu_count() / torch.cuda.device_count())
    model_params_b = 70
    right = int(
        (4 if quant_method == "hqq" else 8)
        * (devprops.total_memory / 1e9 / 40)
        * (model_params_b / (param_count / 1e9))
    )
    return min(left, right)


def load_sharded_model(
    model_name,
    model_config,
    cfg,
    torch_dtype=torch.bfloat16,
    low_memory=True,
):
    if (low_memory and cfg.local_rank == 0) or not low_memory:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            use_cache=False,
            torch_dtype=torch.float32,
            _attn_implementation=model_config._attn_implementation,  # pylint: disable=protected-access
            trust_remote_code=cfg.trust_remote_code,
        )
        dtype = torch_dtype if not cfg.float32 else None
        model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
    else:
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(
                model_config,
                torch_dtype=torch_dtype,
                trust_remote_code=cfg.trust_remote_code,
            )
    return model


def load_sharded_model_quant(
    model_name,
    model_config,
    cfg,
    compute_dtype=torch.bfloat16,
    quant_storage=torch.float32,
    low_memory=True,
    verbose=False,
    loading_workers=2,
):
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(
            model_config,
            trust_remote_code=cfg.trust_remote_code,
        )
        if hasattr(model, "transformer"):
            model.transformer = _replace_linear(
                model.transformer,
                Linear4bit,
                compute_dtype=compute_dtype,
                quant_type="nf4",
                quant_storage=quant_storage,
            )
        else:
            # this is the more common case with HF transformers
            model.model = _replace_linear(
                model.model,
                Linear4bit,
                compute_dtype=compute_dtype,
                quant_type="nf4",
                quant_storage=quant_storage,
            )
    model.is_loaded_in_4bit = True

    # Grab the safetensors files that hold the weights
    try:
        idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
        files, _ = hub.get_checkpoint_shard_files(model_name, idx)
    except OSError:
        try:
            # This means the model doesn't have a model.safetensors.index.json because it is not sharded
            files = []
            files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME))
        except OSError as exc:
            # This means the model probably doesn't have a safetensors file
            raise exc

    # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
    # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
    def load_and_quantize_parallel(name_param, model, **kwargs):
        name, param = name_param
        load_and_quantize(model, name, param, **kwargs)

    quant_method = "bnb"
    param_count = sum((p.numel() for n, p in model.named_parameters()))

    n_workers = (
        n_loading_workers(quant_method, param_count)
        if loading_workers == -1
        else loading_workers
    )
    if cfg.local_rank == 0 and verbose:
        print(f"Using n_workers: {n_workers} for loading")

    start = time.time()
    for filename in tqdm(
        files,
        desc="Loading & Quantizing Model Shards",
        disable=cfg.local_rank != 0,
        position=0,
    ):
        weights = safetensors.torch.load_file(filename)
        parallel(
            load_and_quantize_parallel,
            iter(weights.items()),
            n_workers=n_workers,
            threadpool=True,
            model=model,
            dtype=quant_storage,
            device=cfg.local_rank,
            skip_names=[],
            to_cpu=(low_memory and cfg.local_rank == 0),
            to_meta=(low_memory and cfg.local_rank != 0),
            verbose=verbose,
            quant_method=quant_method,
        )

    if cfg.local_rank == 0 and verbose:
        print(f"Loaded model weights in {time.time()-start:.3f} seconds")
    # cleanup any extra memory usage from parallel loading
    torch.cuda.empty_cache()

    return model