Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,033 Bytes
3b609b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
class IncrementalPCA:
"""
An implementation of Incremental Principal Components Analysis (IPCA) that leverages PyTorch for GPU acceleration.
Adapted from https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/decomposition/_incremental_pca.py
This class provides methods to fit the model on data incrementally in batches, and to transform new data based on
the principal components learned during the fitting process.
Args:
n_components (int, optional): Number of components to keep. If `None`, it's set to the minimum of the
number of samples and features. Defaults to None.
copy (bool): If False, input data will be overwritten. Defaults to True.
batch_size (int, optional): The number of samples to use for each batch. Only needed if self.fit is called.
If `None`, it's inferred from the data and set to `5 * n_features`. Defaults to None.
svd_driver (str, optional): name of the cuSOLVER method to be used for torch.linalg.svd. This keyword
argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj, and gesvda. Defaults to
None.
lowrank (bool, optional): Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster.
Defaults to False.
lowrank_q (int, optional): For an adequate approximation of n_components, this parameter defaults to
n_components * 2.
lowrank_niter (int, optional): Number of subspace iterations to conduct for torch.svd_lowrank.
Defaults to 4.
lowrank_seed (int, optional): Seed for making results of torch.svd_lowrank reproducible.
"""
def __init__(
self,
n_components: Optional[int] = None,
copy: Optional[bool] = True,
batch_size: Optional[int] = None,
svd_driver: Optional[str] = None,
lowrank: bool = False,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
self.n_components = n_components
self.copy = copy
self.batch_size = batch_size
self.svd_driver = svd_driver
self.lowrank = lowrank
self.lowrank_q = lowrank_q
self.lowrank_niter = lowrank_niter
self.lowrank_seed = lowrank_seed
self.n_features_ = None
if self.lowrank:
self._validate_lowrank_params()
def _validate_lowrank_params(self):
if self.lowrank_q is None:
if self.n_components is None:
raise ValueError("n_components must be specified when using lowrank mode with lowrank_q=None.")
self.lowrank_q = self.n_components * 2
elif self.lowrank_q < self.n_components:
raise ValueError("lowrank_q must be greater than or equal to n_components.")
def _svd_fn_full(self, X):
return torch.linalg.svd(X, full_matrices=False, driver=self.svd_driver)
def _svd_fn_lowrank(self, X):
seed_enabled = self.lowrank_seed is not None
with torch.random.fork_rng(enabled=seed_enabled):
if seed_enabled:
torch.manual_seed(self.lowrank_seed)
U, S, V = torch.svd_lowrank(X, q=self.lowrank_q, niter=self.lowrank_niter)
return U, S, V.mH
def _validate_data(self, X) -> torch.Tensor:
"""
Validates and converts the input data `X` to the appropriate tensor format.
Args:
X (torch.Tensor): Input data.
Returns:
torch.Tensor: Converted to appropriate format.
"""
valid_dtypes = [torch.float32, torch.float64]
if not isinstance(X, torch.Tensor):
X = torch.tensor(X, dtype=torch.float32)
elif self.copy:
X = X.clone()
n_samples, n_features = X.shape
if self.n_components is None:
pass
elif self.n_components > n_features:
raise ValueError(
f"n_components={self.n_components} invalid for n_features={n_features}, "
"need more rows than columns for IncrementalPCA processing."
)
elif self.n_components > n_samples:
raise ValueError(
f"n_components={self.n_components} must be less or equal to the batch number of samples {n_samples}"
)
if X.dtype not in valid_dtypes:
X = X.to(torch.float32)
return X
@staticmethod
def _incremental_mean_and_var(
X, last_mean, last_variance, last_sample_count
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the incremental mean and variance for the data `X`.
Args:
X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
last_mean (torch.Tensor): The previous mean tensor with shape (n_features,).
last_variance (torch.Tensor): The previous variance tensor with shape (n_features,).
last_sample_count (torch.Tensor): The count tensor of samples processed before the current batch.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Updated mean, variance tensors, and total sample count.
"""
if X.shape[0] == 0:
return last_mean, last_variance, last_sample_count
if last_sample_count > 0:
if last_mean is None:
raise ValueError("last_mean should not be None if last_sample_count > 0.")
if last_variance is None:
raise ValueError("last_variance should not be None if last_sample_count > 0.")
new_sample_count = torch.tensor([X.shape[0]], device=X.device)
updated_sample_count = last_sample_count + new_sample_count
if last_mean is None:
last_sum = torch.zeros(X.shape[1], dtype=torch.float64, device=X.device)
else:
last_sum = last_mean * last_sample_count
new_sum = X.sum(dim=0, dtype=torch.float64)
updated_mean = (last_sum + new_sum) / updated_sample_count
T = new_sum / new_sample_count
temp = X - T
correction = temp.sum(dim=0, dtype=torch.float64).square()
temp.square_()
new_unnormalized_variance = temp.sum(dim=0, dtype=torch.float64)
new_unnormalized_variance -= correction / new_sample_count
if last_variance is None:
updated_variance = new_unnormalized_variance / updated_sample_count
else:
last_unnormalized_variance = last_variance * last_sample_count
last_over_new_count = last_sample_count.double() / new_sample_count
updated_unnormalized_variance = (
last_unnormalized_variance
+ new_unnormalized_variance
+ last_over_new_count / updated_sample_count * (last_sum / last_over_new_count - new_sum).square()
)
updated_variance = updated_unnormalized_variance / updated_sample_count
return updated_mean, updated_variance, updated_sample_count
@staticmethod
def _svd_flip(u, v, u_based_decision=True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Adjusts the signs of the singular vectors from the SVD decomposition for deterministic output.
This method ensures that the output remains consistent across different runs.
Args:
u (torch.Tensor): Left singular vectors tensor.
v (torch.Tensor): Right singular vectors tensor.
u_based_decision (bool, optional): If True, uses the left singular vectors to determine the sign flipping.
Defaults to True.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Adjusted left and right singular vectors tensors.
"""
if u_based_decision:
max_abs_cols = torch.argmax(torch.abs(u), dim=0)
signs = torch.sign(u[max_abs_cols, range(u.shape[1])])
else:
max_abs_rows = torch.argmax(torch.abs(v), dim=1)
signs = torch.sign(v[range(v.shape[0]), max_abs_rows])
u *= signs[: u.shape[1]].view(1, -1)
v *= signs.view(-1, 1)
return u, v
def fit(self, X, check_input=True):
"""
Fits the model with data `X` using minibatches of size `batch_size`.
Args:
X (torch.Tensor): The input data tensor with shape (n_samples, n_features).
check_input (bool, optional): If True, validates the input. Defaults to True.
Returns:
IncrementalPCA: The fitted IPCA model.
"""
if check_input:
X = self._validate_data(X)
n_samples, n_features = X.shape
if self.batch_size is None:
self.batch_size = 5 * n_features
for batch in self.gen_batches(n_samples, self.batch_size, min_batch_size=self.n_components or 0):
self.partial_fit(X[batch], check_input=False)
return self
def partial_fit(self, X, check_input=True):
"""
Incrementally fits the model with batch data `X`.
Args:
X (torch.Tensor): The batch input data tensor with shape (n_samples, n_features).
check_input (bool, optional): If True, validates the input. Defaults to True.
Returns:
IncrementalPCA: The updated IPCA model after processing the batch.
"""
first_pass = not hasattr(self, "components_")
if check_input:
X = self._validate_data(X)
n_samples, n_features = X.shape
# Initialize attributes to avoid errors during the first call to partial_fit
if first_pass:
self.mean_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.var_ = None # Will be initialized properly in _incremental_mean_and_var based on data dimensions
self.n_samples_seen_ = torch.tensor([0], device=X.device)
self.n_features_ = n_features
if not self.n_components:
self.n_components = min(n_samples, n_features)
if n_features != self.n_features_:
raise ValueError(
"Number of features of the new batch does not match the number of features of the first batch."
)
col_mean, col_var, n_total_samples = self._incremental_mean_and_var(
X, self.mean_, self.var_, self.n_samples_seen_
)
if first_pass:
X -= col_mean
else:
col_batch_mean = torch.mean(X, dim=0)
X -= col_batch_mean
mean_correction_factor = torch.sqrt((self.n_samples_seen_.double() / n_total_samples) * n_samples)
mean_correction = mean_correction_factor * (self.mean_ - col_batch_mean)
X = torch.vstack(
(
self.singular_values_.view((-1, 1)) * self.components_,
X,
mean_correction,
)
)
if self.lowrank:
U, S, Vt = self._svd_fn_lowrank(X)
else:
U, S, Vt = self._svd_fn_full(X)
U, Vt = self._svd_flip(U, Vt, u_based_decision=False)
explained_variance = S**2 / (n_total_samples - 1)
explained_variance_ratio = S**2 / torch.sum(col_var * n_total_samples)
self.n_samples_seen_ = n_total_samples
self.components_ = Vt[: self.n_components]
self.singular_values_ = S[: self.n_components]
self.mean_ = col_mean
self.var_ = col_var
self.explained_variance_ = explained_variance[: self.n_components]
self.explained_variance_ratio_ = explained_variance_ratio[: self.n_components]
if self.n_components not in (n_samples, n_features):
self.noise_variance_ = explained_variance[self.n_components :].mean()
else:
self.noise_variance_ = torch.tensor(0.0, device=X.device)
return self
def transform(self, X) -> torch.Tensor:
"""
Applies dimensionality reduction to `X`.
The input data `X` is projected on the first principal components previously extracted from a training set.
Args:
X (torch.Tensor): New data tensor with shape (n_samples, n_features) to be transformed.
Returns:
torch.Tensor: Transformed data tensor with shape (n_samples, n_components).
"""
X = X - self.mean_
return torch.mm(X.double(), self.components_.T).to(X.dtype)
@staticmethod
def gen_batches(n: int, batch_size: int, min_batch_size: int = 0):
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
The last slice may contain less than `batch_size` elements, when `batch_size` does not divide `n`.
Args:
n (int): Size of the sequence.
batch_size (int): Number of elements in each batch.
min_batch_size (int, optional): Minimum number of elements in each batch. Defaults to 0.
Yields:
slice: A slice of `batch_size` elements.
"""
start = 0
for _ in range(int(n // batch_size)):
end = start + batch_size
if end + min_batch_size > n:
continue
yield slice(start, end)
start = end
if start < n:
yield slice(start, n)
|