ngme-babylm-100M / ngme.py
PatrickHaller's picture
Upload NGMEForCausalLM
d59f795
raw
history blame
4.84 kB
import math
from typing import Optional, List
from functools import lru_cache
from itertools import chain, tee
import torch
import torch.nn.functional as F
n_dists = {
0: [1],
1: [0.4, 0.6],
2: [0.2, 0.3, 0.5],
3: [0.1, 0.2, 0.3, 0.4],
4: [0.1, 0.15, 0.2, 0.25, 0.3],
}
strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
def pad_sequence(
sequence,
n,
pad_left=False,
pad_right=False,
left_pad_symbol=None,
right_pad_symbol=None,
):
"""Copied from NLTK"""
sequence = iter(sequence)
if pad_left:
sequence = chain((left_pad_symbol,) * (n - 1), sequence)
if pad_right:
sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
return sequence
def ngrams(sequence, n, **kwargs):
"""Copied from NLTK"""
sequence = pad_sequence(sequence, n, **kwargs)
# Creates the sliding window, of n no. of items.
# `iterables` is a tuple of iterables where each iterable is a window of n items.
iterables = tee(sequence, n)
for i, sub_iterable in enumerate(iterables): # For each window,
for _ in range(i): # iterate through every order of ngrams
next(sub_iterable, None) # generate the ngrams within the window.
return zip(*iterables) # Unpack and flattens the iterables.
@lru_cache(maxsize=5)
def soft_dist(n):
return [1 / n] * n
@lru_cache(maxsize=5)
def n_dist(n: int, strategy: str) -> list[float]:
"""dist of ngram weight is logarithmic"""
ns = list(range(1, n + 1))
xs = list(map(strats[strategy], ns))
result = list(map(lambda x: x / sum(xs), xs))
return result
def soft_n_hot(
input,
num_classes: int,
strategy: Optional[str],
):
shape = list(input.size())[1:]
shape.append(num_classes)
ret = torch.zeros(shape).to(input.device)
if strategy:
soft_labels = n_dist(input.size(0), strategy)
else:
soft_labels = [1] * input.size(0)
for i, t in enumerate(input):
ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
return ret
def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
shape = list(t.size())
if ngram_sequences is not None:
shape.append(num_clases)
ret = torch.zeros(shape).to(t.device)
ret.scatter_(-1, t.unsqueeze(-1), 1)
for seq in ngram_sequences:
if unk_idx is not None:
mask = torch.eq(seq, unk_idx)
seq[mask] = t[mask]
ret.scatter_(-1, seq.unsqueeze(-1), 1)
return ret
elif len(shape) == 2:
return F.one_hot(t, num_classes=num_clases).float()
else:
shape = shape[1:]
shape.append(num_clases)
ret = torch.zeros(shape).to(t.device)
# Expect that first dimension is for all n-grams
for seq in t:
ret.scatter_(-1, seq.unsqueeze(-1), 1)
return ret
class NGramsEmbedding(torch.nn.Embedding):
"""N-Hot encoder"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[torch.Tensor] = None,
device=None,
dtype=None,
unk_idx: Optional[int] = None
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
self.num_classes = num_embeddings
self.unk_idx = unk_idx
def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
return self._forward(
n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
)
def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
return F.linear(n_hot, self.weight.t())
def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
sequences = []
for n in range(2, len(kwargs)+2):
s = kwargs[f"gram_{n}_sequence"]
if s is not None:
sequences.append(s)
else:
break
return sequences
def shift_with_pad(target_tensor, n, from_tensor):
shifted = target_tensor[:, n:]
seq_size = target_tensor.size(1) - 1
missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
# Pad with missing idxs from unigram tensor
shifted = torch.concat(
(shifted, from_tensor.index_select(1, missing_idxs)), dim=1
)
return shifted