File size: 11,699 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import time
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
from captum._utils.models.linear_model.model import LinearModel
from torch.utils.data import DataLoader


def l2_loss(x1, x2, weights=None):
    if weights is None:
        return torch.mean((x1 - x2) ** 2) / 2.0
    else:
        return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0


def sgd_train_linear_model(
    model: LinearModel,
    dataloader: DataLoader,
    construct_kwargs: Dict[str, Any],
    max_epoch: int = 100,
    reduce_lr: bool = True,
    initial_lr: float = 0.01,
    alpha: float = 1.0,
    loss_fn: Callable = l2_loss,
    reg_term: Optional[int] = 1,
    patience: int = 10,
    threshold: float = 1e-4,
    running_loss_window: Optional[int] = None,
    device: Optional[str] = None,
    init_scheme: str = "zeros",
    debug: bool = False,
) -> Dict[str, float]:
    r"""
    Trains a linear model with SGD. This will continue to iterate your
    dataloader until we converged to a solution or alternatively until we have
    exhausted `max_epoch`.

    Convergence is defined by the loss not changing by `threshold` amount for
    `patience` number of iterations.

    Args:
        model
            The model to train
        dataloader
            The data to train it with. We will assume the dataloader produces
            either pairs or triples of the form (x, y) or (x, y, w). Where x and
            y are typical pairs for supervised learning and w is a weight
            vector.

            We will call `model._construct_model_params` with construct_kwargs
            and the input features set to `x.shape[1]` (`x.shape[0]` corresponds
            to the batch size). We assume that `len(x.shape) == 2`, i.e. the
            tensor is flat. The number of output features will be set to
            y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape)
            <= 2`.
        max_epoch
            The maximum number of epochs to exhaust
        reduce_lr
            Whether or not to reduce the learning rate as iterations progress.
            Halves the learning rate when the training loss does not move. This
            uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the
            parameters `patience` and `threshold`
        initial_lr
            The initial learning rate to use.
        alpha
            A constant for the regularization term.
        loss_fn
            The loss to optimise for. This must accept three parameters:
            x1 (predicted), x2 (labels) and a weight vector
        reg_term
            Regularization is defined by the `reg_term` norm of the weights.
            Please use `None` if you do not wish to use regularization.
        patience
            Defines the number of iterations in a row the loss must remain
            within `threshold` in order to be classified as converged.
        threshold
            Threshold for convergence detection.
        running_loss_window
            Used to report the training loss once we have finished training and
            to determine when we have converged (along with reducing the
            learning rate).

            The reported training loss will take the last `running_loss_window`
            iterations and average them.

            If `None` we will approximate this to be the number of examples in
            an epoch.
        init_scheme
            Initialization to use prior to training the linear model.
        device
            The device to send the model and data to. If None then no `.to` call
            will be used.
        debug
            Whether to print the loss, learning rate per iteration

    Returns
        This will return the final training loss (averaged with
        `running_loss_window`)
    """

    loss_window: List[torch.Tensor] = []
    min_avg_loss = None
    convergence_counter = 0
    converged = False

    def get_point(datapoint):
        if len(datapoint) == 2:
            x, y = datapoint
            w = None
        else:
            x, y, w = datapoint

        if device is not None:
            x = x.to(device)
            y = y.to(device)
            if w is not None:
                w = w.to(device)

        return x, y, w

    # get a point and construct the model
    data_iter = iter(dataloader)
    x, y, w = get_point(next(data_iter))

    model._construct_model_params(
        in_features=x.shape[1],
        out_features=y.shape[1] if len(y.shape) == 2 else 1,
        **construct_kwargs,
    )
    model.train()

    assert model.linear is not None

    if init_scheme is not None:
        assert init_scheme in ["xavier", "zeros"]

        with torch.no_grad():
            if init_scheme == "xavier":
                torch.nn.init.xavier_uniform_(model.linear.weight)
            else:
                model.linear.weight.zero_()

            if model.linear.bias is not None:
                model.linear.bias.zero_()

    optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
    if reduce_lr:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optim, factor=0.5, patience=patience, threshold=threshold
        )

    t1 = time.time()
    epoch = 0
    i = 0
    while epoch < max_epoch:
        while True:  # for x, y, w in dataloader
            if running_loss_window is None:
                running_loss_window = x.shape[0] * len(dataloader)

            y = y.view(x.shape[0], -1)
            if w is not None:
                w = w.view(x.shape[0], -1)

            i += 1

            out = model(x)

            loss = loss_fn(y, out, w)
            if reg_term is not None:
                reg = torch.norm(model.linear.weight, p=reg_term)
                loss += reg.sum() * alpha

            if len(loss_window) >= running_loss_window:
                loss_window = loss_window[1:]
            loss_window.append(loss.clone().detach())
            assert len(loss_window) <= running_loss_window

            average_loss = torch.mean(torch.stack(loss_window))
            if min_avg_loss is not None:
                # if we haven't improved by at least `threshold`
                if average_loss > min_avg_loss or torch.isclose(
                    min_avg_loss, average_loss, atol=threshold
                ):
                    convergence_counter += 1
                    if convergence_counter >= patience:
                        converged = True
                        break
                else:
                    convergence_counter = 0
            if min_avg_loss is None or min_avg_loss >= average_loss:
                min_avg_loss = average_loss.clone()

            if debug:
                print(
                    f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
                    + "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
                )

            loss.backward()

            optim.step()
            model.zero_grad()
            if scheduler:
                scheduler.step(average_loss)

            temp = next(data_iter, None)
            if temp is None:
                break
            x, y, w = get_point(temp)

        if converged:
            break

        epoch += 1
        data_iter = iter(dataloader)
        x, y, w = get_point(next(data_iter))

    t2 = time.time()
    return {
        "train_time": t2 - t1,
        "train_loss": torch.mean(torch.stack(loss_window)).item(),
        "train_iter": i,
        "train_epoch": epoch,
    }


class NormLayer(nn.Module):
    def __init__(self, mean, std, n=None, eps=1e-8) -> None:
        super().__init__()
        self.mean = mean
        self.std = std
        self.eps = eps

    def forward(self, x):
        return (x - self.mean) / (self.std + self.eps)


def sklearn_train_linear_model(
    model: LinearModel,
    dataloader: DataLoader,
    construct_kwargs: Dict[str, Any],
    sklearn_trainer: str = "Lasso",
    norm_input: bool = False,
    **fit_kwargs,
):
    r"""
    Alternative method to train with sklearn. This does introduce some slight
    overhead as we convert the tensors to numpy and then convert the resulting
    trained model to a `LinearModel` object. However, this conversion
    should be negligible.

    Please note that this assumes:

    0. You have sklearn and numpy installed
    1. The dataset can fit into memory

    Args
        model
            The model to train.
        dataloader
            The data to use. This will be exhausted and converted to numpy
            arrays. Therefore please do not feed an infinite dataloader.
        norm_input
            Whether or not to normalize the input
        sklearn_trainer
            The sklearn model to use to train the model. Please refer to
            sklearn.linear_model for a list of modules to use.
        construct_kwargs
            Additional arguments provided to the `sklearn_trainer` constructor
        fit_kwargs
            Other arguments to send to `sklearn_trainer`'s `.fit` method
    """
    from functools import reduce

    try:
        import numpy as np
    except ImportError:
        raise ValueError("numpy is not available. Please install numpy.")

    try:
        import sklearn
        import sklearn.linear_model
        import sklearn.svm
    except ImportError:
        raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

    if not sklearn.__version__ >= "0.23.0":
        warnings.warn(
            "Must have sklearn version 0.23.0 or higher to use "
            "sample_weight in Lasso regression."
        )

    num_batches = 0
    xs, ys, ws = [], [], []
    for data in dataloader:
        if len(data) == 3:
            x, y, w = data
        else:
            assert len(data) == 2
            x, y = data
            w = None

        xs.append(x.cpu().numpy())
        ys.append(y.cpu().numpy())
        if w is not None:
            ws.append(w.cpu().numpy())
        num_batches += 1

    x = np.concatenate(xs, axis=0)
    y = np.concatenate(ys, axis=0)
    if len(ws) > 0:
        w = np.concatenate(ws, axis=0)
    else:
        w = None

    if norm_input:
        mean, std = x.mean(0), x.std(0)
        x -= mean
        x /= std

    t1 = time.time()
    sklearn_model = reduce(
        lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
    )(**construct_kwargs)
    try:
        sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
    except TypeError:
        sklearn_model.fit(x, y, **fit_kwargs)
        warnings.warn(
            "Sample weight is not supported for the provided linear model!"
            " Trained model without weighting inputs. For Lasso, please"
            " upgrade sklearn to a version >= 0.23.0."
        )

    t2 = time.time()

    # Convert weights to pytorch
    classes = (
        torch.IntTensor(sklearn_model.classes_)
        if hasattr(sklearn_model, "classes_")
        else None
    )

    # extract model device
    device = model.device if hasattr(model, "device") else "cpu"

    num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1
    weight_values = torch.FloatTensor(sklearn_model.coef_).to(device)  # type: ignore
    bias_values = torch.FloatTensor([sklearn_model.intercept_]).to(  # type: ignore
        device  # type: ignore
    )  # type: ignore
    model._construct_model_params(
        norm_type=None,
        weight_values=weight_values.view(num_outputs, -1),
        bias_value=bias_values.squeeze().unsqueeze(0),
        classes=classes,
    )

    if norm_input:
        model.norm = NormLayer(mean, std)

    return {"train_time": t2 - t1}