|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def find_multiple(n: int, k: int) -> int: |
|
if k == 0 or n % k == 0: |
|
return n |
|
return n + k - (n % k) |
|
|
|
|
|
def pad_weight_(w: nn.Embedding | nn.Linear, multiple: int): |
|
"""Pad the weight of an embedding or linear layer to a multiple of `multiple`.""" |
|
if isinstance(w, nn.Embedding): |
|
|
|
if w.weight.shape[1] % multiple == 0: |
|
return |
|
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[1] % multiple)) |
|
w.num_embeddings, w.embedding_dim = w.weight.shape |
|
elif isinstance(w, nn.Linear): |
|
|
|
if w.weight.shape[0] % multiple == 0: |
|
return |
|
w.weight.data = F.pad(w.weight.data, (0, 0, 0, w.weight.shape[0] % multiple)) |
|
w.out_features, w.in_features = w.weight.shape |
|
else: |
|
raise ValueError(f"Unsupported weight type: {type(w)}") |
|
|
|
|
|
def get_device() -> torch.device: |
|
if torch.cuda.is_available(): |
|
return torch.device(torch.cuda.current_device()) |
|
|
|
|
|
|
|
return torch.device("cpu") |
|
|
|
|
|
DEFAULT_DEVICE = get_device() |
|
|