skar0 commited on
Commit
4c2c4e8
·
1 Parent(s): 8c79c01

Initial commit

Browse files
Files changed (11) hide show
  1. 100-0.txt +0 -0
  2. Dockerfile +11 -0
  3. Procfile +1 -0
  4. app.py +13 -0
  5. attention_replication.py +156 -0
  6. config.yaml +61 -0
  7. env.yaml +406 -0
  8. sampling.py +239 -0
  9. shakespeare_demo.py +105 -0
  10. transformer_replication.py +183 -0
  11. word_data.py +100 -0
100-0.txt ADDED
The diff for this file is too large to render. See raw diff
 
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create environment
2
+ FROM mambaorg/micromamba:1.3.1
3
+ COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml /tmp/env.yaml
4
+ RUN micromamba install --yes --file /tmp/env.yaml && \
5
+ micromamba clean --all --yes
6
+
7
+ # Run app
8
+ COPY . /app/
9
+ WORKDIR /app/
10
+ ARG MAMBA_DOCKERFILE_ACTIVATE=1
11
+ RUN python app.py
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: gunicorn app:app
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask
2
+ import os
3
+ from shakespeare_demo import make_demo
4
+
5
+ app = Flask(__name__)
6
+
7
+ @app.route("/")
8
+ def hello_world():
9
+ return make_demo()
10
+
11
+ if __name__ == "__main__":
12
+ port = int(os.environ.get('PORT', 5999))
13
+ app.run(debug=True, host='0.0.0.0', port=port)
attention_replication.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch as t
3
+ import torch.nn as nn
4
+ from typing import Union, List
5
+ from fancy_einsum import einsum
6
+ from einops import repeat, rearrange, reduce
7
+ import numpy as np
8
+ #%%
9
+ def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
10
+ '''
11
+ Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).
12
+
13
+ With this function, you can ignore masking.
14
+
15
+ Q: shape (batches x seq_Q x head_size)
16
+ K: shape (batches x seq_K x head_size)
17
+ V: shape (batches x seq_K x head_size)
18
+
19
+ Return: shape (batches x seq_Q x head_size)
20
+ '''
21
+
22
+ attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
23
+ #Ignore masking
24
+ attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
25
+ attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
26
+ return attention_values
27
+
28
+ def test_single_head_attention_shape(single_head_attention):
29
+ Q = t.randn(1, 3, 2)
30
+ K = t.randn(1, 5, 2)
31
+ V = t.randn(1, 5, 2)
32
+ attention_values = single_head_attention(Q, K, V)
33
+ assert Q.shape == attention_values.shape
34
+ print(f"All tests in `test_single_head_attention_shape` passed.")
35
+
36
+ def test_single_head_attention(single_head_attention):
37
+ Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
38
+ K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
39
+ V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
40
+ attention_values = single_head_attention(Q.float(), K.float(), V.float())
41
+ t.testing.assert_close(attention_values, t.tensor([[[9.7880e-04, 9.9902e-01, 9.7880e-04], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
42
+ print(f"All tests in `test_single_head_attention` passed.")
43
+
44
+ if __name__ == "__main__":
45
+ test_single_head_attention_shape(single_head_attention)
46
+ test_single_head_attention(single_head_attention)
47
+ # %%
48
+ def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
49
+ '''
50
+ Should return the results of masked self-attention.
51
+
52
+ See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.
53
+
54
+ Q: shape (batches x seq_Q x head_size)
55
+ K: shape (batches x seq_K x head_size)
56
+ V: shape (batches x seq_K x head_size)
57
+
58
+ Return: shape (batches x seq_Q x head_size)
59
+ '''
60
+ attention_scores = einsum('batches seq_Q head_size, batches seq_K head_size -> batches seq_Q seq_K', Q, K)
61
+ batches, seq_Q, head_size = Q.shape
62
+ batches, seq_K, head_size = K.shape
63
+
64
+ q_index = repeat(t.arange(0, seq_Q), 'q -> b q k', b=batches, k=seq_K)
65
+ k_index = repeat(t.arange(0, seq_K), 'k -> b q k', b=batches, q=seq_Q)
66
+ mask = k_index <= q_index
67
+ attention_scores = t.where(mask, attention_scores, -t.inf)
68
+ attention_probabilities = nn.functional.softmax(attention_scores / np.sqrt(Q.shape[-1]), dim=2)
69
+ attention_values = einsum('batches seq_Q seq_K, batches seq_K head_size -> batches seq_Q head_size', attention_probabilities, V)
70
+ return attention_values
71
+
72
+ def test_single_head_masked_attention(single_head_masked_attention):
73
+ Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
74
+ K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
75
+ V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
76
+ attention_values = single_head_masked_attention(Q.float(), K.float(), V.float())
77
+ t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
78
+ print(f"All tests in `test_single_head_attention` passed.")
79
+
80
+ if __name__ == "__main__":
81
+ test_single_head_attention_shape(single_head_masked_attention)
82
+ test_single_head_masked_attention(single_head_masked_attention)
83
+ # %%
84
+ def multihead_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor, num_heads: int):
85
+ '''
86
+ Implements multihead masked attention on the matrices Q, K and V.
87
+
88
+ Q: shape (batch, seq, nheads*headsize)
89
+ K: shape (batch, seq, nheads*headsize)
90
+ V: shape (batch, seq, nheads*headsize)
91
+
92
+ returns: shape (batch, seq, nheads*headsize)
93
+ '''
94
+ new_Q = rearrange(Q, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
95
+ new_K = rearrange(K, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
96
+ new_V = rearrange(V, 'batch seq (nheads headsize) -> batch nheads seq headsize', nheads=num_heads)
97
+
98
+ attention_scores = einsum('batches nheads seq_Q head_size, batches nheads seq_K head_size -> batches nheads seq_Q seq_K', new_Q, new_K)
99
+ batches, _, seq_Q, head_size = new_Q.shape
100
+ batches, _, seq_K, head_size = new_K.shape
101
+ q_index = repeat(t.arange(0, seq_Q), 'seq_Q -> batches nheads seq_Q seq_K', batches=batches, seq_K=seq_K, nheads=num_heads)
102
+ k_index = repeat(t.arange(0, seq_K), 'seq_K -> batches nheads seq_Q seq_K', batches=batches, seq_Q=seq_Q, nheads=num_heads)
103
+ mask = k_index <= q_index
104
+ device_inf = t.tensor(-np.inf).to(Q.device)
105
+ device_mask = mask.to(Q.device)
106
+ masked_attention_scores = t.where(device_mask, attention_scores, device_inf)
107
+ attention_probabilities = nn.functional.softmax(masked_attention_scores / np.sqrt(head_size), dim=-1)
108
+ attention_values = einsum('batches nheads seq_Q seq_K, batches nheads seq_K head_size -> batches seq_Q nheads head_size', attention_probabilities, new_V)
109
+ return rearrange(attention_values, 'batches seq_Q nheads head_size -> batches seq_Q (nheads head_size)')
110
+
111
+ def test_multihead_masked_attention(multihead_masked_attention):
112
+ Q = t.tensor([[[7, 4, 1], [6, 3, 0], [5, 2, 1]]])
113
+ K = t.tensor([[[1, 3, 5], [2, 4, 6]]])
114
+ V = t.tensor([[[1, 0, 1], [0, 1, 0]]])
115
+ attention_values = multihead_masked_attention(Q.float(), K.float(), V.float(), num_heads=1)
116
+ t.testing.assert_close(attention_values, t.tensor([[[1, 0, 1], [5.5073e-03, 9.9449e-01, 5.5073e-03], [9.7682e-03, 9.9023e-01, 9.7682e-03]]]), rtol=0.01, atol=0.001)
117
+ print(f"All tests in `test_multihead_masked_attention` passed.")
118
+
119
+ if __name__ == "__main__":
120
+ test_multihead_masked_attention(multihead_masked_attention)
121
+ # %%
122
+ class MultiheadMaskedAttention(nn.Module):
123
+ W_QKV: nn.Linear
124
+ W_O: nn.Linear
125
+
126
+ def __init__(self, hidden_size: int, num_heads: int):
127
+ super().__init__()
128
+ self.hidden_size = hidden_size
129
+ self.num_heads = num_heads
130
+ assert self.hidden_size % self.num_heads == 0
131
+ self.W_QKV = nn.Linear(hidden_size, 3 * hidden_size)
132
+ self.W_O = nn.Linear(hidden_size, hidden_size)
133
+
134
+ def forward(self, x: t.Tensor) -> t.Tensor:
135
+ '''
136
+ x: shape (batch, seq, hidden_size)
137
+
138
+ Return: shape (batch, seq, hidden_size)
139
+ '''
140
+ QKV = self.W_QKV(x)
141
+ Q = QKV[..., :self.hidden_size]
142
+ K = QKV[..., self.hidden_size:-self.hidden_size]
143
+ V = QKV[..., -self.hidden_size:]
144
+ attention_values = multihead_masked_attention(Q, K, V, self.num_heads)
145
+ return self.W_O(attention_values)
146
+ # %%
147
+ def test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention):
148
+ mma = MultiheadMaskedAttention(1, 1)
149
+ x = t.randn(2, 7, 1)
150
+ output = mma.forward(x)
151
+ assert x.shape == output.shape
152
+ print(f"All tests in `test_MultiheadMaskedAttention_shape` passed.")
153
+
154
+ if __name__ == "__main__":
155
+ test_MultiheadMaskedAttention_shape(MultiheadMaskedAttention)
156
+ # %%
config.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.13.5
7
+ framework: huggingface
8
+ huggingface_version: 4.24.0
9
+ is_jupyter_run: true
10
+ is_kaggle_kernel: false
11
+ python_version: 3.10.6
12
+ start_time: 1668083783.928274
13
+ t:
14
+ 1:
15
+ - 1
16
+ - 11
17
+ - 41
18
+ - 49
19
+ - 55
20
+ 2:
21
+ - 1
22
+ - 11
23
+ - 41
24
+ - 49
25
+ - 55
26
+ 3:
27
+ - 1
28
+ - 2
29
+ - 3
30
+ - 23
31
+ - 37
32
+ 4: 3.10.6
33
+ 5: 0.13.5
34
+ 6: 4.24.0
35
+ 8:
36
+ - 1
37
+ - 5
38
+ batch_size:
39
+ desc: null
40
+ value: 64
41
+ dropout:
42
+ desc: null
43
+ value: 0.1
44
+ epochs:
45
+ desc: null
46
+ value: 2
47
+ hidden_size:
48
+ desc: null
49
+ value: 512
50
+ lr:
51
+ desc: null
52
+ value: 0.001
53
+ max_seq_len:
54
+ desc: null
55
+ value: 60
56
+ num_heads:
57
+ desc: null
58
+ value: 8
59
+ num_layers:
60
+ desc: null
61
+ value: 6
env.yaml ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: base
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=conda_forge
8
+ - _openmp_mutex=4.5=2_kmp_llvm
9
+ - aiofiles=23.1.0=pyhd8ed1ab_0
10
+ - aiohttp=3.8.4=py310h1fa729e_0
11
+ - aiosignal=1.3.1=pyhd8ed1ab_0
12
+ - alsa-lib=1.2.8=h166bdaf_0
13
+ - altair=4.2.2=pyhd8ed1ab_0
14
+ - anyio=3.6.2=pyhd8ed1ab_0
15
+ - argon2-cffi=21.3.0=pyhd8ed1ab_0
16
+ - argon2-cffi-bindings=21.2.0=py310h5764c6d_3
17
+ - arrow-cpp=11.0.0=ha770c72_4_cpu
18
+ - asttokens=2.2.1=pyhd8ed1ab_0
19
+ - async-timeout=4.0.2=pyhd8ed1ab_0
20
+ - attr=2.5.1=h166bdaf_1
21
+ - attrs=22.2.0=pyh71513ae_0
22
+ - aws-c-auth=0.6.24=h565b4ff_2
23
+ - aws-c-cal=0.5.20=h679401e_5
24
+ - aws-c-common=0.8.10=h0b41bf4_0
25
+ - aws-c-compression=0.2.16=hbe6ad0c_2
26
+ - aws-c-event-stream=0.2.18=h489b7ba_4
27
+ - aws-c-http=0.7.4=hb2c4a47_0
28
+ - aws-c-io=0.13.15=head7655_1
29
+ - aws-c-mqtt=0.8.6=haf0be06_3
30
+ - aws-c-s3=0.2.4=h05be983_0
31
+ - aws-c-sdkutils=0.1.7=hbe6ad0c_2
32
+ - aws-checksums=0.1.14=hbe6ad0c_2
33
+ - aws-crt-cpp=0.19.7=h9b63b7c_3
34
+ - aws-sdk-cpp=1.10.57=hd557813_3
35
+ - backcall=0.2.0=pyh9f0ad1d_0
36
+ - backports=1.0=pyhd8ed1ab_3
37
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
38
+ - beautifulsoup4=4.11.2=pyha770c72_0
39
+ - blas=2.116=mkl
40
+ - blas-devel=3.9.0=16_linux64_mkl
41
+ - bleach=6.0.0=pyhd8ed1ab_0
42
+ - brotli=1.0.9=h166bdaf_8
43
+ - brotli-bin=1.0.9=h166bdaf_8
44
+ - brotlipy=0.7.0=py310h5764c6d_1005
45
+ - bzip2=1.0.8=h7f98852_4
46
+ - c-ares=1.18.1=h7f98852_0
47
+ - ca-certificates=2022.12.7=ha878542_0
48
+ - cairo=1.16.0=ha61ee94_1014
49
+ - certifi=2022.12.7=pyhd8ed1ab_0
50
+ - cffi=1.15.1=py310h255011f_3
51
+ - charset-normalizer=2.1.1=pyhd8ed1ab_0
52
+ - click=8.1.3=unix_pyhd8ed1ab_2
53
+ - colorama=0.4.6=pyhd8ed1ab_0
54
+ - comm=0.1.2=pyhd8ed1ab_0
55
+ - contourpy=1.0.7=py310hdf3cbec_0
56
+ - cryptography=39.0.1=py310h34c0648_0
57
+ - cuda=11.6.1=0
58
+ - cuda-cccl=11.6.55=hf6102b2_0
59
+ - cuda-command-line-tools=11.6.2=0
60
+ - cuda-compiler=11.6.2=0
61
+ - cuda-cudart=11.6.55=he381448_0
62
+ - cuda-cudart-dev=11.6.55=h42ad0f4_0
63
+ - cuda-cuobjdump=11.6.124=h2eeebcb_0
64
+ - cuda-cupti=11.6.124=h86345e5_0
65
+ - cuda-cuxxfilt=11.6.124=hecbf4f6_0
66
+ - cuda-driver-dev=11.6.55=0
67
+ - cuda-gdb=12.0.140=0
68
+ - cuda-libraries=11.6.1=0
69
+ - cuda-libraries-dev=11.6.1=0
70
+ - cuda-memcheck=11.8.86=0
71
+ - cuda-nsight=12.0.140=0
72
+ - cuda-nsight-compute=12.0.1=0
73
+ - cuda-nvcc=11.6.124=hbba6d2d_0
74
+ - cuda-nvdisasm=12.0.140=0
75
+ - cuda-nvml-dev=11.6.55=haa9ef22_0
76
+ - cuda-nvprof=12.0.146=0
77
+ - cuda-nvprune=11.6.124=he22ec0a_0
78
+ - cuda-nvrtc=11.6.124=h020bade_0
79
+ - cuda-nvrtc-dev=11.6.124=h249d397_0
80
+ - cuda-nvtx=11.6.124=h0630a44_0
81
+ - cuda-nvvp=12.0.146=0
82
+ - cuda-runtime=11.6.1=0
83
+ - cuda-samples=11.6.101=h8efea70_0
84
+ - cuda-sanitizer-api=12.0.140=0
85
+ - cuda-toolkit=11.6.1=0
86
+ - cuda-tools=11.6.1=0
87
+ - cuda-visual-tools=11.6.1=0
88
+ - cycler=0.11.0=pyhd8ed1ab_0
89
+ - dataclasses=0.8=pyhc8e2a94_3
90
+ - datasets=2.9.0=pyhd8ed1ab_0
91
+ - dbus=1.13.6=h5008d03_3
92
+ - debugpy=1.6.6=py310heca2aa9_0
93
+ - decorator=5.1.1=pyhd8ed1ab_0
94
+ - defusedxml=0.7.1=pyhd8ed1ab_0
95
+ - dill=0.3.6=pyhd8ed1ab_1
96
+ - einops=0.6.0=pyhd8ed1ab_0
97
+ - entrypoints=0.4=pyhd8ed1ab_0
98
+ - executing=1.2.0=pyhd8ed1ab_0
99
+ - expat=2.5.0=h27087fc_0
100
+ - fastapi=0.92.0=pyhd8ed1ab_0
101
+ - ffmpeg=4.3=hf484d3e_0
102
+ - ffmpy=0.3.0=pyhb6f538c_0
103
+ - fftw=3.3.10=nompi_hf0379b8_106
104
+ - filelock=3.9.0=pyhd8ed1ab_0
105
+ - flask=2.2.3=pyhd8ed1ab_0
106
+ - flit-core=3.8.0=pyhd8ed1ab_0
107
+ - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
108
+ - font-ttf-inconsolata=3.000=h77eed37_0
109
+ - font-ttf-source-code-pro=2.038=h77eed37_0
110
+ - font-ttf-ubuntu=0.83=hab24e00_0
111
+ - fontconfig=2.14.2=h14ed4e7_0
112
+ - fonts-conda-ecosystem=1=0
113
+ - fonts-conda-forge=1=0
114
+ - fonttools=4.38.0=py310h5764c6d_1
115
+ - freetype=2.12.1=hca18f0e_1
116
+ - frozenlist=1.3.3=py310h5764c6d_0
117
+ - fsspec=2023.1.0=pyhd8ed1ab_0
118
+ - gds-tools=1.5.1.14=0
119
+ - gettext=0.21.1=h27087fc_0
120
+ - gflags=2.2.2=he1b5a44_1004
121
+ - glib=2.74.1=h6239696_1
122
+ - glib-tools=2.74.1=h6239696_1
123
+ - glog=0.6.0=h6f12383_0
124
+ - gmp=6.2.1=h58526e2_0
125
+ - gnutls=3.6.13=h85f3911_1
126
+ - gradio=3.19.1=pyhd8ed1ab_0
127
+ - graphite2=1.3.13=h58526e2_1001
128
+ - gst-plugins-base=1.22.0=h4243ec0_0
129
+ - gstreamer=1.22.0=h25f0c4b_0
130
+ - gstreamer-orc=0.4.33=h166bdaf_0
131
+ - h11=0.14.0=pyhd8ed1ab_0
132
+ - h2=4.1.0=pyhd8ed1ab_0
133
+ - harfbuzz=6.0.0=h8e241bc_0
134
+ - hpack=4.0.0=pyh9f0ad1d_0
135
+ - httpcore=0.16.3=pyhd8ed1ab_0
136
+ - httpx=0.23.3=pyhd8ed1ab_0
137
+ - huggingface_hub=0.12.1=pyhd8ed1ab_0
138
+ - hyperframe=6.0.1=pyhd8ed1ab_0
139
+ - icu=70.1=h27087fc_0
140
+ - idna=3.4=pyhd8ed1ab_0
141
+ - importlib-metadata=6.0.0=pyha770c72_0
142
+ - importlib_metadata=6.0.0=hd8ed1ab_0
143
+ - importlib_resources=5.12.0=pyhd8ed1ab_0
144
+ - ipykernel=6.21.2=pyh210e3f2_0
145
+ - ipython=8.10.0=pyh41d4057_0
146
+ - ipython_genutils=0.2.0=py_1
147
+ - ipywidgets=8.0.4=pyhd8ed1ab_0
148
+ - itsdangerous=2.1.2=pyhd8ed1ab_0
149
+ - jack=1.9.22=h11f4161_0
150
+ - jedi=0.18.2=pyhd8ed1ab_0
151
+ - jinja2=3.1.2=pyhd8ed1ab_1
152
+ - joblib=1.2.0=pyhd8ed1ab_0
153
+ - jpeg=9e=h0b41bf4_3
154
+ - jsonschema=4.17.3=pyhd8ed1ab_0
155
+ - jupyter=1.0.0=py310hff52083_8
156
+ - jupyter_client=8.0.3=pyhd8ed1ab_0
157
+ - jupyter_console=6.5.1=pyhd8ed1ab_0
158
+ - jupyter_core=5.2.0=py310hff52083_0
159
+ - jupyter_events=0.6.3=pyhd8ed1ab_0
160
+ - jupyter_server=2.3.0=pyhd8ed1ab_0
161
+ - jupyter_server_terminals=0.4.4=pyhd8ed1ab_1
162
+ - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
163
+ - jupyterlab_widgets=3.0.5=pyhd8ed1ab_0
164
+ - keyutils=1.6.1=h166bdaf_0
165
+ - kiwisolver=1.4.4=py310hbf28c38_1
166
+ - krb5=1.20.1=h81ceb04_0
167
+ - lame=3.100=h166bdaf_1003
168
+ - lcms2=2.14=hfd0df8a_1
169
+ - ld_impl_linux-64=2.40=h41732ed_0
170
+ - lerc=4.0.0=h27087fc_0
171
+ - libabseil=20220623.0=cxx17_h05df665_6
172
+ - libarrow=11.0.0=hc42cb68_4_cpu
173
+ - libblas=3.9.0=16_linux64_mkl
174
+ - libbrotlicommon=1.0.9=h166bdaf_8
175
+ - libbrotlidec=1.0.9=h166bdaf_8
176
+ - libbrotlienc=1.0.9=h166bdaf_8
177
+ - libcap=2.66=ha37c62d_0
178
+ - libcblas=3.9.0=16_linux64_mkl
179
+ - libclang=15.0.7=default_had23c3d_1
180
+ - libclang13=15.0.7=default_h3e3d535_1
181
+ - libcrc32c=1.1.2=h9c3ff4c_0
182
+ - libcublas=11.9.2.110=h5e84587_0
183
+ - libcublas-dev=11.9.2.110=h5c901ab_0
184
+ - libcufft=10.7.1.112=hf425ae0_0
185
+ - libcufft-dev=10.7.1.112=ha5ce4c0_0
186
+ - libcufile=1.5.1.14=0
187
+ - libcufile-dev=1.5.1.14=0
188
+ - libcups=2.3.3=h36d4200_3
189
+ - libcurand=10.3.1.124=0
190
+ - libcurand-dev=10.3.1.124=0
191
+ - libcurl=7.88.1=hdc1c0ab_0
192
+ - libcusolver=11.3.4.124=h33c3c4e_0
193
+ - libcusparse=11.7.2.124=h7538f96_0
194
+ - libcusparse-dev=11.7.2.124=hbbe9722_0
195
+ - libdb=6.2.32=h9c3ff4c_0
196
+ - libdeflate=1.17=h0b41bf4_0
197
+ - libedit=3.1.20191231=he28a2e2_2
198
+ - libev=4.33=h516909a_1
199
+ - libevent=2.1.10=h28343ad_4
200
+ - libffi=3.4.2=h7f98852_5
201
+ - libflac=1.4.2=h27087fc_0
202
+ - libgcc-ng=12.2.0=h65d4601_19
203
+ - libgcrypt=1.10.1=h166bdaf_0
204
+ - libgfortran-ng=12.2.0=h69a702a_19
205
+ - libgfortran5=12.2.0=h337968e_19
206
+ - libglib=2.74.1=h606061b_1
207
+ - libgoogle-cloud=2.7.0=h21dfe5b_1
208
+ - libgpg-error=1.46=h620e276_0
209
+ - libgrpc=1.51.1=h4fad500_1
210
+ - libhwloc=2.8.0=h32351e8_1
211
+ - libiconv=1.17=h166bdaf_0
212
+ - liblapack=3.9.0=16_linux64_mkl
213
+ - liblapacke=3.9.0=16_linux64_mkl
214
+ - libllvm15=15.0.7=hadd5161_0
215
+ - libnghttp2=1.51.0=hff17c54_0
216
+ - libnpp=11.6.3.124=hd2722f0_0
217
+ - libnpp-dev=11.6.3.124=h3c42840_0
218
+ - libnsl=2.0.0=h7f98852_0
219
+ - libnvjpeg=11.6.2.124=hd473ad6_0
220
+ - libnvjpeg-dev=11.6.2.124=hb5906b9_0
221
+ - libogg=1.3.4=h7f98852_1
222
+ - libopus=1.3.1=h7f98852_1
223
+ - libpng=1.6.39=h753d276_0
224
+ - libpq=15.2=hb675445_0
225
+ - libprotobuf=3.21.12=h3eb15da_0
226
+ - libsndfile=1.2.0=hb75c966_0
227
+ - libsodium=1.0.18=h36c2ea0_1
228
+ - libsqlite=3.40.0=h753d276_0
229
+ - libssh2=1.10.0=hf14f497_3
230
+ - libstdcxx-ng=12.2.0=h46fd767_19
231
+ - libsystemd0=252=h2a991cd_0
232
+ - libthrift=0.16.0=he500d00_2
233
+ - libtiff=4.5.0=h6adf6a1_2
234
+ - libtool=2.4.7=h27087fc_0
235
+ - libudev1=252=h166bdaf_0
236
+ - libutf8proc=2.8.0=h166bdaf_0
237
+ - libuuid=2.32.1=h7f98852_1000
238
+ - libvorbis=1.3.7=h9c3ff4c_0
239
+ - libwebp-base=1.2.4=h166bdaf_0
240
+ - libxcb=1.13=h7f98852_1004
241
+ - libxkbcommon=1.5.0=h79f4944_0
242
+ - libxml2=2.10.3=h7463322_0
243
+ - libzlib=1.2.13=h166bdaf_4
244
+ - linkify-it-py=2.0.0=pyhd8ed1ab_0
245
+ - llvm-openmp=15.0.7=h0cdce71_0
246
+ - lz4-c=1.9.4=hcb278e6_0
247
+ - markdown-it-py=2.1.0=pyhd8ed1ab_0
248
+ - markupsafe=2.1.2=py310h1fa729e_0
249
+ - matplotlib-base=3.7.0=py310he60537e_0
250
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
251
+ - mdit-py-plugins=0.3.3=pyhd8ed1ab_0
252
+ - mdurl=0.1.0=pyhd8ed1ab_0
253
+ - mistune=2.0.5=pyhd8ed1ab_0
254
+ - mkl=2022.1.0=h84fe81f_915
255
+ - mkl-devel=2022.1.0=ha770c72_916
256
+ - mkl-include=2022.1.0=h84fe81f_915
257
+ - mpg123=1.31.2=hcb278e6_0
258
+ - multidict=6.0.4=py310h1fa729e_0
259
+ - multiprocess=0.70.14=py310h5764c6d_3
260
+ - munkres=1.1.4=pyh9f0ad1d_0
261
+ - mysql-common=8.0.32=ha901b37_0
262
+ - mysql-libs=8.0.32=hd7da12d_0
263
+ - nbclassic=0.5.2=pyhd8ed1ab_0
264
+ - nbclient=0.7.2=pyhd8ed1ab_0
265
+ - nbconvert=7.2.9=pyhd8ed1ab_0
266
+ - nbconvert-core=7.2.9=pyhd8ed1ab_0
267
+ - nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
268
+ - nbformat=5.7.3=pyhd8ed1ab_0
269
+ - ncurses=6.3=h27087fc_1
270
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
271
+ - nettle=3.6=he412f7d_0
272
+ - notebook=6.5.2=pyha770c72_1
273
+ - notebook-shim=0.2.2=pyhd8ed1ab_0
274
+ - nsight-compute=2022.4.1.6=0
275
+ - nspr=4.35=h27087fc_0
276
+ - nss=3.88=he45b914_0
277
+ - numpy=1.24.2=py310h8deb116_0
278
+ - openh264=2.1.1=h780b84a_0
279
+ - openjpeg=2.5.0=hfec8fc6_2
280
+ - openssl=3.0.8=h0b41bf4_0
281
+ - orc=1.8.2=hfdbbad2_2
282
+ - orjson=3.8.5=py310h38b9cce_1
283
+ - packaging=23.0=pyhd8ed1ab_0
284
+ - pandas=1.5.3=py310h9b08913_0
285
+ - pandoc=2.19.2=h32600fe_1
286
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
287
+ - parquet-cpp=1.5.1=2
288
+ - parso=0.8.3=pyhd8ed1ab_0
289
+ - pcre2=10.40=hc3806b6_0
290
+ - pexpect=4.8.0=pyh1a96a4e_2
291
+ - pickleshare=0.7.5=py_1003
292
+ - pillow=9.4.0=py310h023d228_1
293
+ - pip=23.0.1=pyhd8ed1ab_0
294
+ - pixman=0.40.0=h36c2ea0_0
295
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
296
+ - platformdirs=3.0.0=pyhd8ed1ab_0
297
+ - ply=3.11=py_1
298
+ - prometheus_client=0.16.0=pyhd8ed1ab_0
299
+ - prompt-toolkit=3.0.36=pyha770c72_0
300
+ - prompt_toolkit=3.0.36=hd8ed1ab_0
301
+ - psutil=5.9.4=py310h5764c6d_0
302
+ - pthread-stubs=0.4=h36c2ea0_1001
303
+ - ptyprocess=0.7.0=pyhd3deb0d_0
304
+ - pulseaudio=16.1=ha8d29e2_1
305
+ - pure_eval=0.2.2=pyhd8ed1ab_0
306
+ - pyarrow=11.0.0=py310h633f555_4_cpu
307
+ - pycparser=2.21=pyhd8ed1ab_0
308
+ - pycryptodome=3.16.0=py310h1419917_0
309
+ - pydantic=1.10.5=py310h1fa729e_0
310
+ - pydub=0.25.1=pyhd8ed1ab_0
311
+ - pygments=2.14.0=pyhd8ed1ab_0
312
+ - pyopenssl=23.0.0=pyhd8ed1ab_0
313
+ - pyparsing=3.0.9=pyhd8ed1ab_0
314
+ - pyqt=5.15.7=py310hab646b1_3
315
+ - pyqt5-sip=12.11.0=py310heca2aa9_3
316
+ - pyrsistent=0.19.3=py310h1fa729e_0
317
+ - pysocks=1.7.1=pyha2e5f31_6
318
+ - python=3.10.9=he550d4f_0_cpython
319
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
320
+ - python-fastjsonschema=2.16.2=pyhd8ed1ab_0
321
+ - python-json-logger=2.0.6=pyhd8ed1ab_0
322
+ - python-multipart=0.0.5=py_0
323
+ - python-xxhash=3.2.0=py310h1fa729e_0
324
+ - python_abi=3.10=3_cp310
325
+ - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
326
+ - pytorch-cuda=11.6=h867d48c_1
327
+ - pytorch-mutex=1.0=cuda
328
+ - pytz=2022.7.1=pyhd8ed1ab_0
329
+ - pyyaml=6.0=py310h5764c6d_5
330
+ - pyzmq=25.0.0=py310h059b190_0
331
+ - qt-main=5.15.8=h5d23da1_6
332
+ - qtconsole=5.4.0=pyhd8ed1ab_0
333
+ - qtconsole-base=5.4.0=pyha770c72_0
334
+ - qtpy=2.3.0=pyhd8ed1ab_0
335
+ - re2=2023.02.01=hcb278e6_0
336
+ - readline=8.1.2=h0f457ee_0
337
+ - regex=2022.10.31=py310h5764c6d_0
338
+ - requests=2.28.2=pyhd8ed1ab_0
339
+ - responses=0.18.0=pyhd8ed1ab_0
340
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_0
341
+ - rfc3986=1.5.0=pyhd8ed1ab_0
342
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
343
+ - s2n=1.3.35=h3358134_0
344
+ - sacremoses=0.0.53=pyhd8ed1ab_0
345
+ - send2trash=1.8.0=pyhd8ed1ab_0
346
+ - setuptools=67.3.2=pyhd8ed1ab_0
347
+ - sip=6.7.7=py310heca2aa9_0
348
+ - six=1.16.0=pyh6c4a22f_0
349
+ - snappy=1.1.9=hbd366e4_2
350
+ - sniffio=1.3.0=pyhd8ed1ab_0
351
+ - soupsieve=2.3.2.post1=pyhd8ed1ab_0
352
+ - stack_data=0.6.2=pyhd8ed1ab_0
353
+ - starlette=0.25.0=pyhd8ed1ab_0
354
+ - tbb=2021.7.0=h924138e_1
355
+ - terminado=0.17.1=pyh41d4057_0
356
+ - tinycss2=1.2.1=pyhd8ed1ab_0
357
+ - tk=8.6.12=h27826a3_0
358
+ - tokenizers=0.13.2=py310he1f1126_0
359
+ - toml=0.10.2=pyhd8ed1ab_0
360
+ - toolz=0.12.0=pyhd8ed1ab_0
361
+ - torchaudio=0.13.1=py310_cu116
362
+ - torchvision=0.14.1=py310_cu116
363
+ - tornado=6.2=py310h5764c6d_1
364
+ - tqdm=4.64.1=pyhd8ed1ab_0
365
+ - traitlets=5.9.0=pyhd8ed1ab_0
366
+ - transformers=4.26.1=pyhd8ed1ab_0
367
+ - typing-extensions=4.4.0=hd8ed1ab_0
368
+ - typing_extensions=4.4.0=pyha770c72_0
369
+ - tzdata=2022g=h191b570_0
370
+ - uc-micro-py=1.0.1=pyhd8ed1ab_0
371
+ - unicodedata2=15.0.0=py310h5764c6d_0
372
+ - urllib3=1.26.14=pyhd8ed1ab_0
373
+ - uvicorn=0.20.0=py310hff52083_1
374
+ - wcwidth=0.2.6=pyhd8ed1ab_0
375
+ - webencodings=0.5.1=py_1
376
+ - websocket-client=1.5.1=pyhd8ed1ab_0
377
+ - websockets=10.4=py310h5764c6d_1
378
+ - werkzeug=2.2.3=pyhd8ed1ab_0
379
+ - wheel=0.38.4=pyhd8ed1ab_0
380
+ - widgetsnbextension=4.0.5=pyhd8ed1ab_0
381
+ - xcb-util=0.4.0=h166bdaf_0
382
+ - xcb-util-image=0.4.0=h166bdaf_0
383
+ - xcb-util-keysyms=0.4.0=h166bdaf_0
384
+ - xcb-util-renderutil=0.3.9=h166bdaf_0
385
+ - xcb-util-wm=0.4.1=h166bdaf_0
386
+ - xorg-kbproto=1.0.7=h7f98852_1002
387
+ - xorg-libice=1.0.10=h7f98852_0
388
+ - xorg-libsm=1.2.3=hd9c2040_1000
389
+ - xorg-libx11=1.7.2=h7f98852_0
390
+ - xorg-libxau=1.0.9=h7f98852_0
391
+ - xorg-libxdmcp=1.1.3=h7f98852_0
392
+ - xorg-libxext=1.3.4=h7f98852_1
393
+ - xorg-libxrender=0.9.10=h7f98852_1003
394
+ - xorg-renderproto=0.11.1=h7f98852_1002
395
+ - xorg-xextproto=7.3.0=h7f98852_1002
396
+ - xorg-xproto=7.0.31=h7f98852_1007
397
+ - xxhash=0.8.1=h0b41bf4_0
398
+ - xz=5.2.6=h166bdaf_0
399
+ - yaml=0.2.5=h7f98852_2
400
+ - yarl=1.8.2=py310h5764c6d_0
401
+ - zeromq=4.3.4=h9c3ff4c_1
402
+ - zipp=3.14.0=pyhd8ed1ab_0
403
+ - zlib=1.2.13=h166bdaf_4
404
+ - zstd=1.5.2=h3eb15da_6
405
+ - pip:
406
+ - fancy-einsum==0.0.3
sampling.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch as t
3
+ import torch.nn.functional as F
4
+ import transformers
5
+ import numpy as np
6
+
7
+ gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
8
+ tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
9
+
10
+ def apply_sampling_methods(
11
+ input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0
12
+ ) -> int:
13
+ '''
14
+ Return the next token, sampled from the model's probability distribution with modifiers.
15
+ x
16
+ input_ids: shape (seq,)
17
+ '''
18
+ assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
19
+ assert temperature >= 0, "Temperature should be non-negative"
20
+ assert 0 <= top_p <= 1.0, "Top-p must be a probability"
21
+ assert 0 <= top_k, "Top-k must be non-negative"
22
+ assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"
23
+
24
+ if temperature == 0:
25
+ return greedy_search(logits)
26
+ if temperature != 1.0:
27
+ logits = apply_temperature(logits, temperature)
28
+ if freq_penalty != 0.0:
29
+ logits = apply_freq_penalty(input_ids, logits, freq_penalty)
30
+ if top_k > 0:
31
+ return sample_top_k(logits, top_k)
32
+ if top_p > 0:
33
+ return sample_top_p(logits, top_p)
34
+ return sample_basic(logits)
35
+
36
+ def sample_tokens(
37
+ model,
38
+ tokenizer,
39
+ initial_text: str,
40
+ max_tokens_generated: int = 30,
41
+ **kwargs
42
+ ) -> str:
43
+ '''
44
+ Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.
45
+
46
+ Return: the prompt and continuation concatenated
47
+ '''
48
+ model.eval()
49
+ input_ids: list = tokenizer.encode(initial_text)
50
+ generated = []
51
+ device = next(model.parameters()).device
52
+ for _ in range(max_tokens_generated):
53
+ new_input_ids = t.tensor(np.array(input_ids + generated), dtype=t.int64, device=device)
54
+ new_input_ids_truncated = new_input_ids[-min(tokenizer.model_max_length, new_input_ids.shape[0]):].unsqueeze(0)
55
+ output = model(new_input_ids_truncated)
56
+ all_logits = output if isinstance(output, t.Tensor) else output.logits
57
+ logits = all_logits[0, -1] #batch=0, seq_len=-1 -> returns vocab_size
58
+ new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
59
+ generated.append(new_token)
60
+ if new_token == getattr(tokenizer, "eos_token_id", None):
61
+ break
62
+ return tokenizer.decode(input_ids + generated)
63
+
64
+ # %%
65
+ def greedy_search(logits: t.Tensor) -> int:
66
+ '''
67
+ logits: shape (vocab_size, )
68
+
69
+ Return: the most likely token (as an integer)
70
+ '''
71
+ return logits.argmax().numpy()
72
+
73
+ if __name__ == "__main__":
74
+ prompt = "Jingle bells, jingle bells, jingle all the way"
75
+ print("Greedy decoding with prompt: ", prompt)
76
+ output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
77
+ print(f"Your model said: {output}")
78
+ expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
79
+ assert output == expected
80
+
81
+ print("Greedy decoding a second time (should be deterministic): ")
82
+ output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
83
+ print(f"Your model said: {output}")
84
+ expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
85
+ assert output == expected
86
+
87
+ print("Tests passed!")
88
+ # %%
89
+ def sample_basic(logits: t.Tensor) -> int:
90
+ '''
91
+ logits: shape (vocab_size, ) - unnormalized log-probabilities
92
+
93
+ Return: a sampled token
94
+ '''
95
+ return t.distributions.categorical.Categorical(logits=logits).sample()
96
+
97
+ if __name__ == "__main__":
98
+ N = 20000
99
+ probs = t.linspace(0, 0.4, 5)
100
+ unnormalized_logits = probs.log() + 1.2345
101
+ samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)])
102
+ counts = t.bincount(samples, minlength=len(probs)) / N
103
+ print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
104
+ t.testing.assert_close(counts, probs, atol=0.01, rtol=0)
105
+ print("Tests passed!")
106
+ # %%
107
+ def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
108
+ '''
109
+ logits: shape (vocab_size, )
110
+
111
+ Return: shape (vocab_size, )
112
+ '''
113
+ assert temperature > 0
114
+ return logits / temperature
115
+
116
+ if __name__ == '__main__':
117
+ logits = t.tensor([1, 2]).log()
118
+ cold_logits = apply_temperature(logits, 0.001)
119
+ print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
120
+ t.testing.assert_close(cold_logits, 1000.0 * logits)
121
+ hot_logits = apply_temperature(logits, 1000.0)
122
+ print("A high temperature flattens the distribution: ", hot_logits)
123
+ t.testing.assert_close(hot_logits, 0.001 * logits)
124
+ print("Tests passed!")
125
+
126
+ # %%
127
+ def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
128
+ '''
129
+ input_ids: shape (seq, )
130
+ logits: shape (vocab_size, )
131
+
132
+ Return: shape (vocab_size, )
133
+ '''
134
+ count = input_ids.bincount(minlength=len(logits))
135
+ logits -= count * freq_penalty
136
+ return logits
137
+
138
+ if __name__ == "__main__":
139
+ bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
140
+ input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze()
141
+ logits = t.ones(tokenizer.vocab_size)
142
+ penalized_logits = apply_freq_penalty(input_ids, logits, 2.0)
143
+ assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space"
144
+ assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space"
145
+ print("Tests passed!")
146
+ # %%
147
+ N_RUNS = 0
148
+ your_prompt = "Jingle bells, jingle bells, jingle all the way"
149
+ cases = [
150
+ ("High freq penalty", dict(freq_penalty=100.0)),
151
+ ("Negative freq penalty", dict(freq_penalty=-1.0)),
152
+ ("Too hot!", dict(temperature=2.0)),
153
+ ("Pleasantly cool", dict(temperature=0.7)),
154
+ ("Pleasantly warm", dict(temperature=0.9)),
155
+ ("Too cold!", dict(temperature=0.01)),
156
+ ]
157
+ for (name, kwargs) in cases:
158
+ for i in range(N_RUNS):
159
+ output = sample_tokens(gpt, tokenizer, your_prompt, max_tokens_generated=24, **kwargs)
160
+ print(f"Sample {i} with: {name} ({kwargs}):")
161
+ print(f"Your model said: {repr(output)}\n")
162
+ # %%
163
+ def sample_top_k(logits: t.Tensor, top_k: int) -> int:
164
+ '''
165
+ logits: shape (vocab_size, ) - unnormalized log-probabilities
166
+ top_k: only consider this many of the most likely tokens for sampling
167
+
168
+ Return: a sampled token
169
+ '''
170
+ values, indices = t.topk(logits, top_k)
171
+ return indices[sample_basic(values)].item()
172
+
173
+ if __name__ == "__main__":
174
+ N = 50000
175
+ k = 3
176
+ probs = t.linspace(0, 0.4, 5)
177
+ unnormalized_logits = probs.log() + 1.2345
178
+ samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)])
179
+ counts = t.bincount(samples, minlength=len(probs)) / N
180
+ expected = probs.clone()
181
+ expected[:-k] = 0
182
+ expected /= expected.sum()
183
+ print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
184
+ t.testing.assert_close(counts, expected, atol=0.01, rtol=0)
185
+ print("Tests passed!")
186
+ # %%
187
+ if __name__ == "__main__":
188
+ your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
189
+ output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
190
+ print(f"Your model said: {repr(output)}")
191
+ # %%
192
+ def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
193
+ '''
194
+ logits: shape (vocab_size, ) - unnormalized log-probabilities
195
+
196
+ Return: a sampled token
197
+ '''
198
+ probs = t.exp(logits.double()) / t.exp(logits.double()).sum()
199
+ sorted_probs, sorted_indices = probs.sort(descending=True)
200
+ cum_probs = sorted_probs.cumsum(-1)
201
+ last_index = max(min_tokens_to_keep, t.where(cum_probs >= top_p)[0][0].numpy() + 1)
202
+ masked_probs = sorted_probs[:last_index]
203
+ sample = t.distributions.categorical.Categorical(probs=t.tensor(masked_probs)).sample()
204
+ return sorted_indices[sample]
205
+
206
+ if __name__ == "__main__":
207
+ N = 2000
208
+ unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
209
+ samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)])
210
+ counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
211
+ print("top_p of 0.5 or lower should only return token 2: ", counts)
212
+ assert counts[0] == 0 and counts[1] == 0
213
+
214
+ N = 2000
215
+ unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
216
+ samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)])
217
+ counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
218
+ print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts)
219
+ assert counts[0] == 0
220
+
221
+ N = 50000
222
+ top_p = 0.71
223
+ probs = t.linspace(0, 0.4, 5)
224
+ unnormalized_logits = probs.log() + 1.2345
225
+ samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)])
226
+ counts = t.bincount(samples, minlength=len(probs)) / N
227
+ expected = probs.clone()
228
+ expected[0:2] = 0
229
+ expected /= expected.sum()
230
+ print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
231
+ t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0)
232
+
233
+ print("All tests passed!")
234
+ # %%
235
+ if __name__ == "__main__":
236
+ your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
237
+ output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
238
+ print(f"Your model said: {repr(output)}")
239
+ # %%
shakespeare_demo.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import yaml
3
+ import torch as t
4
+ import gradio as gr
5
+ import re
6
+ from word_data import WordData
7
+ import sampling
8
+ import transformer_replication
9
+ #%%
10
+ MAIN = __name__ == '__main__'
11
+ device = 'cuda' if t.cuda.is_available() else 'cpu'
12
+ #%%
13
+ shakespeare = WordData.from_file(
14
+ '100-0.txt', device=device, start="1\n", end='ALL’S WELL THAT ENDS WELL'
15
+ )
16
+ if MAIN:
17
+ print('Vocab size: ', len(shakespeare.vocab))
18
+ #%%
19
+ #%%
20
+ with open('config.yaml', 'r') as f:
21
+ yaml_cfg = yaml.safe_load(f)
22
+ #%%
23
+ with open('model_state_dict.pt') as f:
24
+ state_dict = t.load(
25
+ 'model_state_dict.pt'
26
+ )
27
+ #%%
28
+ base_config = transformer_replication.TransformerConfig(
29
+ num_layers=yaml_cfg['num_layers']['value'],
30
+ num_heads=yaml_cfg['num_heads']['value'],
31
+ vocab_size=len(shakespeare.vocab),
32
+ hidden_size=yaml_cfg['hidden_size']['value'],
33
+ max_seq_len=yaml_cfg['max_seq_len']['value'],
34
+ dropout=yaml_cfg['dropout']['value'],
35
+ )
36
+ shakespeare.model_max_length = yaml_cfg['max_seq_len']['value']
37
+ model = transformer_replication.DecoderOnlyTransformer(base_config)
38
+
39
+ model.load_state_dict(state_dict)
40
+
41
+ #%%
42
+ def generate(
43
+ text: str, max_tokens: int, temperature: float,
44
+ top_k: int,
45
+ ) -> str:
46
+ return sampling.sample_tokens(
47
+ model,
48
+ shakespeare,
49
+ text,
50
+ max_tokens_generated=max_tokens,
51
+ temperature=temperature,
52
+ top_k=top_k,
53
+ )
54
+
55
+ #%%
56
+ def safe_generate(
57
+ text: str, max_tokens: int = 300, temperature: float = 1.0,
58
+ top_k: int = 20,
59
+ ) -> str:
60
+ try:
61
+ raw = generate(
62
+ text, max_tokens=max_tokens, temperature=temperature, top_k=top_k,
63
+ )
64
+ match = re.match(r"(?P<start>\D*)\d+\n", raw)
65
+ if match is None:
66
+ return raw
67
+ return match.group('start')
68
+ except KeyError as e:
69
+ return f"I'm sorry, {str(e)} is not in Shakespeare's vocabulary"
70
+ #%%
71
+ examples = [
72
+ ["I sang a beautiful song"],
73
+ ["To be free is to"],
74
+ ["How I love thee"],
75
+ ]
76
+ #%%
77
+ if MAIN:
78
+ print(safe_generate('How I love thee'))
79
+ #%%
80
+ def make_demo():
81
+ demo = gr.Interface(
82
+ fn=safe_generate,
83
+ inputs=[
84
+ gr.components.Textbox(lines=5, label="Input Text"),
85
+ gr.components.Slider(
86
+ label='max tokens generated', minimum=1, maximum=1000,
87
+ value=300, step=1,
88
+ ),
89
+ gr.components.Slider(
90
+ label='temperature', minimum=0, maximum=2, value=1, step=0.1,
91
+ ),
92
+ gr.components.Slider(
93
+ label='top_k', minimum=1, maximum=100, value=10, step=1,
94
+ ),
95
+ ],
96
+ outputs=gr.components.Textbox(label="Generated Text"),
97
+ examples=examples
98
+ )
99
+ demo.launch()
100
+ # %%
101
+ '''
102
+ FIXME:
103
+ * deploy to heroku
104
+ * link from github home
105
+ '''
transformer_replication.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import transformers
3
+ import torch as t
4
+ import torch.nn as nn
5
+ from typing import Union, List
6
+ from fancy_einsum import einsum
7
+ import torch as t
8
+ from torch import nn
9
+ from torchvision import datasets, transforms
10
+ from torch.utils.data import DataLoader
11
+ from typing import Union, Optional, Callable, Tuple
12
+ import numpy as np
13
+ from einops import rearrange
14
+ import time
15
+ # %%
16
+ tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
17
+ if __name__ == "__main__":
18
+ print(tokenizer("hello meg"))
19
+ print(tokenizer.encode("hello meg"))
20
+ print(tokenizer.decode([31373, 17243]))
21
+ print(tokenizer.tokenize("hello meg"))
22
+ print(f"'{tokenizer.decode(17243)}'")
23
+ # %%
24
+ class Embedding(nn.Module):
25
+
26
+ def __init__(self, num_embeddings: int, embedding_dim: int):
27
+ super().__init__()
28
+ self.num_embeddings = num_embeddings
29
+ self.embedding_dim = embedding_dim
30
+
31
+ self.weight = nn.Parameter(t.randn((self.num_embeddings, self.embedding_dim)))
32
+
33
+ def forward(self, x: t.LongTensor) -> t.Tensor:
34
+ '''For each integer in the input, return that row of the embedding.
35
+ '''
36
+ #return einsum('num_embeddings embedding_dim, i num_embeddings -> i embedding_dim', self.weight, nn.functional.one_hot(x, num_classes=self.num_embeddings).float())
37
+ return self.weight[x]
38
+
39
+ def extra_repr(self) -> str:
40
+ return f"{self.num_embeddings}, {self.embedding_dim}"
41
+
42
+ # %%
43
+ #TODO positional encoding
44
+ class PositionalEncoding(nn.Module):
45
+
46
+ def __init__(self, max_seq_len: int, embedding_dim: int):
47
+ super().__init__()
48
+ # Defining our positional encoding array, with `max_seq_len` rows
49
+ # This is an advantage of using sinusoidal encoding: we can easily expand to sequences of greater length without adding more learned params
50
+ angles = t.outer(t.arange(max_seq_len), 1 / 10000 ** (2 * t.arange(embedding_dim//2) / embedding_dim))
51
+ pe = t.zeros((max_seq_len, embedding_dim))
52
+ pe[:, ::2] = t.sin(angles)
53
+ pe[:, 1::2] = t.cos(angles)
54
+ # Register array as a buffer, rather than parameter (we don't want it to be updated by gradient descent)
55
+ self.register_buffer('pe', pe)
56
+
57
+ def forward(self, x: t.Tensor) -> t.Tensor:
58
+ """
59
+ x: shape (batch, seq_len, embedding_dim)
60
+ """
61
+ batch, seq_len, embedding_dim = x.shape
62
+ # We slice the positional encoding, so it's the same shape as x
63
+ # This is equivalent to just using an nn.Embedding, but having the input be t.arange(seq_len)
64
+ return x + self.pe[:seq_len, :] # type: ignore
65
+
66
+
67
+ # %%
68
+ class LayerNorm(nn.Module):
69
+
70
+ def __init__(self, normalized_shape: Union[int, List[int]], eps: float = 1e-05, elementwise_affine: bool = True):
71
+ super().__init__()
72
+ self.normalized_shape = normalized_shape
73
+ self.eps = eps
74
+ self.elementwise_affine = elementwise_affine
75
+
76
+ if self.elementwise_affine:
77
+ self.weight = nn.Parameter(t.ones(normalized_shape))
78
+ self.bias = nn.Parameter(t.zeros(normalized_shape))
79
+
80
+ def forward(self, x: t.Tensor) -> t.Tensor:
81
+ normalized_shape_dims = 1 if isinstance(self.normalized_shape, int) else len(self.normalized_shape)
82
+ x_mean = x.mean(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True) # complement of the normalised shape
83
+ x_var = x.var(dim=list(range(x.dim()))[-normalized_shape_dims:], keepdim=True, unbiased=False) # complement of the normalised shape
84
+ x_scaled = (x - x_mean) / t.sqrt(x_var + self.eps)
85
+ if self.elementwise_affine:
86
+ return x_scaled * self.weight + self.bias
87
+ return x_scaled
88
+
89
+ def extra_repr(self) -> str:
90
+ pass
91
+
92
+ # %%
93
+ from dataclasses import dataclass
94
+
95
+ @dataclass(frozen=True)
96
+ class TransformerConfig:
97
+ '''Constants used throughout your decoder-only transformer model.'''
98
+
99
+ num_layers: int
100
+ num_heads: int
101
+ vocab_size: int
102
+ hidden_size: int
103
+ max_seq_len: int
104
+ dropout: float = 0.1
105
+ layer_norm_epsilon: float = 1e-05
106
+ # %%
107
+ import attention_replication
108
+
109
+ class BertMLP(nn.Module):
110
+ def __init__(self, config: TransformerConfig):
111
+ super().__init__()
112
+ self.linear1 = nn.Linear(config.hidden_size, 4 * config.hidden_size)
113
+ self.gelu = nn.GELU()
114
+ self.linear2 = nn.Linear(4 * config.hidden_size, config.hidden_size)
115
+ self.dropout = nn.Dropout(config.dropout)
116
+
117
+ def forward(self, x: t.Tensor) -> t.Tensor:
118
+ x = self.linear1(x)
119
+ x = self.gelu(x)
120
+ x = self.linear2(x)
121
+ x = self.dropout(x)
122
+ return x
123
+
124
+ class DecoderBlock(nn.Module):
125
+
126
+ def __init__(self, config: TransformerConfig):
127
+ super().__init__()
128
+ self.attention = attention_replication.MultiheadMaskedAttention(config.hidden_size, config.num_heads)
129
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
130
+ self.mlp = BertMLP(config)
131
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
132
+
133
+ def forward(self, x: t.Tensor) -> t.Tensor:
134
+ y = self.attention(x)
135
+ y = self.layer_norm1(y)
136
+ x = x + y
137
+ z = self.mlp(x)
138
+ z = self.layer_norm2(z)
139
+ x = x + z
140
+ return x
141
+
142
+ class DecoderOnlyTransformer(nn.Module):
143
+
144
+ def __init__(self, config: TransformerConfig):
145
+ super().__init__()
146
+ self.token_embedding = Embedding(config.vocab_size, config.hidden_size)
147
+ self.positional_embedding = PositionalEncoding(config.max_seq_len, config.hidden_size)
148
+ self.dropout = nn.Dropout(config.dropout)
149
+ self.bert_blocks = nn.Sequential(*[DecoderBlock(config) for _ in range(config.num_layers)])
150
+ self.layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_epsilon)
151
+
152
+ def forward(self, x: t.Tensor) -> t.Tensor:
153
+ x = self.token_embedding(x)
154
+ x = self.positional_embedding(x)
155
+ x = self.dropout(x)
156
+ for block in self.bert_blocks:
157
+ x = block(x)
158
+ x = self.layer_norm(x)
159
+ x = einsum('num_embeddings embedding_dim,batch seq_len embedding_dim ->batch seq_len num_embeddings', self.token_embedding.weight, x)
160
+ return x
161
+
162
+ # %%
163
+ from torch.utils.data import Dataset
164
+
165
+ class CustomTextDataset(Dataset):
166
+ def __init__(self, texts, labels):
167
+ self.labels = labels
168
+ self.texts = texts
169
+
170
+ @staticmethod
171
+ def from_config(config, samples):
172
+ texts = [t.randint(high=config.vocab_size, size=(config.max_seq_len,)) for _ in range(samples)]
173
+ labels = [t.flip(text, (0,)) for text in texts]
174
+ return CustomTextDataset(texts, labels)
175
+
176
+ def __len__(self):
177
+ return len(self.labels)
178
+
179
+ def __getitem__(self, idx):
180
+ label = self.labels[idx]
181
+ text = self.texts[idx]
182
+ sample = (text, label)
183
+ return sample
word_data.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Optional, Union
3
+ import requests
4
+ from torch.utils.data import Dataset
5
+ import torch as t
6
+
7
+
8
+ class WordsDataset(Dataset):
9
+ def __init__(self, texts, labels):
10
+ self.texts = texts
11
+ self.labels = labels
12
+
13
+ def __len__(self):
14
+ return len(self.labels)
15
+
16
+ def __getitem__(self, idx):
17
+ label = self.labels[idx]
18
+ text = self.texts[idx]
19
+ sample = (text, label)
20
+ return sample
21
+
22
+ #%%
23
+ def tokenize(text):
24
+ return re.split(r"\b", text)
25
+
26
+ def _remove_duplicates(text, string=" "):
27
+ if string + string in text:
28
+ text = text.replace(string + string, string)
29
+ return _remove_duplicates(text, string)
30
+ return text
31
+
32
+ def remove_duplicates(text):
33
+ text = _remove_duplicates(text, ' ')
34
+ text = _remove_duplicates(text, '\n')
35
+ return text
36
+
37
+ # %%
38
+ class WordData():
39
+ def __init__(self, text, start, end, device):
40
+ self.complete_text = remove_duplicates(text)
41
+ if start is not None and end is not None:
42
+ self.complete_text = self.get_excerpt(start, end)
43
+ self.complete_tokens = tokenize(self.complete_text)
44
+ self.vocab = sorted(set(self.complete_tokens))
45
+ self.token_to_id = dict(zip(self.vocab, list(range(len(self.vocab)))))
46
+ self.id_to_token = dict(zip(list(range(len(self.vocab))), self.vocab))
47
+ self.model_max_length = None
48
+ self.device = device
49
+
50
+ @staticmethod
51
+ def from_link(link, device, start=None, end=None):
52
+ return WordData(
53
+ requests.get(link).content.decode('utf-8'),
54
+ start,
55
+ end,
56
+ device=device
57
+ )
58
+
59
+ @staticmethod
60
+ def from_file(filename, device, start=None, end=None):
61
+ with open(filename, encoding='utf-8') as f:
62
+ text = f.read()
63
+ return WordData(text, start, end, device=device)
64
+
65
+ def get_excerpt(self, start="THE SONNETS", end="THE END", text=None):
66
+ if text is None:
67
+ text = self.complete_text
68
+ assert start in text, f'get_excerpt: cannot find {start} in text'
69
+ l_stripped = text.split(start, maxsplit=1)[1]
70
+ assert end in l_stripped, f'get_excerpt: cannot find {end} in text'
71
+ r_stripped = l_stripped.split(end, maxsplit=1)[0]
72
+ return r_stripped
73
+
74
+ def generate_autoregressive_dataset(self, sequence_length, text=None):
75
+ self.model_max_length = sequence_length
76
+ if text is None:
77
+ text = self.complete_text
78
+ token_ids = self.encode(text, return_tensors="pt")
79
+ inputs = [token_ids[i:i + sequence_length] for i in range(len(token_ids) - sequence_length)]
80
+ labels = [token_ids[i + 1:i + 1 + sequence_length] for i in range(len(token_ids) - sequence_length)]
81
+ return WordsDataset(inputs, labels)
82
+
83
+ def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
84
+ '''
85
+ Tokenizes initial_text, then returns the token ids.
86
+
87
+ Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
88
+ '''
89
+ tokens = tokenize(initial_text)
90
+ token_ids = [self.token_to_id[t] for t in tokens]
91
+ if return_tensors == "pt":
92
+ return t.tensor(token_ids, device=self.device)
93
+ return token_ids
94
+
95
+ def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
96
+ '''
97
+ Converts ids to a list of tokens, then joins them into a single string.
98
+ '''
99
+ tokens = [self.id_to_token[int(i)] for i in list_of_ids]
100
+ return "".join(tokens)