Update model_simple.py
Browse files- model_simple.py +38 -58
model_simple.py
CHANGED
@@ -53,14 +53,15 @@ class rotary(nn.Module):
|
|
53 |
x1 = x1.view(orig_shape)
|
54 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
55 |
|
56 |
-
def
|
57 |
head_dim = dims // head
|
58 |
-
scale = head_dim ** -0.5
|
59 |
q = nn.Linear(dims, dims)
|
60 |
k = nn.Linear(dims, dims, bias=False)
|
61 |
v = nn.Linear(dims, dims)
|
62 |
o = nn.Linear(dims, dims)
|
63 |
-
|
|
|
|
|
64 |
|
65 |
def create_qkv(dims, head, q, k, v, x, xa):
|
66 |
head_dim = dims // head
|
@@ -73,16 +74,15 @@ def create_qkv(dims, head, q, k, v, x, xa):
|
|
73 |
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
74 |
return _shape(q), _shape(k), _shape(v)
|
75 |
|
76 |
-
def calculate_attention(q, k, v, mask=None, temperature=1.0
|
77 |
scaled_q = q
|
78 |
if temperature != 1.0 and temperature > 0:
|
79 |
scaled_q = q * (1.0 / temperature)**.5
|
80 |
|
81 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
82 |
-
# out = scaled_dot_product_attention(scaled_q, k, v, attn_mask=attn_mask, is_causal=is_causal if attn_mask is None else False)
|
83 |
return out
|
84 |
|
85 |
-
class
|
86 |
def __init__(self, head_dim: int):
|
87 |
super().__init__()
|
88 |
self.head_dim = head_dim
|
@@ -95,53 +95,41 @@ class LocalAttentionModule(nn.Module):
|
|
95 |
return x
|
96 |
|
97 |
class attentiona(nn.Module):
|
98 |
-
def __init__(self, dims: int, head: int,
|
99 |
super(attentiona, self).__init__()
|
100 |
-
|
101 |
self.dims = dims
|
102 |
self.head = head
|
103 |
self.head_dim = dims // head
|
104 |
-
self.
|
105 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
106 |
self.factor = nn.Parameter(torch.tensor(factor))
|
107 |
-
self.dropout = dropout
|
108 |
-
|
109 |
-
self.q = nn.Linear(dims, dims)
|
110 |
-
self.k = nn.Linear(dims, dims, bias=False)
|
111 |
-
self.v = nn.Linear(dims, dims)
|
112 |
-
self.o = nn.Linear(dims, dims)
|
113 |
-
|
114 |
-
self.lna = nn.LayerNorm(dims, bias=False)
|
115 |
-
self.lnb = nn.LayerNorm(dims, bias=False)
|
116 |
self.lnc = nn.LayerNorm(self.head_dim, bias=False)
|
117 |
self.lnd = nn.LayerNorm(self.head_dim, bias=False)
|
118 |
-
self.attn_local =
|
119 |
|
120 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
121 |
-
|
122 |
-
k = self.
|
123 |
-
v = self.v(self.lnb(x if xa is None else xa))
|
124 |
-
query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
125 |
-
key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
126 |
-
value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
127 |
|
128 |
iteration = 0
|
129 |
-
prev_out = torch.zeros_like(
|
130 |
-
attn_out = torch.zeros_like(
|
131 |
threshold = self.threshold.item()
|
132 |
factor = self.factor.item()
|
133 |
-
qcur =
|
134 |
|
135 |
-
while iteration < self.
|
136 |
-
eff_span = min(x.shape[1], qcur.size(1),
|
137 |
if xa is not None:
|
138 |
eff_span = min(eff_span, xa.shape[1])
|
139 |
if eff_span == 0:
|
140 |
break
|
141 |
|
142 |
qiter = qcur[:, :, :eff_span, :]
|
143 |
-
kiter =
|
144 |
-
viter =
|
145 |
q = self.attn_local.query_module(qiter)
|
146 |
k = self.attn_local.key_module(kiter)
|
147 |
v = self.attn_local.value_module(viter)
|
@@ -155,14 +143,12 @@ class attentiona(nn.Module):
|
|
155 |
|
156 |
attn_iter = calculate_attention(
|
157 |
self.lnc(q), self.lnd(k), v,
|
158 |
-
mask=iter_mask
|
159 |
-
is_causal=True)
|
160 |
|
161 |
iter_out = torch.zeros_like(qcur)
|
162 |
iter_out[:, :, :eff_span, :] = attn_iter
|
163 |
diff = torch.abs(iter_out - prev_out).mean()
|
164 |
dthresh = threshold + factor * diff
|
165 |
-
|
166 |
if diff < dthresh and iteration > 0:
|
167 |
attn_out = iter_out
|
168 |
break
|
@@ -175,7 +161,7 @@ class attentiona(nn.Module):
|
|
175 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
176 |
return self.o(output), None
|
177 |
|
178 |
-
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None
|
179 |
|
180 |
batch, ctx, dims = x.size()
|
181 |
output = torch.zeros_like(x)
|
@@ -188,17 +174,17 @@ class attentiona(nn.Module):
|
|
188 |
if current_win_qlen == 0:
|
189 |
continue
|
190 |
|
191 |
-
|
192 |
-
|
193 |
qwin = x[:, qstart:qend, :]
|
194 |
-
kwin = x[:,
|
195 |
|
196 |
win_mask = None
|
197 |
if mask is not None:
|
198 |
if mask.dim() == 4:
|
199 |
-
win_mask = mask[:, :, qstart:qend,
|
200 |
elif mask.dim() == 2:
|
201 |
-
win_mask = mask[qstart:qend,
|
202 |
|
203 |
attn_out, _ = self._focus(
|
204 |
x=qwin,
|
@@ -239,13 +225,13 @@ class Residual(nn.Module):
|
|
239 |
|
240 |
self.lna = nn.LayerNorm(dims, bias=False)
|
241 |
self.attnb = attentionb(dims, head)
|
242 |
-
self.attna = attentiona(dims, head,
|
243 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
244 |
|
245 |
def forward(self, x, xa = None, mask = None) -> Tensor:
|
246 |
-
x = x + self.attnb(self.lna(x),
|
247 |
if xa is not None:
|
248 |
-
x = x + self.attna(self.lna(x), xa,
|
249 |
x = x + self.mlp(self.lna(x))
|
250 |
return x
|
251 |
|
@@ -266,28 +252,22 @@ class processor(nn.Module):
|
|
266 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
267 |
|
268 |
self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
|
269 |
-
|
270 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
271 |
self.register_buffer("mask", mask, persistent=False)
|
272 |
|
273 |
-
def forward(self, x, xa, sequential=False) -> Tensor:
|
274 |
|
275 |
x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
|
276 |
xa = self.audio_enc(xa).permute(0, 2, 1)
|
277 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
278 |
|
279 |
for b in chain(self.bA or []):
|
280 |
-
xa = b(
|
281 |
-
x = b(x
|
282 |
-
x = b(x
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
# if sequential:
|
287 |
-
# x = y
|
288 |
-
# else:
|
289 |
-
# a = torch.sigmoid(self.blend)
|
290 |
-
# x = a * y + (1 - a) * x
|
291 |
|
292 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
293 |
x = self.ln(x)
|
@@ -360,4 +340,4 @@ class Model(nn.Module):
|
|
360 |
print("Initialization summary:")
|
361 |
for module_type, count in self.init_counts.items():
|
362 |
if count > 0:
|
363 |
-
print(f"{module_type}: {count}")
|
|
|
53 |
x1 = x1.view(orig_shape)
|
54 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
55 |
|
56 |
+
def qkv_init(dims: int, head: int):
|
57 |
head_dim = dims // head
|
|
|
58 |
q = nn.Linear(dims, dims)
|
59 |
k = nn.Linear(dims, dims, bias=False)
|
60 |
v = nn.Linear(dims, dims)
|
61 |
o = nn.Linear(dims, dims)
|
62 |
+
lna = nn.LayerNorm(dims, bias=False)
|
63 |
+
lnb = nn.LayerNorm(head_dim, bias=False)
|
64 |
+
return q, k, v, o, lna, lnb
|
65 |
|
66 |
def create_qkv(dims, head, q, k, v, x, xa):
|
67 |
head_dim = dims // head
|
|
|
74 |
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous()
|
75 |
return _shape(q), _shape(k), _shape(v)
|
76 |
|
77 |
+
def calculate_attention(q, k, v, mask=None, temperature=1.0):
|
78 |
scaled_q = q
|
79 |
if temperature != 1.0 and temperature > 0:
|
80 |
scaled_q = q * (1.0 / temperature)**.5
|
81 |
|
82 |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
|
|
83 |
return out
|
84 |
|
85 |
+
class LocalOut(nn.Module):
|
86 |
def __init__(self, head_dim: int):
|
87 |
super().__init__()
|
88 |
self.head_dim = head_dim
|
|
|
95 |
return x
|
96 |
|
97 |
class attentiona(nn.Module):
|
98 |
+
def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1):
|
99 |
super(attentiona, self).__init__()
|
100 |
+
self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head)
|
101 |
self.dims = dims
|
102 |
self.head = head
|
103 |
self.head_dim = dims // head
|
104 |
+
self.max_iter = max_iter
|
105 |
self.threshold = nn.Parameter(torch.tensor(threshold))
|
106 |
self.factor = nn.Parameter(torch.tensor(factor))
|
107 |
+
self.dropout = dropout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
self.lnc = nn.LayerNorm(self.head_dim, bias=False)
|
109 |
self.lnd = nn.LayerNorm(self.head_dim, bias=False)
|
110 |
+
self.attn_local = LocalOut(self.head_dim)
|
111 |
|
112 |
def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None):
|
113 |
+
z = default(xa, x)
|
114 |
+
q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z))
|
|
|
|
|
|
|
|
|
115 |
|
116 |
iteration = 0
|
117 |
+
prev_out = torch.zeros_like(q)
|
118 |
+
attn_out = torch.zeros_like(q)
|
119 |
threshold = self.threshold.item()
|
120 |
factor = self.factor.item()
|
121 |
+
qcur = q
|
122 |
|
123 |
+
while iteration < self.max_iter:
|
124 |
+
eff_span = min(x.shape[1], qcur.size(1), k.size(1))
|
125 |
if xa is not None:
|
126 |
eff_span = min(eff_span, xa.shape[1])
|
127 |
if eff_span == 0:
|
128 |
break
|
129 |
|
130 |
qiter = qcur[:, :, :eff_span, :]
|
131 |
+
kiter = k[:, :, :eff_span, :]
|
132 |
+
viter = v[:, :, :eff_span, :]
|
133 |
q = self.attn_local.query_module(qiter)
|
134 |
k = self.attn_local.key_module(kiter)
|
135 |
v = self.attn_local.value_module(viter)
|
|
|
143 |
|
144 |
attn_iter = calculate_attention(
|
145 |
self.lnc(q), self.lnd(k), v,
|
146 |
+
mask=iter_mask)
|
|
|
147 |
|
148 |
iter_out = torch.zeros_like(qcur)
|
149 |
iter_out[:, :, :eff_span, :] = attn_iter
|
150 |
diff = torch.abs(iter_out - prev_out).mean()
|
151 |
dthresh = threshold + factor * diff
|
|
|
152 |
if diff < dthresh and iteration > 0:
|
153 |
attn_out = iter_out
|
154 |
break
|
|
|
161 |
output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2)
|
162 |
return self.o(output), None
|
163 |
|
164 |
+
def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor:
|
165 |
|
166 |
batch, ctx, dims = x.size()
|
167 |
output = torch.zeros_like(x)
|
|
|
174 |
if current_win_qlen == 0:
|
175 |
continue
|
176 |
|
177 |
+
kstart = max(0, qend - span_len)
|
178 |
+
kend = qend
|
179 |
qwin = x[:, qstart:qend, :]
|
180 |
+
kwin = x[:, kstart:kend, :]
|
181 |
|
182 |
win_mask = None
|
183 |
if mask is not None:
|
184 |
if mask.dim() == 4:
|
185 |
+
win_mask = mask[:, :, qstart:qend, kstart:kend]
|
186 |
elif mask.dim() == 2:
|
187 |
+
win_mask = mask[qstart:qend, kstart:kend]
|
188 |
|
189 |
attn_out, _ = self._focus(
|
190 |
x=qwin,
|
|
|
225 |
|
226 |
self.lna = nn.LayerNorm(dims, bias=False)
|
227 |
self.attnb = attentionb(dims, head)
|
228 |
+
self.attna = attentiona(dims, head, max_iter=3)
|
229 |
self.mlp = nn.Sequential(Linear(dims, dims*4), get_activation(act), Linear(dims*4, dims))
|
230 |
|
231 |
def forward(self, x, xa = None, mask = None) -> Tensor:
|
232 |
+
x = x + self.attnb(self.lna(x), None, mask)
|
233 |
if xa is not None:
|
234 |
+
x = x + self.attna(self.lna(x), xa, None, use_sliding_win=True, win_size=500, span_len=1500)
|
235 |
x = x + self.mlp(self.lna(x))
|
236 |
return x
|
237 |
|
|
|
252 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
253 |
|
254 |
self.bA = nn.ModuleList([Residual(dims, head, act_fn) for _ in range(layer)])
|
255 |
+
|
256 |
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
|
257 |
self.register_buffer("mask", mask, persistent=False)
|
258 |
|
259 |
+
def forward(self, x, xa, sequential=False, modal=False) -> Tensor:
|
260 |
|
261 |
x = self.token_emb(x.long()) + self.positions[:x.shape[1]]
|
262 |
xa = self.audio_enc(xa).permute(0, 2, 1)
|
263 |
xa = xa + self.audio_emb(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
|
264 |
|
265 |
for b in chain(self.bA or []):
|
266 |
+
xa = b(xa, None, None)
|
267 |
+
x = b(x, None, self.mask)
|
268 |
+
x = b(x, xa, None)
|
269 |
+
xc = b(torch.cat([x, xa], dim=1), xa=xa, mask=self.mask) if modal else None
|
270 |
+
x = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None) if modal else x
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
x = nn.functional.dropout(x, p=0.001, training=self.training)
|
273 |
x = self.ln(x)
|
|
|
340 |
print("Initialization summary:")
|
341 |
for module_type, count in self.init_counts.items():
|
342 |
if count > 0:
|
343 |
+
print(f"{module_type}: {count}")
|