Sin2pi commited on
Commit
9465621
·
verified ·
1 Parent(s): f439a74

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +38 -58
model_simple.py CHANGED
@@ -53,14 +53,15 @@ class rotary(nn.Module):
53
  x1 = x1.view(orig_shape)
54
  return torch.cat([x1.type_as(x), x2], dim=-1)
55
 
56
- def qkvinit(dims: int, head: int):
57
  head_dim = dims // head
58
- scale = head_dim ** -0.5
59
  q = nn.Linear(dims, dims)
60
  k = nn.Linear(dims, dims, bias=False)
61
  v = nn.Linear(dims, dims)
62
  o = nn.Linear(dims, dims)
63
- return q, k, v, o, scale
 
 
64
 
65
  def create_qkv(dims, head, q, k, v, x, xa):
66
  head_dim = dims // head
@@ -73,16 +74,15 @@ def create_qkv(dims, head, q, k, v, x, xa):
73
  return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
74
  return _shape(q), _shape(k), _shape(v)
75
 
76
- def calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True):
77
  scaled_q = q
78
  if temperature != 1.0 and temperature > 0:
79
  scaled_q = q * (1.0 / temperature)**.5
80
 
81
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
82
- # out = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
83
  return out
84
 
85
- class LocalAttentionModule(nn.Module):
86
  def __init__(self, head_dim: int):
87
  super().__init__()
88
  self.head_dim = head_dim
@@ -95,53 +95,41 @@ class LocalAttentionModule(nn.Module):
95
  return x
96
 
97
  class attentiona(nn.Module):
98
- def __init__(self, dims: int, head: int, max_iterations: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
99
  super(attentiona, self).__init__()
100
- # self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
101
  self.dims = dims
102
  self.head = head
103
  self.head_dim = dims // head
104
- self.max_iterations = max_iterations
105
  self.threshold = nn.Parameter(torch.tensor(threshold))
106
  self.factor = nn.Parameter(torch.tensor(factor))
107
- self.dropout = dropout
108
-
109
- self.q = nn.Linear(dims, dims)
110
- self.k = nn.Linear(dims, dims, bias=False)
111
- self.v = nn.Linear(dims, dims)
112
- self.o = nn.Linear(dims, dims)
113
-
114
- self.lna = nn.LayerNorm(dims, bias=False)
115
- self.lnb = nn.LayerNorm(dims, bias=False)
116
  self.lnc = nn.LayerNorm(self.head_dim, bias=False)
117
  self.lnd = nn.LayerNorm(self.head_dim, bias=False)
118
- self.attn_local = LocalAttentionModule(self.head_dim)
119
 
120
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
121
- q = self.q(self.lna(x))
122
- k = self.k(self.lnb(x if xa is None else xa))
123
- v = self.v(self.lnb(x if xa is None else xa))
124
- query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
125
- key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
126
- value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
127
 
128
  iteration = 0
129
- prev_out = torch.zeros_like(query)
130
- attn_out = torch.zeros_like(query)
131
  threshold = self.threshold.item()
132
  factor = self.factor.item()
133
- qcur = query
134
 
135
- while iteration < self.max_iterations:
136
- eff_span = min(x.shape[1], qcur.size(1), key.size(1))
137
  if xa is not None:
138
  eff_span = min(eff_span, xa.shape[1])
139
  if eff_span == 0:
140
  break
141
 
142
  qiter = qcur[:, :, :eff_span, :]
143
- kiter = key[:, :, :eff_span, :]
144
- viter = value[:, :, :eff_span, :]
145
  q = self.attn_local.query_module(qiter)
146
  k = self.attn_local.key_module(kiter)
147
  v = self.attn_local.value_module(viter)
@@ -155,14 +143,12 @@ class attentiona(nn.Module):
155
 
156
  attn_iter = calculate_attention(
157
  self.lnc(q), self.lnd(k), v,
158
- mask=iter_mask,
159
- is_causal=True)
160
 
161
  iter_out = torch.zeros_like(qcur)
162
  iter_out[:, :, :eff_span, :] = attn_iter
163
  diff = torch.abs(iter_out - prev_out).mean()
164
  dthresh = threshold + factor * diff
165
-
166
  if diff < dthresh and iteration > 0:
167
  attn_out = iter_out
168
  break
@@ -175,7 +161,7 @@ class attentiona(nn.Module):
175
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
176
  return self.o(output), None
177
 
178
- def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor:
179
 
180
  batch, ctx, dims = x.size()
181
  output = torch.zeros_like(x)
@@ -188,17 +174,17 @@ class attentiona(nn.Module):
188
  if current_win_qlen == 0:
189
  continue
190
 
191
- kvstart = max(0, qend - span_len)
192
- kvend = qend
193
  qwin = x[:, qstart:qend, :]
194
- kwin = x[:, kvstart:kvend, :]
195
 
196
  win_mask = None
197
  if mask is not None:
198
  if mask.dim() == 4:
199
- win_mask = mask[:, :, qstart:qend, kvstart:kvend]
200
  elif mask.dim() == 2:
201
- win_mask = mask[qstart:qend, kvstart:kvend]
202
 
203
  attn_out, _ = self._focus(
204
  x=qwin,
@@ -239,13 +225,13 @@ class Residual(nn.Module):
239
 
240
  self.lna = nn.LayerNorm(dims, bias=False)
241
  self.attnb = attentionb(dims, head)
242
- self.attna = attentiona(dims, head, max_iterations=3)
243
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
244
 
245
  def forward(self, x, xa = None, mask = None) -> Tensor:
246
- x = x + self.attnb(self.lna(x), xa=None, mask=mask)
247
  if xa is not None:
248
- x = x + self.attna(self.lna(x), xa, mask=None, use_sliding_win=True, win_size=500, span_len=1500)
249
  x = x + self.mlp(self.lna(x))
250
  return x
251
 
@@ -266,28 +252,22 @@ class processor(nn.Module):
266
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
267
 
268
  self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
269
-
270
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
271
  self.register_buffer("mask", mask, persistent=False)
272
 
273
- def forward(self, x, xa, sequential=False) -> Tensor:
274
 
275
  x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
276
  xa = self.audio_enc(xa).permute(0, 2, 1)
277
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
278
 
279
  for b in chain(self.bA or []):
280
- xa = b(x=xa, xa=None, mask=None)
281
- x = b(x=x, xa=None, mask=self.mask)
282
- x = b(x=x, xa=xa, mask=None)
283
- # xc = b(torch.cat([x, xa], dim=1), xa=None, mask=self.mask)
284
- # x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None)
285
-
286
- # if sequential:
287
- # x = y
288
- # else:
289
- # a = torch.sigmoid(self.blend)
290
- # x = a * y + (1 - a) * x
291
 
292
  x = nn.functional.dropout(x, p=0.001, training=self.training)
293
  x = self.ln(x)
@@ -360,4 +340,4 @@ class Model(nn.Module):
360
  print("Initialization summary:")
361
  for module_type, count in self.init_counts.items():
362
  if count > 0:
363
- print(f"{module_type}: {count}")
 
53
  x1 = x1.view(orig_shape)
54
  return torch.cat([x1.type_as(x), x2], dim=-1)
55
 
56
+ def qkv_init(dims: int, head: int):
57
  head_dim = dims // head
 
58
  q = nn.Linear(dims, dims)
59
  k = nn.Linear(dims, dims, bias=False)
60
  v = nn.Linear(dims, dims)
61
  o = nn.Linear(dims, dims)
62
+ lna = nn.LayerNorm(dims, bias=False)
63
+ lnb = nn.LayerNorm(head_dim, bias=False)
64
+ return q, k, v, o, lna, lnb
65
 
66
  def create_qkv(dims, head, q, k, v, x, xa):
67
  head_dim = dims // head
 
74
  return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
75
  return _shape(q), _shape(k), _shape(v)
76
 
77
+ def calculate_attention(q, k, v, mask=None, temperature=1.0):
78
  scaled_q = q
79
  if temperature != 1.0 and temperature > 0:
80
  scaled_q = q * (1.0 / temperature)**.5
81
 
82
  out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
 
83
  return out
84
 
85
+ class LocalOut(nn.Module):
86
  def __init__(self, head_dim: int):
87
  super().__init__()
88
  self.head_dim = head_dim
 
95
  return x
96
 
97
  class attentiona(nn.Module):
98
+ def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
99
  super(attentiona, self).__init__()
100
+ self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
101
  self.dims = dims
102
  self.head = head
103
  self.head_dim = dims // head
104
+ self.max_iter = max_iter
105
  self.threshold = nn.Parameter(torch.tensor(threshold))
106
  self.factor = nn.Parameter(torch.tensor(factor))
107
+ self.dropout = dropout
 
 
 
 
 
 
 
 
108
  self.lnc = nn.LayerNorm(self.head_dim, bias=False)
109
  self.lnd = nn.LayerNorm(self.head_dim, bias=False)
110
+ self.attn_local = LocalOut(self.head_dim)
111
 
112
  def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
113
+ z = default(xa, x)
114
+ q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
 
 
 
 
115
 
116
  iteration = 0
117
+ prev_out = torch.zeros_like(q)
118
+ attn_out = torch.zeros_like(q)
119
  threshold = self.threshold.item()
120
  factor = self.factor.item()
121
+ qcur = q
122
 
123
+ while iteration < self.max_iter:
124
+ eff_span = min(x.shape[1], qcur.size(1), k.size(1))
125
  if xa is not None:
126
  eff_span = min(eff_span, xa.shape[1])
127
  if eff_span == 0:
128
  break
129
 
130
  qiter = qcur[:, :, :eff_span, :]
131
+ kiter = k[:, :, :eff_span, :]
132
+ viter = v[:, :, :eff_span, :]
133
  q = self.attn_local.query_module(qiter)
134
  k = self.attn_local.key_module(kiter)
135
  v = self.attn_local.value_module(viter)
 
143
 
144
  attn_iter = calculate_attention(
145
  self.lnc(q), self.lnd(k), v,
146
+ mask=iter_mask)
 
147
 
148
  iter_out = torch.zeros_like(qcur)
149
  iter_out[:, :, :eff_span, :] = attn_iter
150
  diff = torch.abs(iter_out - prev_out).mean()
151
  dthresh = threshold + factor * diff
 
152
  if diff < dthresh and iteration > 0:
153
  attn_out = iter_out
154
  break
 
161
  output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
162
  return self.o(output), None
163
 
164
+ def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
165
 
166
  batch, ctx, dims = x.size()
167
  output = torch.zeros_like(x)
 
174
  if current_win_qlen == 0:
175
  continue
176
 
177
+ kstart = max(0, qend - span_len)
178
+ kend = qend
179
  qwin = x[:, qstart:qend, :]
180
+ kwin = x[:, kstart:kend, :]
181
 
182
  win_mask = None
183
  if mask is not None:
184
  if mask.dim() == 4:
185
+ win_mask = mask[:, :, qstart:qend, kstart:kend]
186
  elif mask.dim() == 2:
187
+ win_mask = mask[qstart:qend, kstart:kend]
188
 
189
  attn_out, _ = self._focus(
190
  x=qwin,
 
225
 
226
  self.lna = nn.LayerNorm(dims, bias=False)
227
  self.attnb = attentionb(dims, head)
228
+ self.attna = attentiona(dims, head, max_iter=3)
229
  self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
230
 
231
  def forward(self, x, xa = None, mask = None) -> Tensor:
232
+ x = x + self.attnb(self.lna(x), None, mask)
233
  if xa is not None:
234
+ x = x + self.attna(self.lna(x), xa, None, use_sliding_win=True, win_size=500, span_len=1500)
235
  x = x + self.mlp(self.lna(x))
236
  return x
237
 
 
252
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
253
 
254
  self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
255
+
256
  mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
257
  self.register_buffer("mask", mask, persistent=False)
258
 
259
+ def forward(self, x, xa, sequential=False, modal=False) -> Tensor:
260
 
261
  x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
262
  xa = self.audio_enc(xa).permute(0, 2, 1)
263
  xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
264
 
265
  for b in chain(self.bA or []):
266
+ xa = b(xa, None, None)
267
+ x = b(x, None, self.mask)
268
+ x = b(x, xa, None)
269
+ xc = b(torch.cat([x, xa], dim=1), xa=xa, mask=self.mask) if modal else None
270
+ x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
 
 
 
 
 
 
271
 
272
  x = nn.functional.dropout(x, p=0.001, training=self.training)
273
  x = self.ln(x)
 
340
  print("Initialization summary:")
341
  for module_type, count in self.init_counts.items():
342
  if count > 0:
343
+ print(f"{module_type}: {count}")