Spaces:
Build error
Build error
File size: 11,699 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
import time
import warnings
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn as nn
from captum._utils.models.linear_model.model import LinearModel
from torch.utils.data import DataLoader
def l2_loss(x1, x2, weights=None):
if weights is None:
return torch.mean((x1 - x2) ** 2) / 2.0
else:
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0
def sgd_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
construct_kwargs: Dict[str, Any],
max_epoch: int = 100,
reduce_lr: bool = True,
initial_lr: float = 0.01,
alpha: float = 1.0,
loss_fn: Callable = l2_loss,
reg_term: Optional[int] = 1,
patience: int = 10,
threshold: float = 1e-4,
running_loss_window: Optional[int] = None,
device: Optional[str] = None,
init_scheme: str = "zeros",
debug: bool = False,
) -> Dict[str, float]:
r"""
Trains a linear model with SGD. This will continue to iterate your
dataloader until we converged to a solution or alternatively until we have
exhausted `max_epoch`.
Convergence is defined by the loss not changing by `threshold` amount for
`patience` number of iterations.
Args:
model
The model to train
dataloader
The data to train it with. We will assume the dataloader produces
either pairs or triples of the form (x, y) or (x, y, w). Where x and
y are typical pairs for supervised learning and w is a weight
vector.
We will call `model._construct_model_params` with construct_kwargs
and the input features set to `x.shape[1]` (`x.shape[0]` corresponds
to the batch size). We assume that `len(x.shape) == 2`, i.e. the
tensor is flat. The number of output features will be set to
y.shape[1] or 1 (if `len(y.shape) == 1`); we require `len(y.shape)
<= 2`.
max_epoch
The maximum number of epochs to exhaust
reduce_lr
Whether or not to reduce the learning rate as iterations progress.
Halves the learning rate when the training loss does not move. This
uses torch.optim.lr_scheduler.ReduceLROnPlateau and uses the
parameters `patience` and `threshold`
initial_lr
The initial learning rate to use.
alpha
A constant for the regularization term.
loss_fn
The loss to optimise for. This must accept three parameters:
x1 (predicted), x2 (labels) and a weight vector
reg_term
Regularization is defined by the `reg_term` norm of the weights.
Please use `None` if you do not wish to use regularization.
patience
Defines the number of iterations in a row the loss must remain
within `threshold` in order to be classified as converged.
threshold
Threshold for convergence detection.
running_loss_window
Used to report the training loss once we have finished training and
to determine when we have converged (along with reducing the
learning rate).
The reported training loss will take the last `running_loss_window`
iterations and average them.
If `None` we will approximate this to be the number of examples in
an epoch.
init_scheme
Initialization to use prior to training the linear model.
device
The device to send the model and data to. If None then no `.to` call
will be used.
debug
Whether to print the loss, learning rate per iteration
Returns
This will return the final training loss (averaged with
`running_loss_window`)
"""
loss_window: List[torch.Tensor] = []
min_avg_loss = None
convergence_counter = 0
converged = False
def get_point(datapoint):
if len(datapoint) == 2:
x, y = datapoint
w = None
else:
x, y, w = datapoint
if device is not None:
x = x.to(device)
y = y.to(device)
if w is not None:
w = w.to(device)
return x, y, w
# get a point and construct the model
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
model._construct_model_params(
in_features=x.shape[1],
out_features=y.shape[1] if len(y.shape) == 2 else 1,
**construct_kwargs,
)
model.train()
assert model.linear is not None
if init_scheme is not None:
assert init_scheme in ["xavier", "zeros"]
with torch.no_grad():
if init_scheme == "xavier":
torch.nn.init.xavier_uniform_(model.linear.weight)
else:
model.linear.weight.zero_()
if model.linear.bias is not None:
model.linear.bias.zero_()
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
if reduce_lr:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optim, factor=0.5, patience=patience, threshold=threshold
)
t1 = time.time()
epoch = 0
i = 0
while epoch < max_epoch:
while True: # for x, y, w in dataloader
if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)
y = y.view(x.shape[0], -1)
if w is not None:
w = w.view(x.shape[0], -1)
i += 1
out = model(x)
loss = loss_fn(y, out, w)
if reg_term is not None:
reg = torch.norm(model.linear.weight, p=reg_term)
loss += reg.sum() * alpha
if len(loss_window) >= running_loss_window:
loss_window = loss_window[1:]
loss_window.append(loss.clone().detach())
assert len(loss_window) <= running_loss_window
average_loss = torch.mean(torch.stack(loss_window))
if min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > min_avg_loss or torch.isclose(
min_avg_loss, average_loss, atol=threshold
):
convergence_counter += 1
if convergence_counter >= patience:
converged = True
break
else:
convergence_counter = 0
if min_avg_loss is None or min_avg_loss >= average_loss:
min_avg_loss = average_loss.clone()
if debug:
print(
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
)
loss.backward()
optim.step()
model.zero_grad()
if scheduler:
scheduler.step(average_loss)
temp = next(data_iter, None)
if temp is None:
break
x, y, w = get_point(temp)
if converged:
break
epoch += 1
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
t2 = time.time()
return {
"train_time": t2 - t1,
"train_loss": torch.mean(torch.stack(loss_window)).item(),
"train_iter": i,
"train_epoch": epoch,
}
class NormLayer(nn.Module):
def __init__(self, mean, std, n=None, eps=1e-8) -> None:
super().__init__()
self.mean = mean
self.std = std
self.eps = eps
def forward(self, x):
return (x - self.mean) / (self.std + self.eps)
def sklearn_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
construct_kwargs: Dict[str, Any],
sklearn_trainer: str = "Lasso",
norm_input: bool = False,
**fit_kwargs,
):
r"""
Alternative method to train with sklearn. This does introduce some slight
overhead as we convert the tensors to numpy and then convert the resulting
trained model to a `LinearModel` object. However, this conversion
should be negligible.
Please note that this assumes:
0. You have sklearn and numpy installed
1. The dataset can fit into memory
Args
model
The model to train.
dataloader
The data to use. This will be exhausted and converted to numpy
arrays. Therefore please do not feed an infinite dataloader.
norm_input
Whether or not to normalize the input
sklearn_trainer
The sklearn model to use to train the model. Please refer to
sklearn.linear_model for a list of modules to use.
construct_kwargs
Additional arguments provided to the `sklearn_trainer` constructor
fit_kwargs
Other arguments to send to `sklearn_trainer`'s `.fit` method
"""
from functools import reduce
try:
import numpy as np
except ImportError:
raise ValueError("numpy is not available. Please install numpy.")
try:
import sklearn
import sklearn.linear_model
import sklearn.svm
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
)
num_batches = 0
xs, ys, ws = [], [], []
for data in dataloader:
if len(data) == 3:
x, y, w = data
else:
assert len(data) == 2
x, y = data
w = None
xs.append(x.cpu().numpy())
ys.append(y.cpu().numpy())
if w is not None:
ws.append(w.cpu().numpy())
num_batches += 1
x = np.concatenate(xs, axis=0)
y = np.concatenate(ys, axis=0)
if len(ws) > 0:
w = np.concatenate(ws, axis=0)
else:
w = None
if norm_input:
mean, std = x.mean(0), x.std(0)
x -= mean
x /= std
t1 = time.time()
sklearn_model = reduce(
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
)(**construct_kwargs)
try:
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
except TypeError:
sklearn_model.fit(x, y, **fit_kwargs)
warnings.warn(
"Sample weight is not supported for the provided linear model!"
" Trained model without weighting inputs. For Lasso, please"
" upgrade sklearn to a version >= 0.23.0."
)
t2 = time.time()
# Convert weights to pytorch
classes = (
torch.IntTensor(sklearn_model.classes_)
if hasattr(sklearn_model, "classes_")
else None
)
# extract model device
device = model.device if hasattr(model, "device") else "cpu"
num_outputs = sklearn_model.coef_.shape[0] if sklearn_model.coef_.ndim > 1 else 1
weight_values = torch.FloatTensor(sklearn_model.coef_).to(device) # type: ignore
bias_values = torch.FloatTensor([sklearn_model.intercept_]).to( # type: ignore
device # type: ignore
) # type: ignore
model._construct_model_params(
norm_type=None,
weight_values=weight_values.view(num_outputs, -1),
bias_value=bias_values.squeeze().unsqueeze(0),
classes=classes,
)
if norm_input:
model.norm = NormLayer(mean, std)
return {"train_time": t2 - t1}
|