Spaces:
Runtime error
Runtime error
File size: 6,943 Bytes
4c2c4e8 4e46e20 4c2c4e8 4e46e20 4c2c4e8 4e46e20 4c2c4e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
#%%
import transformers
import torch as t
import torch.nn as nn
from typing import Union, List
from fancy_einsum import einsum
import torch as t
import attention_replication
# %%
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
if __name__ == "__main__":
print(tokenizer("hello meg"))
print(tokenizer.encode("hello meg"))
print(tokenizer.decode([31373, 17243]))
print(tokenizer.tokenize("hello meg"))
print(f"'{tokenizer.decode(17243)}'")
# %%
class Embedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(t.randn((self.num_embeddings, self.embedding_dim)))
def forward(self, x: t.LongTensor) -> t.Tensor:
'''For each integer in the input, return that row of the embedding.
'''
#return einsum('num_embeddings embedding_dim, i num_embeddings -> i embedding_dim', self.weight, nn.functional.one_hot(x, num_classes=self.num_embeddings).float())
return self.weight[x]
def extra_repr(self) -> str:
return f"{self.num_embeddings}, {self.embedding_dim}"
# %%
#TODO positional encoding
class PositionalEncoding(nn.Module):
def __init__(self, max_seq_len: int, embedding_dim: int):
super().__init__()
# Defining our positional encoding array, with `max_seq_len` rows
# This is an advantage of using sinusoidal encoding: we can easily expand to sequences of greater length without adding more learned params
angles = t.outer(t.arange(max_seq_len), 1 / 10000 ** (2 * t.arange(embedding_dim//2) / embedding_dim))
pe = t.zeros((max_seq_len, embedding_dim))
pe[:, ::2] = t.sin(angles)
pe[:, 1::2] = t.cos(angles)
# Register array as a buffer, rather than parameter (we don't want it to be updated by gradient descent)
self.register_buffer('pe', pe)
def forward(self, x: t.Tensor) -> t.Tensor:
"""
x: shape (batch, seq_len, embedding_dim)
"""
batch, seq_len, embedding_dim = x.shape
# We slice the positional encoding, so it's the same shape as x
# This is equivalent to just using an nn.Embedding, but having the input be t.arange(seq_len)
return x + self.pe[:seq_len, :] # type: ignore
# %%
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: Union[int, List[int]], eps: float = 1e-05, elementwise_affine: bool = True):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(t.ones(normalized_shape))
self.bias = nn.Parameter(t.zeros(normalized_shape))
def forward(self, x: t.Tensor) -> t.Tensor:
normalized_shape_dims = 1 if isinstance(self.normalized_shape, int) else len(self.normalized_shape)
x_mean = x.mean(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True) # complement of the normalised shape
x_var = x.var(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True, unbiased=False) # complement of the normalised shape
x_scaled = (x - x_mean) / t.sqrt(x_var + self.eps)
if self.elementwise_affine:
return x_scaled * self.weight + self.bias
return x_scaled
def extra_repr(self) -> str:
pass
# %%
class TransformerConfig:
'''Constants used throughout your decoder-only transformer model.'''
num_layers: int
num_heads: int
vocab_size: int
hidden_size: int
max_seq_len: int
dropout: float
layer_norm_epsilon: float
def __init__(
self, num_layers, num_heads, vocab_size, hidden_size, max_seq_len,
dropout=0.1, layer_norm_epsilon=1e-5,
) -> None:
self.num_layers = num_layers
self.num_heads = num_heads
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.dropout = dropout
self.layer_norm_epsilon = layer_norm_epsilon
# %%
class BertMLP(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.linear1 = nn.Linear(config.hidden_size, 4 * config.hidden_size)
self.gelu = nn.GELU()
self.linear2 = nn.Linear(4 * config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: t.Tensor) -> t.Tensor:
x = self.linear1(x)
x = self.gelu(x)
x = self.linear2(x)
x = self.dropout(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.attention = attention_replication.MultiheadMaskedAttention(config.hidden_size, config.num_heads)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
self.mlp = BertMLP(config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
def forward(self, x: t.Tensor) -> t.Tensor:
y = self.attention(x)
y = self.layer_norm1(y)
x = x + y
z = self.mlp(x)
z = self.layer_norm2(z)
x = x + z
return x
class DecoderOnlyTransformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.token_embedding = Embedding(config.vocab_size, config.hidden_size)
self.positional_embedding = PositionalEncoding(config.max_seq_len, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.bert_blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.num_layers)])
self.layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
def forward(self, x: t.Tensor) -> t.Tensor:
x = self.token_embedding(x)
x = self.positional_embedding(x)
x = self.dropout(x)
for block in self.bert_blocks:
x = block(x)
x = self.layer_norm(x)
x = einsum('num_embeddings embedding_dim,batch seq_len embedding_dim ->batch seq_len num_embeddings', self.token_embedding.weight, x)
return x
# %%
from torch.utils.data import Dataset
class CustomTextDataset(Dataset):
def __init__(self, texts, labels):
self.labels = labels
self.texts = texts
@staticmethod
def from_config(config, samples):
texts = [t.randint(high=config.vocab_size, size=(config.max_seq_len,)) for _ in range(samples)]
labels = [t.flip(text, (0,)) for text in texts]
return CustomTextDataset(texts, labels)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
text = self.texts[idx]
sample = (text, label)
return sample
|