crypto-code commited on
Commit
212945c
·
1 Parent(s): c4f1082

Update llama/m2ugen.py

Browse files
Files changed (1) hide show
  1. llama/m2ugen.py +70 -25
llama/m2ugen.py CHANGED
@@ -231,9 +231,9 @@ class M2UGen(nn.Module):
231
  self.music_decoder = self.args.music_decoder.lower()
232
 
233
  # 4. prefix
234
- self.query_layer = 20
235
  self.query_len = 1
236
- self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:0")
237
 
238
  # 5. knn
239
  self.knn = knn
@@ -492,30 +492,52 @@ class M2UGen(nn.Module):
492
  h = self.llama.tok_embeddings(tokens).to("cuda:0")
493
  freqs_cis = self.llama.freqs_cis.to("cuda:0")
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
495
-
496
- feats = torch.zeros((1, 1, 4096)).to("cuda:0")
497
- if audio_feats is not None:
498
- feats += audio_feats
499
- if video_feats is not None:
500
- feats += video_feats
501
- if image_feats is not None:
502
- feats += image_feats
503
 
504
  mask = None
505
  mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
506
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
507
 
508
  music_output_embedding = []
509
- for layer in self.llama.layers[:-1 * self.query_layer]:
510
  h = layer(h, 0, freqs_cis, mask)
511
  music_output_embedding.append(h)
512
 
513
- prefix_query = self.prefix_query.weight.reshape(self.query_layer, 1, 4096).unsqueeze(1)
 
514
 
515
  prefix_index = 0
516
- for layer in self.llama.layers[-1 * self.query_layer:]:
517
- h = layer(h, 0, freqs_cis, mask, feats + prefix_query[prefix_index])
518
- prefix_index = prefix_index + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  h = self.llama.norm(h)
521
  output = self.llama.output(h[:, -1, :])
@@ -523,30 +545,53 @@ class M2UGen(nn.Module):
523
  return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
524
 
525
  def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
526
- feats = torch.zeros((1, 1, 4096)).to(self.device)
527
  if audios is not None:
528
- feats += self.forward_audio({'Audio': [audios, 1]})
529
  if videos is not None:
530
- feats += self.forward_video({'Video': [videos, 1]})
531
  if imgs is not None:
532
- feats += self.forward_image({'Image': [imgs, 1]})
533
  _bsz, seqlen = tokens.shape
534
 
535
  h = self.llama.tok_embeddings(tokens.to(self.device))
536
  freqs_cis = self.llama.freqs_cis.to(h.device)
537
  freqs_cis = freqs_cis[:seqlen]
538
- mask = None
539
  mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
540
  mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
541
 
542
- for layer in self.llama.layers[:-1 * self.query_layer]:
543
  h = layer(h, 0, freqs_cis, mask)
544
- prefix_query = self.prefix_query.weight.reshape(self.query_layer, 1, 4096).unsqueeze(1)
 
 
545
  prefix_index = 0
 
 
 
 
 
 
 
 
546
 
547
- for layer in self.llama.layers[-1 * self.query_layer:]:
548
- h = layer(h, 0, freqs_cis, mask, feats + prefix_query[prefix_index])
549
- prefix_index = prefix_index + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  final_hidden = h
552
  h = self.llama.norm(h)
 
231
  self.music_decoder = self.args.music_decoder.lower()
232
 
233
  # 4. prefix
234
+ self.query_layer = 6
235
  self.query_len = 1
236
+ self.prefix_query = nn.Embedding(self.query_layer * 3 * self.query_len, self.model_args.dim).to("cuda:0")
237
 
238
  # 5. knn
239
  self.knn = knn
 
492
  h = self.llama.tok_embeddings(tokens).to("cuda:0")
493
  freqs_cis = self.llama.freqs_cis.to("cuda:0")
494
  freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
 
 
 
 
 
 
 
 
495
 
496
  mask = None
497
  mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
498
  mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
499
 
500
  music_output_embedding = []
501
+ for layer in self.llama.layers[:-3 * self.query_layer]:
502
  h = layer(h, 0, freqs_cis, mask)
503
  music_output_embedding.append(h)
504
 
505
+ prefix_query = self.prefix_query.weight.reshape(
506
+ self.query_layer * 3, 1, 4096).unsqueeze(1)
507
 
508
  prefix_index = 0
509
+ if audio_feats is not None:
510
+ for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
511
+ h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
512
+ music_output_embedding.append(h)
513
+ prefix_index = prefix_index + 1
514
+ else:
515
+ for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
516
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
517
+ music_output_embedding.append(h)
518
+ prefix_index = prefix_index + 1
519
+
520
+ if image_feats is not None:
521
+ for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
522
+ h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
523
+ music_output_embedding.append(h)
524
+ prefix_index = prefix_index + 1
525
+ else:
526
+ for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
527
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
528
+ music_output_embedding.append(h)
529
+ prefix_index = prefix_index + 1
530
+
531
+ if video_feats is not None:
532
+ for layer in self.llama.layers[-1 * self.query_layer:]:
533
+ h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
534
+ music_output_embedding.append(h)
535
+ prefix_index = prefix_index + 1
536
+ else:
537
+ for layer in self.llama.layers[-1 * self.query_layer:]:
538
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
539
+ music_output_embedding.append(h)
540
+ prefix_index = prefix_index + 1
541
 
542
  h = self.llama.norm(h)
543
  output = self.llama.output(h[:, -1, :])
 
545
  return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
546
 
547
  def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
548
+ audio_feats, video_feats, image_feats = None, None, None
549
  if audios is not None:
550
+ audio_feats = self.forward_audio({'Audio': [audios, 1]})
551
  if videos is not None:
552
+ video_feats = self.forward_video({'Video': [videos, 1]})
553
  if imgs is not None:
554
+ image_feats = self.forward_image({'Image': [imgs, 1]})
555
  _bsz, seqlen = tokens.shape
556
 
557
  h = self.llama.tok_embeddings(tokens.to(self.device))
558
  freqs_cis = self.llama.freqs_cis.to(h.device)
559
  freqs_cis = freqs_cis[:seqlen]
 
560
  mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
561
  mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
562
 
563
+ for layer in self.llama.layers[:-3 * self.query_layer]:
564
  h = layer(h, 0, freqs_cis, mask)
565
+ prefix_query = self.prefix_query.weight.reshape(
566
+ self.query_layer * 3, 1, 4096).unsqueeze(1)
567
+
568
  prefix_index = 0
569
+ if audio_feats is not None:
570
+ for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
571
+ h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
572
+ prefix_index = prefix_index + 1
573
+ else:
574
+ for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
575
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
576
+ prefix_index = prefix_index + 1
577
 
578
+ if image_feats is not None:
579
+ for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
580
+ h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
581
+ prefix_index = prefix_index + 1
582
+ else:
583
+ for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
584
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
585
+ prefix_index = prefix_index + 1
586
+
587
+ if video_feats is not None:
588
+ for layer in self.llama.layers[-1 * self.query_layer:]:
589
+ h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
590
+ prefix_index = prefix_index + 1
591
+ else:
592
+ for layer in self.llama.layers[-1 * self.query_layer:]:
593
+ h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
594
+ prefix_index = prefix_index + 1
595
 
596
  final_hidden = h
597
  h = self.llama.norm(h)