AbstractPhil commited on
Commit
fb60b42
·
verified ·
1 Parent(s): 880d8a1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +357 -1
README.md CHANGED
@@ -119,4 +119,360 @@ def main():
119
  #upload_file(ckpt, ckpt, repo_id=hf_repo_id)
120
 
121
 
122
- pbar.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  #upload_file(ckpt, ckpt, repo_id=hf_repo_id)
120
 
121
 
122
+ pbar.close()
123
+
124
+ ```
125
+
126
+
127
+ You can inference the test version using stable-diffusion-15 as an example test.
128
+ The CLIP_L responses fall apart when too many nodes hit those guidance bells, but it's definitely a powerful first test using divergent systems.
129
+
130
+ Should just run clean on colab using a l4.
131
+
132
+ ```
133
+ # Optimized inference_adapter.py
134
+
135
+ import torch
136
+ import math
137
+ from PIL import Image
138
+ from torchvision.transforms import ToPILImage
139
+ from safetensors.torch import load_file as load_safetensors
140
+
141
+ from transformers import (
142
+ T5TokenizerFast, T5EncoderModel,
143
+ CLIPTokenizerFast, CLIPTextModel
144
+ )
145
+ from diffusers import (
146
+ AutoencoderKL,
147
+ UNet2DConditionModel,
148
+ EulerAncestralDiscreteScheduler
149
+ )
150
+ from typing import Optional
151
+
152
+ # ─────────────────────────────────────────────────────────────
153
+ # 1) GLOBAL SETUP: load once, cast, eval, move
154
+ # ─────────────────────────────────────────────────────────────
155
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ DTYPE = torch.float16 # use fp16 for everything on GPU
157
+
158
+ # 1a) CLIP text encoder (cond + uncond)
159
+ clip_tok = CLIPTokenizerFast.from_pretrained(
160
+ "runwayml/stable-diffusion-v1-5", subfolder="tokenizer"
161
+ )
162
+ clip_mod = CLIPTextModel.from_pretrained(
163
+ "runwayml/stable-diffusion-v1-5", subfolder="text_encoder",
164
+ torch_dtype=DTYPE
165
+ ).to(DEVICE).eval()
166
+
167
+ # 1b) T5 encoder
168
+ t5_tok = T5TokenizerFast.from_pretrained("t5-small")
169
+ t5_mod = T5EncoderModel.from_pretrained(
170
+ "AbstractPhil/T5-Small-Human-Attentive-Try2-Pass3",
171
+ torch_dtype=DTYPE
172
+ ).to(DEVICE).eval()
173
+
174
+ # 1c) Adapter
175
+ import torch
176
+ import torch.nn as nn
177
+ import torch.nn.functional as F
178
+
179
+ import math
180
+ import torch
181
+ import torch.nn as nn
182
+ import torch.nn.functional as F
183
+
184
+ class RobustVelocityAdapter(nn.Module):
185
+ """
186
+ Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores
187
+ so that _add_rel_pos_bias can unpack them correctly.
188
+ """
189
+ def __init__(
190
+ self,
191
+ t5_dim: int = 512,
192
+ clip_dim: int = 768,
193
+ hidden_dim: int = 1024,
194
+ out_tokens: int = 64, # now aligned with your T5 finetune
195
+ self_attn_layers: int = 2,
196
+ cross_heads: int = 8,
197
+ max_rel_pos: int = 128,
198
+ ):
199
+ super().__init__()
200
+ self.out_tokens = out_tokens
201
+ self.cross_heads = cross_heads
202
+ self.head_dim = t5_dim // cross_heads
203
+ self.max_rel_pos = max_rel_pos
204
+
205
+ # 1) Self-attention stack
206
+ self.self_attn = nn.ModuleList()
207
+ self.self_norm = nn.ModuleList()
208
+ for _ in range(self_attn_layers):
209
+ self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True))
210
+ self.self_norm.append(nn.LayerNorm(t5_dim))
211
+
212
+ # 2) Residual blocks
213
+ def resblock():
214
+ return nn.Sequential(
215
+ nn.LayerNorm(t5_dim),
216
+ nn.Linear(t5_dim, t5_dim),
217
+ nn.GELU(),
218
+ nn.Linear(t5_dim, t5_dim),
219
+ )
220
+ self.res1 = resblock()
221
+ self.res2 = resblock()
222
+
223
+ # 3) Learned queries for cross-attn
224
+ self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim))
225
+
226
+ # 4) Projection heads
227
+ self.anchor_proj = nn.Sequential(
228
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
229
+ )
230
+ self.delta_proj = nn.Sequential(
231
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
232
+ )
233
+ self.var_proj = nn.Sequential(
234
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
235
+ )
236
+ self.gate_proj = nn.Sequential(
237
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid()
238
+ )
239
+
240
+ # 5) Relative-position bias table
241
+ self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads))
242
+
243
+ # 6) Norm after cross-attn
244
+ self.cross_norm = nn.LayerNorm(t5_dim)
245
+
246
+ def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor:
247
+ """
248
+ attn_scores: [B, heads, Q, K]
249
+ returns: attn_scores + bias where bias is [B, heads, Q, K]
250
+ """
251
+ B, H, Q, K = attn_scores.shape
252
+ device = attn_scores.device
253
+
254
+ # 1) Query & key position indices
255
+ idx_q = torch.arange(Q, device=device) # [Q]
256
+ idx_k = torch.arange(K, device=device) # [K]
257
+
258
+ # 2) Compute relative distances for every (q, k) pair
259
+ # rel[i,j] = idx_q[i] - idx_k[j]
260
+ rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) # [Q, K]
261
+
262
+ # 3) Clamp & shift into bias table range [0, 2*max_rel-2]
263
+ max_rel = self.max_rel_pos
264
+ rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1)
265
+
266
+ # 4) Lookup per-head biases
267
+ # self.rel_bias has shape [2*max_rel-1, H]
268
+ bias = self.rel_bias[rel] # [Q, K, H]
269
+ bias = bias.permute(2, 0, 1) # [H, Q, K]
270
+
271
+ # 5) Broadcast to [B, H, Q, K] and add
272
+ bias = bias.unsqueeze(0).expand(B, -1, -1, -1)
273
+ return attn_scores + bias
274
+
275
+
276
+ def forward(self, t5_seq: torch.Tensor):
277
+ """
278
+ t5_seq: [B, L, t5_dim]
279
+ returns:
280
+ anchor: [B, out_tokens, clip_dim]
281
+ delta: [B, out_tokens, clip_dim]
282
+ sigma: [B, out_tokens, clip_dim]
283
+ """
284
+ x = t5_seq
285
+ B, L, D = x.shape
286
+
287
+ # 1) Self-attention + residual
288
+ for attn, norm in zip(self.self_attn, self.self_norm):
289
+ res, _ = attn(x, x, x)
290
+ x = norm(x + res)
291
+
292
+ # 2) Residual blocks
293
+ x = x + self.res1(x)
294
+ x = x + self.res2(x)
295
+
296
+ # 3) Prepare queries & split heads
297
+ queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) # [B, Q, D]
298
+ # reshape into heads
299
+ q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3)
300
+ k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3)
301
+ v = k
302
+
303
+ # 4) Scaled dot-product to get [B, heads, Q, K]
304
+ scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
305
+ scores = self._add_rel_pos_bias(scores)
306
+ probs = F.softmax(scores, dim=-1) # [B, H, Q, K]
307
+
308
+ # 5) Attend & merge heads → [B, Q, D]
309
+ ctx = probs @ v # [B, H, Q, head_dim]
310
+ ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D)
311
+ ctx = self.cross_norm(ctx)
312
+
313
+ # 6) Project to anchor, delta_mean, delta_logvar, gate
314
+ anchor = self.anchor_proj(ctx)
315
+ delta_mean = self.delta_proj(ctx)
316
+ delta_logvar = self.var_proj(ctx)
317
+ gate = self.gate_proj(ctx)
318
+
319
+ # 7) Compute sigma & gated delta
320
+ sigma = torch.exp(0.5 * delta_logvar)
321
+ delta = delta_mean * gate
322
+
323
+ return anchor, delta, sigma
324
+
325
+ import torch
326
+ import torch.nn.functional as F
327
+ from PIL import Image
328
+ from torchvision.transforms import ToPILImage
329
+ from safetensors.torch import load_file as load_safetensors
330
+
331
+ from transformers import (
332
+ CLIPTokenizer, CLIPTextModel,
333
+ T5TokenizerFast, T5EncoderModel
334
+ )
335
+ from diffusers import (
336
+ AutoencoderKL,
337
+ UNet2DConditionModel,
338
+ EulerAncestralDiscreteScheduler
339
+ )
340
+
341
+ # 1) GLOBAL SETUP
342
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
343
+ DTYPE = torch.float32
344
+
345
+ # 1a) CLIP tokenizer & text encoder
346
+ clip_tok = CLIPTokenizer.from_pretrained(
347
+ "runwayml/stable-diffusion-v1-5", subfolder="tokenizer"
348
+ )
349
+ clip_mod = CLIPTextModel.from_pretrained(
350
+ "runwayml/stable-diffusion-v1-5", subfolder="text_encoder",
351
+ torch_dtype=DTYPE
352
+ ).to(DEVICE).eval()
353
+
354
+ # 1b) U-Net, VAE, Scheduler
355
+ unet = UNet2DConditionModel.from_pretrained(
356
+ "runwayml/stable-diffusion-v1-5", subfolder="unet",
357
+ torch_dtype=DTYPE
358
+ ).to(DEVICE).eval()
359
+ vae = AutoencoderKL.from_pretrained(
360
+ "runwayml/stable-diffusion-v1-5", subfolder="vae",
361
+ torch_dtype=DTYPE
362
+ ).to(DEVICE).eval()
363
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
364
+ "runwayml/stable-diffusion-v1-5", subfolder="scheduler"
365
+ )
366
+
367
+ # 1c) T5 + Adapter
368
+ t5_tok = T5TokenizerFast.from_pretrained("t5-small")
369
+ t5_mod = T5EncoderModel.from_pretrained(
370
+ "AbstractPhil/T5-Small-Human-Attentive-Try2-Pass3",
371
+ torch_dtype=DTYPE
372
+ ).to(DEVICE).eval()
373
+
374
+ adapter = RobustVelocityAdapter(out_tokens=64).to(DEVICE).eval()
375
+ state = load_safetensors("roba_adapter_step_19500.safetensors", device="cpu")
376
+ clean = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
377
+ adapter.load_state_dict(clean, strict=False)
378
+ adapter.to(DEVICE).eval()
379
+
380
+ # 2) GENERATION FUNCTION
381
+ @torch.no_grad()
382
+ def generate_image_with_adapter(
383
+ prompt: str,
384
+ seed: int = 42,
385
+ steps: int = 50,
386
+ adapter_scale: float = 0.5,
387
+ guidance_scale: float = 7.5,
388
+ height: int = 512,
389
+ width: int = 512,
390
+ ):
391
+ gen = torch.Generator(device=DEVICE).manual_seed(seed)
392
+
393
+ # 2.1) CLIP embeddings
394
+ clip_in = clip_tok([prompt],
395
+ max_length=clip_tok.model_max_length,
396
+ padding="max_length", truncation=True,
397
+ return_tensors="pt").to(DEVICE)
398
+ clip_cond = clip_mod(**clip_in).last_hidden_state # [1,77,768]
399
+
400
+ empty_in = clip_tok([""],
401
+ max_length=clip_tok.model_max_length,
402
+ padding="max_length", truncation=True,
403
+ return_tensors="pt").to(DEVICE)
404
+ clip_uncond= clip_mod(**empty_in).last_hidden_state # [1,77,768]
405
+
406
+ # 2.2) T5 → adapter → anchor, delta, sigma (64 tokens)
407
+ t5_in = t5_tok(prompt,
408
+ max_length=64, padding="max_length",
409
+ truncation=True, return_tensors="pt").to(DEVICE)
410
+ t5_seq = t5_mod(**t5_in).last_hidden_state # [1,64,512]
411
+ anchor, delta, sigma = adapter(t5_seq) # each [1,64,768]
412
+
413
+ # 2.3) Upsample to 77 tokens
414
+ T_clip = clip_cond.shape[1] # 77
415
+ def up(x):
416
+ return F.interpolate(
417
+ x.permute(0,2,1),
418
+ size=T_clip, mode="linear", align_corners=False
419
+ ).permute(0,2,1)
420
+ anchor = up(anchor)
421
+ delta = up(delta)
422
+ sigma = up(sigma)
423
+
424
+ # 2.4) σ-based noise scaling
425
+ raw_ns = sigma.mean().clamp(0.1, 2.0).item()
426
+ noise_scale = 1.0 + adapter_scale * (raw_ns - 1.0)
427
+
428
+ # 2.5) Initialize latents
429
+ latents = torch.randn(
430
+ (1, unet.config.in_channels, height//8, width//8),
431
+ generator=gen, device=DEVICE, dtype=DTYPE
432
+ ) * scheduler.init_noise_sigma * noise_scale
433
+ scheduler.set_timesteps(steps, device=DEVICE)
434
+
435
+ # 2.6) Denoising with adapter guidance
436
+ for i, t in enumerate(scheduler.timesteps):
437
+ alpha = i / (len(scheduler.timesteps)-1)
438
+ aw = adapter_scale * alpha
439
+ cw = 1.0 - aw
440
+
441
+ # blend anchors
442
+ blended = clip_cond * cw + anchor * aw
443
+
444
+ # per-token confidence
445
+ eps = 1e-6
446
+ conf = 1.0 / (sigma + eps)
447
+ conf = conf / conf.amax(dim=(1,2), keepdim=True)
448
+
449
+ # gated delta
450
+ gated_delta = delta * aw * conf
451
+
452
+ # final cond embedding
453
+ cond_embed = blended + gated_delta # [1,77,768]
454
+
455
+ # UNet forward
456
+ lat_in = scheduler.scale_model_input(latents, t)
457
+ lat_in = torch.cat([lat_in, lat_in], dim=0)
458
+ embeds = torch.cat([clip_uncond, cond_embed], dim=0)
459
+ noise = unet(lat_in, t, encoder_hidden_states=embeds).sample
460
+ u, c = noise.chunk(2)
461
+ guided = u + guidance_scale * (c - u)
462
+ latents= scheduler.step(guided, t, latents, generator=gen).prev_sample
463
+
464
+ # 2.7) Decode
465
+ dec_lat = latents / vae.config.scaling_factor
466
+ image_t = vae.decode(dec_lat).sample
467
+ image_t = (image_t.clamp(-1,1) + 1) / 2
468
+ return ToPILImage()(image_t[0])
469
+
470
+ # 3) RUN EXAMPLE
471
+ if __name__ == "__main__":
472
+ out = generate_image_with_adapter(
473
+ "silly dog wearing a batman costume, high resolution, studio lighting",
474
+ seed=1234, steps=50,
475
+ adapter_scale=0.5, guidance_scale=7.5
476
+ )
477
+ out.save("sd15_with_adapter.png")
478
+ print("Saved sd15_with_adapter.png")