Sin2pi commited on
Commit
446e362
·
verified ·
1 Parent(s): 9465621

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +68 -68
model_simple.py CHANGED
@@ -9,7 +9,7 @@ import numpy as np
9
  from datetime import datetime
10
  from dataclasses import dataclass
11
  from torch.nn.functional import scaled_dot_product_attention
12
-
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float32
15
  warnings.filterwarnings("ignore")
@@ -31,7 +31,7 @@ class rotary(nn.Module):
31
  self.dims = dims
32
  self.head = head
33
  self.head_dim = dims // head
34
- self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
35
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
36
 
37
  def _compute_freqs_base(self):
@@ -53,6 +53,16 @@ class rotary(nn.Module):
53
  x1 = x1.view(orig_shape)
54
  return torch.cat([x1.type_as(x), x2], dim=-1)
55
 
 
 
 
 
 
 
 
 
 
 
56
  def qkv_init(dims: int, head: int):
57
  head_dim = dims // head
58
  q = nn.Linear(dims, dims)
@@ -60,60 +70,51 @@ def qkv_init(dims: int, head: int):
60
  v = nn.Linear(dims, dims)
61
  o = nn.Linear(dims, dims)
62
  lna = nn.LayerNorm(dims, bias=False)
63
- lnb = nn.LayerNorm(head_dim, bias=False)
64
- return q, k, v, o, lna, lnb
65
-
66
- def create_qkv(dims, head, q, k, v, x, xa):
67
- head_dim = dims // head
68
- scale = head_dim ** -0.25
69
- q = q(x) * scale
70
- k = k(xa) * scale
71
- v = v(xa)
72
- batch, ctx, dims = x.shape
73
- def _shape(tensor):
74
- return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
75
- return _shape(q), _shape(k), _shape(v)
76
 
77
- def calculate_attention(q, k, v, mask=None, temperature=1.0):
78
  scaled_q = q
79
- if temperature != 1.0 and temperature > 0:
80
- scaled_q = q * (1.0 / temperature)**.5
81
-
82
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
83
  return out
84
 
85
  class LocalOut(nn.Module):
86
- def __init__(self, head_dim: int):
87
  super().__init__()
88
- self.head_dim = head_dim
89
  self.query_module = nn.Linear(head_dim, head_dim)
90
  self.key_module = nn.Linear(head_dim, head_dim)
91
  self.value_module = nn.Linear(head_dim, head_dim)
92
  self.out_proj = nn.Linear(head_dim, head_dim)
93
-
94
- def _reshape_to_output(self, x):
95
- return x
96
 
97
- class attentiona(nn.Module):
98
- def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
99
- super(attentiona, self).__init__()
100
- self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
 
 
 
 
101
  self.dims = dims
102
  self.head = head
103
- self.head_dim = dims // head
104
  self.max_iter = max_iter
105
  self.threshold = nn.Parameter(torch.tensor(threshold))
 
106
  self.factor = nn.Parameter(torch.tensor(factor))
107
- self.dropout = dropout
108
- self.lnc = nn.LayerNorm(self.head_dim, bias=False)
109
- self.lnd = nn.LayerNorm(self.head_dim, bias=False)
110
- self.attn_local = LocalOut(self.head_dim)
111
 
112
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
113
- z = default(xa, x)
114
- q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
 
 
115
 
116
  iteration = 0
 
117
  prev_out = torch.zeros_like(q)
118
  attn_out = torch.zeros_like(q)
119
  threshold = self.threshold.item()
@@ -121,7 +122,7 @@ class attentiona(nn.Module):
121
  qcur = q
122
 
123
  while iteration < self.max_iter:
124
- eff_span = min(x.shape[1], qcur.size(1), k.size(1))
125
  if xa is not None:
126
  eff_span = min(eff_span, xa.shape[1])
127
  if eff_span == 0:
@@ -130,9 +131,9 @@ class attentiona(nn.Module):
130
  qiter = qcur[:, :, :eff_span, :]
131
  kiter = k[:, :, :eff_span, :]
132
  viter = v[:, :, :eff_span, :]
133
- q = self.attn_local.query_module(qiter)
134
- k = self.attn_local.key_module(kiter)
135
- v = self.attn_local.value_module(viter)
136
 
137
  iter_mask = None
138
  if mask is not None:
@@ -143,7 +144,7 @@ class attentiona(nn.Module):
143
 
144
  attn_iter = calculate_attention(
145
  self.lnc(q), self.lnd(k), v,
146
- mask=iter_mask)
147
 
148
  iter_out = torch.zeros_like(qcur)
149
  iter_out[:, :, :eff_span, :] = attn_iter
@@ -157,21 +158,22 @@ class attentiona(nn.Module):
157
  qcur = qcur + iter_out
158
  attn_out = iter_out
159
  iteration += 1
 
160
 
161
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
162
  return self.o(output), None
163
 
164
  def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
165
 
166
- batch, ctx, dims = x.size()
167
  output = torch.zeros_like(x)
168
  num_win = (ctx + win_size - 1) // win_size
169
 
170
  for i in range(num_win):
171
  qstart = i * win_size
172
  qend = min(qstart + win_size, ctx)
173
- current_win_qlen = qend - qstart
174
- if current_win_qlen == 0:
175
  continue
176
 
177
  kstart = max(0, qend - span_len)
@@ -186,10 +188,7 @@ class attentiona(nn.Module):
186
  elif mask.dim() == 2:
187
  win_mask = mask[qstart:qend, kstart:kend]
188
 
189
- attn_out, _ = self._focus(
190
- x=qwin,
191
- xa=kwin,
192
- mask=win_mask)
193
  output[:, qstart:qend, :] = attn_out
194
  return output
195
 
@@ -201,21 +200,21 @@ class attentiona(nn.Module):
201
  output, _ = self._focus(x, xa, mask)
202
  return output
203
 
204
- class attentionb(nn.Module):
205
  def __init__(self, dims: int, head: int):
206
- super(attentionb, self).__init__()
207
- self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
208
  self.dims = dims
209
  self.head = head
210
- self.head_dim = dims // head
211
- self.rope = rotary(dims=dims, head=head)
212
-
213
  def forward(self, x: Tensor, xa = None, mask = None):
214
- z = default(xa, x)
215
- q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
 
 
216
  q = self.rope(q, q.shape[2])
217
  k = self.rope(k, k.shape[2])
218
- a = scaled_dot_product_attention(self.lnb(q), self.lnb(k), v, is_causal=mask is not None and q.shape[1] > 1)
219
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
220
  return self.o(out)
221
 
@@ -224,14 +223,15 @@ class Residual(nn.Module):
224
  super().__init__()
225
 
226
  self.lna = nn.LayerNorm(dims, bias=False)
227
- self.attnb = attentionb(dims, head)
228
- self.attna = attentiona(dims, head, max_iter=3)
229
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
230
 
231
  def forward(self, x, xa = None, mask = None) -> Tensor:
232
- x = x + self.attnb(self.lna(x), None, mask)
233
  if xa is not None:
234
- x = x + self.attna(self.lna(x), xa, None, use_sliding_win=True, win_size=500, span_len=1500)
 
235
  x = x + self.mlp(self.lna(x))
236
  return x
237
 
@@ -239,8 +239,9 @@ class processor(nn.Module):
239
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
240
  super(processor, self).__init__()
241
 
242
- self.ln = nn.LayerNorm(dims)
243
- self.blend = nn.Parameter(torch.tensor(0.5), requires_grad=True)
 
244
  self.token_emb = nn.Embedding(vocab, dims)
245
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
246
  self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
@@ -252,7 +253,7 @@ class processor(nn.Module):
252
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
253
 
254
  self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
255
-
256
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
257
  self.register_buffer("mask", mask, persistent=False)
258
 
@@ -263,14 +264,13 @@ class processor(nn.Module):
263
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
264
 
265
  for b in chain(self.bA or []):
266
- xa = b(xa, None, None)
267
- x = b(x, None, self.mask)
268
- x = b(x, xa, None)
269
- xc = b(torch.cat([x, xa], dim=1), xa=xa, mask=self.mask) if modal else None
270
  x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
271
 
272
  x = nn.functional.dropout(x, p=0.001, training=self.training)
273
- x = self.ln(x)
274
  x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
275
  return x
276
 
 
9
  from datetime import datetime
10
  from dataclasses import dataclass
11
  from torch.nn.functional import scaled_dot_product_attention
12
+ from echoutils import *
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
  dtype = torch.float32
15
  warnings.filterwarnings("ignore")
 
31
  self.dims = dims
32
  self.head = head
33
  self.head_dim = dims // head
34
+ self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
35
  self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
36
 
37
  def _compute_freqs_base(self):
 
53
  x1 = x1.view(orig_shape)
54
  return torch.cat([x1.type_as(x), x2], dim=-1)
55
 
56
+ def shape(dims, head, q, k, v):
57
+ head_dim = dims // head
58
+ scale = head_dim ** -0.25
59
+ q = q * scale
60
+ k = k * scale
61
+ v = v
62
+ def _shape(tensor):
63
+ return tensor.view(*tensor.shape[:2], head, -1).permute(0, 2, 1, 3).contiguous()
64
+ return _shape(q), _shape(k), _shape(v)
65
+
66
  def qkv_init(dims: int, head: int):
67
  head_dim = dims // head
68
  q = nn.Linear(dims, dims)
 
70
  v = nn.Linear(dims, dims)
71
  o = nn.Linear(dims, dims)
72
  lna = nn.LayerNorm(dims, bias=False)
73
+ lnb = nn.LayerNorm(dims, bias=False)
74
+ lnc = nn.LayerNorm(head_dim, bias=False)
75
+ lnd = nn.LayerNorm(head_dim, bias=False)
76
+ return q, k, v, o, lna, lnb, lnc, lnd
 
 
 
 
 
 
 
 
 
77
 
78
+ def calculate_attention(q, k, v, mask=None, temp=1.0):
79
  scaled_q = q
80
+ if temp != 1.0 and temp > 0:
81
+ scaled_q = q * (1.0 / temp)**.5
 
82
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
83
  return out
84
 
85
  class LocalOut(nn.Module):
86
+ def __init__(self, dims: int, head: int):
87
  super().__init__()
88
+ head_dim = dims // head
89
  self.query_module = nn.Linear(head_dim, head_dim)
90
  self.key_module = nn.Linear(head_dim, head_dim)
91
  self.value_module = nn.Linear(head_dim, head_dim)
92
  self.out_proj = nn.Linear(head_dim, head_dim)
 
 
 
93
 
94
+ def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
95
+ batch, _, ctx, _ = attn_output.shape
96
+ return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
97
+
98
+ class attentionb(nn.Module):
99
+ def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0):
100
+ super(attentionb, self).__init__()
101
+ self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
102
  self.dims = dims
103
  self.head = head
 
104
  self.max_iter = max_iter
105
  self.threshold = nn.Parameter(torch.tensor(threshold))
106
+ self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
107
  self.factor = nn.Parameter(torch.tensor(factor))
108
+ self.alocal = LocalOut(dims, head)
 
 
 
109
 
110
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
111
+ q = self.q(self.lna(x))
112
+ k = self.k(self.lnb(x if xa is None else xa))
113
+ v = self.v(self.lnb(x if xa is None else xa))
114
+ q, k, v = shape(self.dims, self.head, q, k, v)
115
 
116
  iteration = 0
117
+ temp = self.temp.item()
118
  prev_out = torch.zeros_like(q)
119
  attn_out = torch.zeros_like(q)
120
  threshold = self.threshold.item()
 
122
  qcur = q
123
 
124
  while iteration < self.max_iter:
125
+ eff_span = min(qcur.shape[1], k.shape[1])
126
  if xa is not None:
127
  eff_span = min(eff_span, xa.shape[1])
128
  if eff_span == 0:
 
131
  qiter = qcur[:, :, :eff_span, :]
132
  kiter = k[:, :, :eff_span, :]
133
  viter = v[:, :, :eff_span, :]
134
+ q = self.alocal.query_module(qiter)
135
+ k = self.alocal.key_module(kiter)
136
+ v = self.alocal.value_module(viter)
137
 
138
  iter_mask = None
139
  if mask is not None:
 
144
 
145
  attn_iter = calculate_attention(
146
  self.lnc(q), self.lnd(k), v,
147
+ mask=iter_mask, temp=temp)
148
 
149
  iter_out = torch.zeros_like(qcur)
150
  iter_out[:, :, :eff_span, :] = attn_iter
 
158
  qcur = qcur + iter_out
159
  attn_out = iter_out
160
  iteration += 1
161
+ temp += 0.005
162
 
163
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
164
  return self.o(output), None
165
 
166
  def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
167
 
168
+ batch, ctx, dims = x.shape
169
  output = torch.zeros_like(x)
170
  num_win = (ctx + win_size - 1) // win_size
171
 
172
  for i in range(num_win):
173
  qstart = i * win_size
174
  qend = min(qstart + win_size, ctx)
175
+ win_qlen = qend - qstart
176
+ if win_qlen == 0:
177
  continue
178
 
179
  kstart = max(0, qend - span_len)
 
188
  elif mask.dim() == 2:
189
  win_mask = mask[qstart:qend, kstart:kend]
190
 
191
+ attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask)
 
 
 
192
  output[:, qstart:qend, :] = attn_out
193
  return output
194
 
 
200
  output, _ = self._focus(x, xa, mask)
201
  return output
202
 
203
+ class attentiona(nn.Module):
204
  def __init__(self, dims: int, head: int):
205
+ super(attentiona, self).__init__()
206
+ self.q, self.k, self.v, self.o, self.lna, self.lnb, self.lnc, self.lnd = qkv_init(dims, head)
207
  self.dims = dims
208
  self.head = head
209
+ self.rope = rotary(dims=dims, head=head)
 
 
210
  def forward(self, x: Tensor, xa = None, mask = None):
211
+ q = self.q(self.lna(x))
212
+ k = self.k(self.lnb(x if xa is None else xa))
213
+ v = self.v(self.lnb(x if xa is None else xa))
214
+ q, k, v = shape(self.dims, self.head, q, k, v)
215
  q = self.rope(q, q.shape[2])
216
  k = self.rope(k, k.shape[2])
217
+ a = scaled_dot_product_attention(self.lnc(q), self.lnd(k), v, is_causal=mask is not None and q.shape[1] > 1)
218
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
219
  return self.o(out)
220
 
 
223
  super().__init__()
224
 
225
  self.lna = nn.LayerNorm(dims, bias=False)
226
+ self.attna = attentiona(dims, head)
227
+ self.attnb = attentionb(dims, head, max_iter=3)
228
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
229
 
230
  def forward(self, x, xa = None, mask = None) -> Tensor:
231
+ x = x + self.attna(self.lna(x), mask=mask)
232
  if xa is not None:
233
+ x = x + self.attna(self.lna(x), xa, mask=None)
234
+ x = x + self.attnb(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=256, span_len=512)
235
  x = x + self.mlp(self.lna(x))
236
  return x
237
 
 
239
  def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
240
  super(processor, self).__init__()
241
 
242
+ self.lna = nn.LayerNorm(dims)
243
+ self.lnb = nn.LayerNorm(dims)
244
+ self.lnc = nn.LayerNorm(dims)
245
  self.token_emb = nn.Embedding(vocab, dims)
246
  self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
247
  self.audio_emb = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
 
253
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
254
 
255
  self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
256
+
257
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
258
  self.register_buffer("mask", mask, persistent=False)
259
 
 
264
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
265
 
266
  for b in chain(self.bA or []):
267
+ xa = b(self.lna(xa))
268
+ x = b(self.lnb(x), xa=xa, mask=self.mask)
269
+ xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask) if modal else None
 
270
  x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
271
 
272
  x = nn.functional.dropout(x, p=0.001, training=self.training)
273
+ x = self.lnc(x)
274
  x = x @ torch.transpose(self.token_emb.weight.to(dtype), 0, 1).float()
275
  return x
276