Dionyssos commited on
Commit
b399825
·
1 Parent(s): 86b9ce4
api.py CHANGED
@@ -2,7 +2,6 @@
2
  # -*- coding: utf-8 -*-
3
  import numpy as np
4
  import soundfile
5
- import audresample
6
  from Utils.text_utils import split_into_sentences
7
  import msinference
8
  import re
@@ -15,10 +14,12 @@ from flask import Flask, request, send_from_directory
15
  from moviepy.video.io.VideoFileClip import VideoFileClip
16
  from moviepy.video.VideoClip import ImageClip
17
  from audiocraft.builders import AudioGen
18
- CACHE_DIR = 'flask_cache/'
19
- NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
20
 
21
- sound_generator = AudioGen(duration=4.74, device='cuda:0').to('cuda:0').eval()
 
 
 
 
22
 
23
  Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
24
 
@@ -57,62 +58,17 @@ def _resize(image, width=None, height=None, inter=cv2.INTER_AREA):
57
  # return the resized image
58
  return resized
59
 
60
-
61
-
62
- def _shift(x):
63
- n = x.shape[0]
64
- i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0
65
- x = np.roll(x, i)
66
- # we can add the one or fade it and then amplify
67
- # the audio is so short 6s that is difficult to not hear the shift somewhere
68
- # Just concatenate - raw - and then shift - the longconcat audio - many times may fix it
69
- # fade_in = 1 - .5 * np.tanh(-4*(np.linspace(-10, 10, n) - 9.4)) + .5 * np.tanh(4*(np.linspace(-10, 10, n) + 9.4))
70
- return x #* fade_in # silence this
71
-
72
  def overlay(x, soundscape=None):
73
-
74
  if soundscape is not None:
75
-
76
- # SOUNDS
77
-
78
- background = sound_generator.generate(
79
- [soundscape] * NUM_SOUND_GENERATIONS
80
- ).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
81
-
82
- # upsample 16 kHz AudioGen to 24kHZ of VITS/StyleTTS2
83
-
84
- print('Resampling') # soundscape each generation in batch differs from the other generations thus clone/shift each element in batch, finally concat w/o shift
85
-
86
-
87
- background = audresample.resample(
88
- background,
89
- original_rate=16000, # sound_generator.sample_rate,
90
- target_rate=24000)[0, :-250] # last samples have splash sounds DISCARD 25000 last samples
91
-
92
 
 
 
 
93
 
 
94
 
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
- n_repeat = len(x) // background.shape[0] + 1
104
-
105
- total = np.tile(background, n_repeat)
106
-
107
- # less periodic
108
-
109
- for _ in range(4):
110
- total = _shift(total)
111
-
112
- # amplify sounds full [-1,1]
113
-
114
- total /= np.abs(total).max() + 1e-7
115
- x = .5 * x + .5 * total[:len(x)]
116
 
117
  else:
118
 
 
2
  # -*- coding: utf-8 -*-
3
  import numpy as np
4
  import soundfile
 
5
  from Utils.text_utils import split_into_sentences
6
  import msinference
7
  import re
 
14
  from moviepy.video.io.VideoFileClip import VideoFileClip
15
  from moviepy.video.VideoClip import ImageClip
16
  from audiocraft.builders import AudioGen
 
 
17
 
18
+ CACHE_DIR = 'flask_cache/'
19
+ PIECE_OF_SOUND_DURATION = 4.74 # seconds
20
+ sound_generator = AudioGen(
21
+ duration=PIECE_OF_SOUND_DURATION
22
+ ).to('cuda:0').eval()
23
 
24
  Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
25
 
 
58
  # return the resized image
59
  return resized
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def overlay(x, soundscape=None):
62
+ # pre-calculate the n_repeat here then apply torchaudio.resample and repeat insd sound_gen forward()
63
  if soundscape is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ background = sound_generator.generate(soundscape,
66
+ n_repeat=int(len(x) / (PIECE_OF_SOUND_DURATION * 16000)) + 1
67
+ ).detach().cpu().numpy() # bs, 11400 @.74s
68
 
69
+ # blend TTS
70
 
71
+ x = .5 * x + .5 * background[:len(x)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  else:
74
 
audiocraft/activations.py DELETED
@@ -1,96 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
- import torch.nn as nn
9
- from torch import Tensor
10
- from typing import Union, Callable
11
-
12
-
13
- class CustomGLU(nn.Module):
14
- """Custom Gated Linear Unit activation.
15
- Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16
- of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17
- function (i.e. sigmoid, swish, etc.).
18
-
19
- Args:
20
- activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21
- dim (int): the dimension on which to split the input. Default: -1
22
-
23
- Shape:
24
- - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25
- dimensions
26
- - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27
-
28
- Examples::
29
- >>> m = CustomGLU(nn.Sigmoid())
30
- >>> input = torch.randn(4, 2)
31
- >>> output = m(input)
32
- """
33
- def __init__(self, activation: nn.Module, dim: int = -1):
34
- super(CustomGLU, self).__init__()
35
- self.dim = dim
36
- self.activation = activation
37
-
38
- def forward(self, x: Tensor):
39
- assert x.shape[self.dim] % 2 == 0 # M = N / 2
40
- a, b = torch.chunk(x, 2, dim=self.dim)
41
- return a * self.activation(b)
42
-
43
-
44
- class SwiGLU(CustomGLU):
45
- """SiLU Gated Linear Unit activation.
46
- Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47
- the first half of the input matrices, :math:`b` is the second half.
48
-
49
- Args:
50
- dim (int): the dimension on which to split the input. Default: -1
51
- """
52
- def __init__(self, dim: int = -1):
53
- super(SwiGLU, self).__init__(nn.SiLU(), dim)
54
-
55
-
56
- class GeGLU(CustomGLU):
57
- """GeLU Gated Linear Unit activation.
58
- Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59
- the first half of the input matrices, :math:`b` is the second half.
60
-
61
- Args:
62
- dim (int): the dimension on which to split the input. Default: -1
63
- """
64
- def __init__(self, dim: int = -1):
65
- super(GeGLU, self).__init__(nn.GELU(), dim)
66
-
67
-
68
- class ReGLU(CustomGLU):
69
- """ReLU Gated Linear Unit activation.
70
- Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71
- the first half of the input matrices, :math:`b` is the second half.
72
-
73
- Args:
74
- dim (int): the dimension on which to split the input. Default: -1
75
- """
76
- def __init__(self, dim: int = -1):
77
- super(ReGLU, self).__init__(nn.ReLU(), dim)
78
-
79
-
80
- def get_activation_fn(
81
- activation: Union[str, Callable[[Tensor], Tensor]]
82
- ) -> Union[str, Callable[[Tensor], Tensor]]:
83
- """Helper function to map an activation string to the activation class.
84
- If the supplied activation is not a string that is recognized, the activation is passed back.
85
-
86
- Args:
87
- activation (str, or Callable[[Tensor], Tensor]): Activation to check
88
- """
89
- if isinstance(activation, str):
90
- if activation == "reglu":
91
- return ReGLU()
92
- elif activation == "geglu":
93
- return GeGLU()
94
- elif activation == "swiglu":
95
- return SwiGLU()
96
- return activation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/builders.py CHANGED
@@ -1,22 +1,25 @@
1
- import typing as tp
2
  import omegaconf
 
3
  from torch import nn
4
  import torch
 
5
  from huggingface_hub import hf_hub_download
6
  import os
7
- from omegaconf import OmegaConf, DictConfig
8
-
9
  from .encodec import EncodecModel
10
  from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
- from .codebooks_patterns import DelayedPatternProvider
13
- from .conditioners import T5Conditioner
14
  from .vq import ResidualVectorQuantizer
15
 
 
 
 
 
 
 
 
16
 
17
-
18
-
19
- def _delete_param(cfg: DictConfig, full_name: str):
20
  parts = full_name.split('.')
21
  for part in parts[:-1]:
22
  if part in cfg:
@@ -35,48 +38,53 @@ def dict_from_config(cfg):
35
  return dct
36
 
37
 
38
-
39
-
40
-
41
-
42
-
43
-
44
- # ============================================== DEFINE AUDIOGEN
45
-
46
-
47
-
48
-
49
-
50
-
51
  class AudioGen(nn.Module):
52
 
53
  # https://huggingface.co/facebook/audiogen-medium
54
 
55
  def __init__(self,
56
- duration=0.024,
57
- device='cpu'):
58
 
59
  super().__init__()
60
- self.device = device # needed for loading & select float16 LM
61
  self.load_compression_model()
62
  self.load_lm_model()
63
  self.duration = duration
 
 
64
 
65
  @property
66
  def frame_rate(self):
67
  return self.compression_model.frame_rate
68
 
69
  def generate(self,
70
- descriptions):
 
 
71
  with torch.no_grad():
72
  gen_tokens = self.lm.generate(
73
- descriptions=descriptions,
74
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
75
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
76
- # print('______________\nAudioGen Tokens', gen_tokens)
 
 
 
77
 
 
 
 
 
 
 
 
78
 
79
- return x / x.abs().max(2, keepdims=True)[0] + 1e-7
 
 
 
 
 
80
 
81
  # == BUILD Fn
82
  def get_quantizer(self, quantizer, cfg, dimension):
@@ -126,58 +134,7 @@ class AudioGen(nn.Module):
126
  ).to(cfg.device)
127
  else:
128
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
129
-
130
-
131
- def get_lm_model(self, cfg):
132
- """Instantiate a transformer LM."""
133
- if cfg.lm_model in ['transformer_lm',
134
- 'transformer_lm_magnet']:
135
- kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
136
- n_q = kwargs['n_q']
137
- q_modeling = kwargs.pop('q_modeling', None)
138
- codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
139
- attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
140
- cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
141
- cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
142
-
143
-
144
-
145
- # if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
146
- kwargs['cross_attention'] = True
147
- if codebooks_pattern_cfg.modeling is None:
148
- print('Q MODELING\n=\n=><')
149
- assert q_modeling is not None, \
150
- "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
151
- codebooks_pattern_cfg = omegaconf.OmegaConf.create(
152
- {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
153
- )
154
-
155
- pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
156
- return LMModel(
157
- pattern_provider=pattern_provider,
158
- condition_provider=T5Conditioner(name='t5-large', output_dim=kwargs["dim"], device=self.device),
159
- cfg_dropout=cfg_prob,
160
- cfg_coef=cfg_coef,
161
- attribute_dropout=attribute_dropout,
162
- dtype=getattr(torch, cfg.dtype),
163
- device=self.device,
164
- **kwargs
165
- ).to(cfg.device)
166
- else:
167
- raise KeyError(f"Unexpected LM model {cfg.lm_model}")
168
-
169
-
170
- def get_codebooks_pattern_provider(self, n_q, cfg):
171
- pattern_providers = {
172
- 'delay': DelayedPatternProvider, # THIS
173
- }
174
- name = cfg.modeling
175
- kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
176
-
177
- klass = pattern_providers[name]
178
- return klass(n_q, **kwargs)
179
-
180
- # ======================
181
  def load_compression_model(self):
182
  file = hf_hub_download(
183
  repo_id='facebook/audiogen-medium',
@@ -204,24 +161,20 @@ class AudioGen(nn.Module):
204
  library_name="audiocraft",
205
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
206
  pkg = torch.load(file,
207
- map_location=self.device) #'cpu')
208
- cfg = OmegaConf.create(pkg['xp.cfg'])
209
- # cfg.device = 'cpu'
210
- if self.device == 'cpu':
211
- cfg.dtype = 'float32'
212
- else:
213
- cfg.dtype = 'float16'
214
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
215
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
216
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
217
- model = self.get_lm_model(cfg)
218
-
 
 
219
  _best = pkg['best_state']
220
  _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
221
  _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
222
  model.load_state_dict(pkg['best_state'])
223
- model.cfg = cfg
224
- # return model
225
  self.lm = model.to(torch.float)
226
 
227
  # def _flush(self):
 
 
1
  import omegaconf
2
+ import torchaudio
3
  from torch import nn
4
  import torch
5
+ import numpy as np
6
  from huggingface_hub import hf_hub_download
7
  import os
8
+ from omegaconf import OmegaConf
 
9
  from .encodec import EncodecModel
10
  from .lm import LMModel
11
  from .seanet import SEANetDecoder
 
 
12
  from .vq import ResidualVectorQuantizer
13
 
14
+ def _shift(x):
15
+ # [bs, samples] shift circular each batch elem of sound
16
+ n = x.shape[1]
17
+ for i, batch_elem in enumerate(x):
18
+ offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
19
+ x[i, :] = torch.roll(batch_elem, offset, dims=0) # batch_elem = [400000, ]
20
+ return x
21
 
22
+ def _delete_param(cfg, full_name):
 
 
23
  parts = full_name.split('.')
24
  for part in parts[:-1]:
25
  if part in cfg:
 
38
  return dct
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class AudioGen(nn.Module):
42
 
43
  # https://huggingface.co/facebook/audiogen-medium
44
 
45
  def __init__(self,
46
+ duration=2.24, # s
47
+ ):
48
 
49
  super().__init__()
 
50
  self.load_compression_model()
51
  self.load_lm_model()
52
  self.duration = duration
53
+ # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
54
+ self.resample_fn = torchaudio.transforms.Resample(16000, 24000)
55
 
56
  @property
57
  def frame_rate(self):
58
  return self.compression_model.frame_rate
59
 
60
  def generate(self,
61
+ descriptions,
62
+ n_repeat=3):
63
+
64
  with torch.no_grad():
65
  gen_tokens = self.lm.generate(
66
+ descriptions=[descriptions]*3,
67
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
68
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
69
+
70
+ x = x[:, 0, :-250] # last samples have splash sounds DISCARD 25000 last samples
71
+
72
+ # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
73
 
74
+ # x = self.resample_fn(x)
75
+
76
+ # batch size = different sounds for same txt
77
+
78
+ x = x.repeat(1, n_repeat)
79
+
80
+ # less periodic - shift every batch elem
81
 
82
+ for _ in range(7):
83
+ x = _shift(x)
84
+
85
+ x = x.reshape(-1)
86
+ print(x.abs().max(), 'MAX')
87
+ return x / (x.abs().max() + 1e-7)
88
 
89
  # == BUILD Fn
90
  def get_quantizer(self, quantizer, cfg, dimension):
 
134
  ).to(cfg.device)
135
  else:
136
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
137
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def load_compression_model(self):
139
  file = hf_hub_download(
140
  repo_id='facebook/audiogen-medium',
 
161
  library_name="audiocraft",
162
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
163
  pkg = torch.load(file,
164
+ map_location='cpu')
165
+ cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
 
 
 
 
 
166
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
167
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
168
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
169
+ print('___________________________CFG___________________',cfg,'\n=======================')
170
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
171
+ print('___________________________Kwarg___________________',kwargs,'\n=======================')
172
+ model = LMModel().to(getattr(torch, cfg.dtype)) #.to(cfg.device)
173
  _best = pkg['best_state']
174
  _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
175
  _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
176
  model.load_state_dict(pkg['best_state'])
177
+ # model.cfg = cfg
 
178
  self.lm = model.to(torch.float)
179
 
180
  # def _flush(self):
audiocraft/conditioners.py CHANGED
@@ -25,7 +25,7 @@ class T5Conditioner(nn.Module):
25
  def __init__(self,
26
  name,
27
  output_dim,
28
- device,
29
  finetune=False):
30
  print(f'{finetune=}')
31
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
@@ -36,7 +36,7 @@ class T5Conditioner(nn.Module):
36
  self.device = device
37
  self.name = name
38
 
39
- self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
40
  t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
41
  if finetune:
42
  self.t5 = t5
@@ -65,7 +65,7 @@ class T5Conditioner(nn.Module):
65
  embeds = self.t5(input_ids=d['input_ids'],
66
  attention_mask=d['attention_mask']
67
  ).last_hidden_state # no kvcache for txt conditioning
68
- embeds = self.output_proj(embeds.to(self.output_proj.weight))
69
  embeds = (embeds * d['attention_mask'].unsqueeze(-1))
70
 
71
  return embeds # , d['attention_mask']
 
25
  def __init__(self,
26
  name,
27
  output_dim,
28
+ device='cuda:0',
29
  finetune=False):
30
  print(f'{finetune=}')
31
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
 
36
  self.device = device
37
  self.name = name
38
 
39
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name, legacy=True)
40
  t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
41
  if finetune:
42
  self.t5 = t5
 
65
  embeds = self.t5(input_ids=d['input_ids'],
66
  attention_mask=d['attention_mask']
67
  ).last_hidden_state # no kvcache for txt conditioning
68
+ embeds = self.output_proj(embeds.to(self.output_proj.weight))
69
  embeds = (embeds * d['attention_mask'].unsqueeze(-1))
70
 
71
  return embeds # , d['attention_mask']
audiocraft/lm.py CHANGED
@@ -1,237 +1,45 @@
1
- from dataclasses import dataclass
2
- import logging
3
- import math
4
- import typing as tp
5
  import torch
6
  import torch.nn.functional as F
7
  from audiocraft.transformer import StreamingTransformer
8
- from dataclasses import dataclass
9
- from functools import partial
10
  from torch import nn
11
- from audiocraft.activations import get_activation_fn
 
12
  import numpy as np
13
 
14
- def _shift(x):
15
- # cyclic shift of [1, 4, seq_len] slices from [bs, 4, seq_len]
16
- print(x.shape, 'SHIFT\n= = = = = ')
17
- for i, _slice in enumerate(x):
18
- n = x.shape[2]
19
- offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
20
- print(offset)
21
- x[i, :, :] = torch.roll(_slice, offset, dims=1)
22
- return x
23
-
24
-
25
-
26
- def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
27
- """LM layer initialization.
28
- Inspired from xlformers: https://github.com/fairinternal/xlformers
29
-
30
- Args:
31
- method (str): Method name for init function. Valid options are:
32
- 'gaussian', 'uniform'.
33
- input_dim (int): Input dimension of the initialized module.
34
- init_depth (int, optional): Optional init depth value used to rescale
35
- the standard deviation if defined.
36
- """
37
- # Compute std
38
- std = 1 / math.sqrt(input_dim)
39
- # Rescale with depth
40
- if init_depth is not None:
41
- std = std / math.sqrt(2 * init_depth)
42
-
43
- if method == 'gaussian':
44
- return partial(
45
- torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
46
- )
47
- elif method == 'uniform':
48
- bound = math.sqrt(3) * std # ensure the standard deviation is `std`
49
- return partial(torch.nn.init.uniform_, a=-bound, b=bound)
50
- else:
51
- raise ValueError("Unsupported layer initialization method")
52
-
53
-
54
- def init_layer(m: nn.Module,
55
- method: str,
56
- init_depth: tp.Optional[int] = None,
57
- zero_bias_init: bool = False):
58
- """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
59
-
60
- Args:
61
- m (nn.Module): Module to initialize.
62
- method (str): Method name for the init function.
63
- init_depth (int, optional): Optional init depth value used to rescale
64
- the standard deviation if defined.
65
- zero_bias_init (bool): Whether to initialize the bias to 0 or not.
66
- """
67
- if isinstance(m, nn.Linear):
68
- init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
69
- if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
70
- weight = m.weight.float()
71
- init_fn(weight)
72
- m.weight.data[:] = weight.half()
73
- else:
74
- init_fn(m.weight)
75
- if zero_bias_init and m.bias is not None:
76
- nn.init.constant_(m.bias, 0)
77
- elif isinstance(m, nn.Embedding):
78
- init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
79
- if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
80
- weight = m.weight.float()
81
- init_fn(weight)
82
- m.weight.data[:] = weight.half()
83
- else:
84
- init_fn(m.weight)
85
-
86
-
87
- class ScaledEmbedding(nn.Embedding):
88
- """Boost learning rate for embeddings (with `scale`).
89
- """
90
- def __init__(self, *args, lr=None, **kwargs):
91
- super().__init__(*args, **kwargs)
92
- self.lr = lr
93
-
94
- def make_optim_group(self):
95
- group = {"params": list(self.parameters())}
96
- if self.lr is not None:
97
- group["lr"] = self.lr
98
- return group
99
-
100
-
101
- @dataclass
102
- class LMOutput:
103
- # The logits are already re-aligned with the input codes
104
- # hence no extra shift is required, e.g. when computing CE
105
- logits: torch.Tensor # [B, K, T, card]
106
- mask: torch.Tensor # [B, K, T]
107
-
108
 
109
  class LMModel(nn.Module):
110
- """Transformer-based language model on multiple streams of codes.
111
-
112
- Args:
113
- pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
114
- condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
115
- fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
116
- n_q (int): Number of parallel streams to model.
117
- card (int): Cardinality, vocabulary size.
118
- dim (int): Dimension of the transformer encoder.
119
- num_heads (int): Number of heads for the transformer encoder.
120
- hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
121
- norm (str): Normalization method.
122
- norm_first (bool): Use pre-norm instead of post-norm.
123
- emb_lr (float, optional): Embedding-specific learning rate.
124
- bias_proj (bool): Use bias for output projections.
125
- weight_init (str, optional): Method for weight initialization.
126
- depthwise_init (str, optional): Method for depthwise weight initialization.
127
- zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
128
- cfg_dropout (float): Classifier-free guidance dropout.
129
- cfg_coef (float): Classifier-free guidance coefficient.
130
- attribute_dropout (dict): Attribute dropout probabilities.
131
- two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
132
- **kwargs: Additional parameters for the transformer encoder.
133
- """
134
  def __init__(self,
135
- pattern_provider,
136
- condition_provider,
137
- n_q: int = 8,
138
- card: int = 1024,
139
- dim: int = 128,
140
- num_heads: int = 8,
141
- hidden_scale: int = 4,
142
- norm: str = 'layer_norm',
143
- norm_first: bool = False,
144
- emb_lr: tp.Optional[float] = None,
145
- bias_proj: bool = True,
146
- weight_init: tp.Optional[str] = None,
147
- depthwise_init: tp.Optional[str] = None,
148
- zero_bias_init: bool = False, cfg_dropout: float = 0,
149
- cfg_coef: float = 1.0,
150
- two_step_cfg: bool = False,
151
- **kwargs):
152
  super().__init__()
153
- self.cfg_coef = cfg_coef
154
- self.condition_provider = condition_provider
155
  self.card = card # 2048 ?
156
  self.n_draw = 1 # replicate so many times the generation of each text in batch
 
 
157
  embed_dim = self.card + 1
158
  self.n_q = n_q
159
  self.dim = dim
160
- self.pattern_provider = pattern_provider
161
- self.two_step_cfg = two_step_cfg
162
- self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
163
- if 'activation' in kwargs:
164
- kwargs['activation'] = get_activation_fn(kwargs['activation'])
165
- # ========================================================================
166
- # {
167
- # 'dtype': torch.float16, 'device': 'cuda',
168
- # 'num_layers': 48, 'dropout': 0.0, 'activation': 'gelu',
169
- # 'bias_ff': False, 'bias_attn': False,
170
- # 'past_context': None, 'causal': True,
171
- # 'custom': False, 'memory_efficient': True,
172
- # 'attention_as_float32': False, 'positional_embedding': 'sin', 'xpos': False,
173
- # 'checkpointing': 'none', 'cross_attention': True, 'qk_layer_norm': False,
174
- # 'qk_layer_norm_cross': False, 'attention_dropout': None, 'kv_repeat': 1
175
- # }
176
- # ==========================================================================
177
- kwargs.pop('layer_scale') # nn.Indentity()
178
-
179
  self.transformer = StreamingTransformer(
180
  d_model=dim,
181
  num_heads=num_heads,
182
  dim_feedforward=int(hidden_scale * dim),
183
- norm=norm,
184
- norm_first=norm_first, **kwargs)
185
- self.out_norm: tp.Optional[nn.Module] = None
186
- if norm_first:
187
- self.out_norm = nn.LayerNorm(dim, eps=1e-5)
188
- self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
189
- self._init_weights(weight_init, depthwise_init, zero_bias_init)
190
- self._fsdp: tp.Optional[nn.Module]
191
- self.__dict__['_fsdp'] = None
192
-
193
- def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
194
- """Initialization of the transformer module weights.
195
-
196
- Args:
197
- weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
198
- depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
199
- 'current' where the depth corresponds to the current layer index or 'global' where the total number
200
- of layer is used as depth. If not set, no depthwise initialization strategy is used.
201
- zero_bias_init (bool): Whether to initialize bias to zero or not.
202
- """
203
- assert depthwise_init is None or depthwise_init in ['current', 'global']
204
- assert depthwise_init is None or weight_init is not None, \
205
- "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
206
- assert not zero_bias_init or weight_init is not None, \
207
- "If 'zero_bias_init', a 'weight_init' method should be provided"
208
-
209
- if weight_init is None:
210
- return
211
-
212
- for emb_layer in self.emb:
213
- init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
214
-
215
- for layer_idx, tr_layer in enumerate(self.transformer.layers):
216
- depth = None
217
- if depthwise_init == 'current':
218
- depth = layer_idx + 1
219
- elif depthwise_init == 'global':
220
- depth = len(self.transformer.layers)
221
- init_fn = partial(init_layer,
222
- method=weight_init,
223
- init_depth=depth,
224
- zero_bias_init=zero_bias_init)
225
- tr_layer.apply(init_fn)
226
-
227
- for linear in self.linears:
228
- init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
229
 
230
- @property
231
- def special_token_id(self) -> int:
232
- return self.card
233
-
234
-
235
 
236
  def forward(self,
237
  sequence,
@@ -293,7 +101,7 @@ class LMModel(nn.Module):
293
  max_gen_len), -1, dtype=torch.long,
294
  device=text_condition.device)
295
 
296
- gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
297
  _, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
298
 
299
  # print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
@@ -313,7 +121,7 @@ class LMModel(nn.Module):
313
  for offset in range(1, audiodur):
314
 
315
  # forward duplicates the query to nullcond - then cfg & returns deduplicate token
316
- next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
317
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
318
  token_count=offset-1) # [bs, 4, 1, 2048]
319
 
@@ -322,7 +130,7 @@ class LMModel(nn.Module):
322
 
323
  # MASK is not full 1---- HAS 4 x audioduration PATTERN
324
  m = mask[:, :, :, offset]
325
- next_token[~m] = self.special_token_id
326
  gen_sequence[:, :, :, offset] = torch.where(
327
  gen_sequence[:, :, :, offset] == -1, #unknown_token,
328
  next_token,
@@ -333,7 +141,7 @@ class LMModel(nn.Module):
333
  # 1. reshape n_draw as bs * n_draw
334
  # 2. invert all short-sequences
335
  # 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
336
- out_codes, _, _ = pattern.revert_pattern_sequence(
337
  gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
338
  special_token=-1)
339
  # print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
@@ -341,12 +149,10 @@ class LMModel(nn.Module):
341
  out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
342
  out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
343
  print(out_codes.shape, 'o')
344
- for _ in range(7):
345
- out_codes = _shift(out_codes)
346
 
347
- # Clear Transformer k/v history (Different history is kept by 48x selfattn)
348
  for lay in self.transformer.layers:
349
  lay.self_attn.k_history = None
350
  lay.self_attn.v_history = None
351
 
352
- return out_codes
 
 
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from audiocraft.transformer import StreamingTransformer
 
 
4
  from torch import nn
5
+ from audiocraft.codebooks_patterns import DelayedPatternProvider
6
+ from audiocraft.conditioners import T5Conditioner
7
  import numpy as np
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class LMModel(nn.Module):
11
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def __init__(self,
13
+ n_q = 4,
14
+ card = 2048,
15
+ dim = 1536,
16
+ num_heads = 24,
17
+ hidden_scale = 4, # FFN of Transformer
18
+ ):
 
 
 
 
 
 
 
 
 
 
 
19
  super().__init__()
20
+ self.condition_provider = T5Conditioner(name='t5-large',
21
+ output_dim=dim)
22
  self.card = card # 2048 ?
23
  self.n_draw = 1 # replicate so many times the generation of each text in batch
24
+ # the batch is more expensive than n_draw as it re-runs the model bs times
25
+ # n_draw just draws more phonemes from the multinomial - after running the lm
26
  embed_dim = self.card + 1
27
  self.n_q = n_q
28
  self.dim = dim
29
+ self.pattern_provider = DelayedPatternProvider()
30
+ self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.transformer = StreamingTransformer(
32
  d_model=dim,
33
  num_heads=num_heads,
34
  dim_feedforward=int(hidden_scale * dim),
35
+ num_layers=48,
36
+ positional_embedding='sin',
37
+ )
38
+ self.out_norm = nn.LayerNorm(dim, eps=1e-5)
39
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
40
+ # self._init_weights(weight_init, depthwise_init, zero_bias_init)
41
+ # self.__dict__['_fsdp'] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
43
 
44
  def forward(self,
45
  sequence,
 
101
  max_gen_len), -1, dtype=torch.long,
102
  device=text_condition.device)
103
 
104
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.card)
105
  _, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
106
 
107
  # print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
 
121
  for offset in range(1, audiodur):
122
 
123
  # forward duplicates the query to nullcond - then cfg & returns deduplicate token
124
+ next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
125
  condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
126
  token_count=offset-1) # [bs, 4, 1, 2048]
127
 
 
130
 
131
  # MASK is not full 1---- HAS 4 x audioduration PATTERN
132
  m = mask[:, :, :, offset]
133
+ next_token[~m] = self.card
134
  gen_sequence[:, :, :, offset] = torch.where(
135
  gen_sequence[:, :, :, offset] == -1, #unknown_token,
136
  next_token,
 
141
  # 1. reshape n_draw as bs * n_draw
142
  # 2. invert all short-sequences
143
  # 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
144
+ out_codes = pattern.revert_pattern_sequence(
145
  gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
146
  special_token=-1)
147
  # print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
 
149
  out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
150
  out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
151
  print(out_codes.shape, 'o')
 
 
152
 
153
+ # Clear k/v cache (Different kv is saved by every 48x selfattn)
154
  for lay in self.transformer.layers:
155
  lay.self_attn.k_history = None
156
  lay.self_attn.v_history = None
157
 
158
+ return out_codes # bs*n_draw, duration -> repeat/shift in api.py
audiocraft/transformer.py CHANGED
@@ -1,26 +1,12 @@
1
- import typing as tp
2
- from einops import rearrange
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
- from torch.utils.checkpoint import checkpoint as torch_checkpoint
7
-
8
-
9
- _efficient_attention_backend: str = 'torch'
10
-
11
-
12
-
13
-
14
- def _get_attention_time_dimension(memory_efficient: bool) -> int:
15
- if _efficient_attention_backend == 'torch' and memory_efficient:
16
- return 2
17
- else:
18
- return 1
19
-
20
-
21
 
22
- def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
23
- dtype: torch.dtype = torch.float32) -> torch.Tensor:
 
 
24
  """Create sinusoidal positional embedding, with shape `[B, T, C]`.
25
 
26
  Args:
@@ -41,256 +27,102 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
41
  return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
42
 
43
 
44
- def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
45
- """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
46
- if n_rep == 1:
47
- return x
48
- if _efficient_attention_backend == 'torch' and memory_efficient:
49
- bs, n_kv_heads, slen, head_dim = x.shape
50
- return (
51
- x[:, :, None, :, :]
52
- .expand(bs, n_kv_heads, n_rep, slen, head_dim)
53
- .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
54
- )
55
- else:
56
- bs, slen, n_kv_heads, head_dim = x.shape
57
- return (
58
- x[:, :, :, None, :]
59
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
60
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
61
- )
62
-
63
-
64
-
65
-
66
-
67
  class StreamingMultiheadAttention(nn.Module):
68
 
69
  def __init__(self,
70
  embed_dim,
71
- num_heads, dropout: float = 0.0, bias: bool = True,
72
- causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
73
- memory_efficient: bool = False, attention_as_float32: bool = False,
74
- cross_attention: bool = False,
75
- kv_repeat: int = 1,
76
- device=None, dtype=None):
77
  super().__init__()
78
- factory_kwargs = {'device': device, 'dtype': dtype}
79
- if past_context is not None:
80
- assert causal
81
-
82
  self.embed_dim = embed_dim
83
-
84
  self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
85
- self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
86
-
87
- self.memory_efficient = memory_efficient
88
-
89
-
90
- self.cross_attention = cross_attention
91
-
92
  self.num_heads = num_heads
93
- self.dropout = dropout
94
- self.kv_repeat = kv_repeat
95
-
96
-
97
-
98
-
99
- self.custom = True #_is_custom(custom, memory_efficient)
100
- if not self.custom:
101
- print(f'{self.custom}')
102
- if self.custom:
103
- out_dim = embed_dim
104
- assert num_heads % kv_repeat == 0
105
- assert not cross_attention or kv_repeat == 1
106
- num_kv = num_heads // kv_repeat
107
- kv_dim = (embed_dim // num_heads) * num_kv
108
- out_dim += 2 * kv_dim
109
- in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
110
- # We try to follow the default PyTorch MHA convention, to easily compare results.
111
- self.in_proj_weight = in_proj.weight
112
- self.in_proj_bias = in_proj.bias
113
- if bias:
114
- self.in_proj_bias.data.zero_() # Following Pytorch convention
115
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
116
- if bias:
117
- self.out_proj.bias.data.zero_()
118
- else:
119
- assert kv_repeat == 1
120
- self.mha = nn.MultiheadAttention(
121
- embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
122
- **factory_kwargs)
123
-
124
-
125
- def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
126
- if not self.custom:
127
- # Support compat with regular MHA
128
- keys = [n for n, _ in self.mha.named_parameters()]
129
- for key in keys:
130
- if prefix + key in state_dict:
131
- state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
132
- super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
133
-
134
-
135
 
136
-
137
-
138
-
139
-
140
- def forward(self,
141
- query,
142
- key=None, # ignores those 2 args if not self.cross_attn
143
  value=None):
144
-
145
-
146
- # time_dim = _get_attention_time_dimension(self.memory_efficient)
147
- # if time_dim == 2:
148
  layout = "b h t d"
149
- # else:
150
- # layout = "b t h d"
151
- # dtype = query.dtype
152
-
153
-
154
-
155
-
156
-
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- if self.custom:
 
 
159
 
160
- if self.cross_attention:
161
- # Different queries, keys, values, we have to spit manually the weights
162
- # before applying the linear.
163
- dim = self.in_proj_weight.shape[0] // 3
164
- if self.in_proj_bias is None:
165
- bias_q, bias_k, bias_v = None, None, None
166
- else:
167
- bias_q = self.in_proj_bias[:dim]
168
- bias_k = self.in_proj_bias[dim: 2 * dim]
169
- bias_v = self.in_proj_bias[2 * dim:]
170
- q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
171
- # todo: when streaming, we could actually save k, v and check the shape actually match.
172
- k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
173
- v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
174
-
175
- q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
176
- # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
177
- else:
178
- # 1st projected makes k,v (instantaneous)
179
- # 2nd cat
180
 
181
-
182
- # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
183
-
184
- projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
185
- if self.kv_repeat == 1:
186
- # if time_dim == 2:
187
- bound_layout = "b h p t d"
188
- # else:
189
- # bound_layout = "b t p h d"
190
- packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
191
- q, k, v = packed.unbind(dim=2)
192
-
193
-
194
- if self.k_history is not None:
195
- #
196
- # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
197
- # has to be 4D with batch 1 due to single condition 3=seqlen
198
- # 24 heads 64 dimofh
199
- self.k_history = torch.cat([self.k_history, k], 2)
200
- self.v_history = torch.cat([self.v_history, v], 2)
201
 
202
- else:
203
- # init on 1st token (for all 47 transf layers)
204
- print(f'AudioGen kv cache Flush')
205
- self.k_history = k
206
- self.v_history = v
207
-
208
- k = self.k_history
209
- v = self.v_history
210
 
211
 
212
-
213
- # KV COMPLETION ONLY ON SELF ATTENTION
214
- # print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape)
215
-
216
 
217
- if self.memory_efficient:
218
- # print('EVER IN MEMORY EFFICIENT A')
219
-
220
 
221
- p = self.dropout if self.training else 0
222
- if _efficient_attention_backend == 'torch':
223
- x = torch.nn.functional.scaled_dot_product_attention(
224
- q, k, v, is_causal=False, dropout_p=p
225
- )
226
-
227
- x = x.to(q.dtype)
228
- x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
229
- x = self.out_proj(x)
230
  return x
231
 
232
 
233
- class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
234
- # INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
235
 
236
  def __init__(self,
237
- d_model: int,
238
- num_heads: int,
239
- dim_feedforward: int = 2048,
240
- dropout: float = 0.1,
241
- bias_ff: bool = True,
242
- bias_attn: bool = True,
243
- custom: bool = False,
244
- memory_efficient: bool = False,
245
- attention_as_float32: bool = False,
246
- cross_attention: bool = False,
247
- attention_dropout: tp.Optional[float] = None,
248
- kv_repeat: int = 1,
249
- norm: str = 'layer_norm',
250
- device=None,
251
- dtype=None,
252
- **kwargs):
253
-
254
 
255
- super().__init__() #d_model, num_heads, dim_feedforward, dropout,
256
- #device=device, dtype=dtype, batch_first=True, **kwargs)
257
- # print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n')
258
- # -- EN Layer
259
- # DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function
260
 
261
- # -- EN layer
262
 
263
- factory_kwargs = {'device': device, 'dtype': dtype}
264
- # Redefine self_attn to our streaming multi-head attention
265
- attn_kwargs: tp.Dict[str, tp.Any] = {
266
- 'embed_dim': d_model,
267
- 'num_heads': num_heads,
268
- 'dropout': dropout if attention_dropout is None else attention_dropout,
269
- 'bias': bias_attn,
270
- 'custom': custom,
271
- 'memory_efficient': memory_efficient,
272
- 'attention_as_float32': attention_as_float32,
273
- }
274
- self.self_attn = StreamingMultiheadAttention(
275
- kv_repeat=kv_repeat,
276
- **attn_kwargs,
277
- **factory_kwargs) # type: ignore
278
- # Redefine feedforward layers to expose bias parameter
279
- self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
280
- self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
281
- # print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n') # always
282
-
283
-
284
- self.cross_attention= None
285
- if cross_attention:
286
- self.cross_attention = StreamingMultiheadAttention(
287
- cross_attention=True,
288
- **attn_kwargs,
289
- **factory_kwargs)
290
-
291
- self.dropout_cross = nn.Dropout(dropout)
292
-
293
- self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
294
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
295
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
296
 
@@ -316,59 +148,34 @@ class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
316
 
317
  class StreamingTransformer(nn.Module):
318
 
319
- def __init__(self, d_model: int,
320
- num_heads: int,
321
- num_layers: int,
322
- dim_feedforward: int = 2048,
323
- dropout: float = 0.1,
324
- bias_ff: bool = True,
325
- bias_attn: bool = True,
326
- custom: bool = False,
327
- memory_efficient: bool = False,
328
- attention_as_float32: bool = False,
329
- cross_attention: bool = False,
330
  positional_embedding: str = 'sin',
331
- max_period: float = 10_000,
332
- layer_class=StreamingTransformerLayer,
333
- checkpointing: str = 'none',
334
- device=None,
335
- dtype=None,
336
- **kwargs):
337
  super().__init__()
338
  assert d_model % num_heads == 0
339
 
340
  self.positional_embedding = positional_embedding
341
  self.max_period = max_period
342
-
343
-
344
-
345
- # self._stream_off = 0 # the llm should reinitialize this at ery generate()
346
-
347
- self.checkpointing = checkpointing
348
-
349
-
350
-
351
-
352
  self.layers = nn.ModuleList()
353
  for idx in range(num_layers):
354
  self.layers.append(
355
- layer_class(
356
- d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
357
- dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
358
- custom=custom,
359
- memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
360
- cross_attention=cross_attention,
361
- device=device, dtype=dtype, **kwargs))
362
-
363
- if self.checkpointing != 'none':
364
- for layer in self.layers:
365
- # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
366
- # backward hook inside of FSDP...
367
- layer._magma_checkpointed = True # type: ignore
368
-
369
-
370
 
371
- def forward(self, x: torch.Tensor, *args, **kwargs):
 
 
 
372
 
373
  B, T, C = x.shape
374
 
@@ -376,7 +183,7 @@ class StreamingTransformer(nn.Module):
376
  if self.positional_embedding in ['sin', 'sin_rope']:
377
 
378
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
379
- positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1)
380
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
381
  x = x + pos_emb
382
 
@@ -384,6 +191,6 @@ class StreamingTransformer(nn.Module):
384
 
385
  for j, lay in enumerate(self.layers):
386
  # print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
387
- x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # cross_attention_src = txt-cond
388
  # each layer (mha) keeps history of its own k,v for all tokens
389
  return x
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
4
+ from einops import rearrange
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ def create_sin_embedding(positions,
7
+ dim,
8
+ max_period = 10000,
9
+ dtype = torch.float32):
10
  """Create sinusoidal positional embedding, with shape `[B, T, C]`.
11
 
12
  Args:
 
27
  return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class StreamingMultiheadAttention(nn.Module):
31
 
32
  def __init__(self,
33
  embed_dim,
34
+ num_heads,
35
+ cross_attention = False):
 
 
 
 
36
  super().__init__()
37
+ self.cross_attention = cross_attention
 
 
 
38
  self.embed_dim = embed_dim
 
39
  self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
40
+ self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
 
 
 
 
 
 
41
  self.num_heads = num_heads
42
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
43
+ self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
44
+ dtype=torch.float))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ def forward(self,
47
+ query,
48
+ key=None,
 
 
 
 
49
  value=None):
 
 
 
 
50
  layout = "b h t d"
51
+ if self.cross_attention:
52
+
53
+ # Different queries, keys, values, we have to spit manually the in_proj_weight
54
+
55
+ dim = self.in_proj_weight.shape[0] // 3
56
+
57
+ q = nn.functional.linear(query, self.in_proj_weight[:dim])
58
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
59
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
60
+
61
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
62
+ # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
63
+ else:
64
+ # 1st projected makes k,v (instantaneous)
65
+ # 2nd cat
66
+
67
+
68
+ # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
69
+
70
+ projected = nn.functional.linear(query, self.in_proj_weight)
71
 
72
+ bound_layout = "b h p t d"
73
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
74
+ q, k, v = packed.unbind(dim=2)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ if self.k_history is not None:
78
+ #
79
+ # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
80
+ # has to be 4D with batch 1 due to single condition 3=seqlen
81
+ # 24 heads 64 dimofh
82
+ self.k_history = torch.cat([self.k_history, k], 2)
83
+ self.v_history = torch.cat([self.v_history, v], 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ else:
86
+ # init on 1st token (for all 47 transf layers)
87
+ print(f'AudioGen kv cache Flush')
88
+ self.k_history = k
89
+ self.v_history = v
90
+
91
+ k = self.k_history
92
+ v = self.v_history
93
 
94
 
 
 
 
 
95
 
96
+ # KV COMPLETION ONLY ON SELF ATTENTION
 
 
97
 
98
+ x = torch.nn.functional.scaled_dot_product_attention(
99
+ q, k, v, is_causal=False, dropout_p=0
100
+ )
101
+
102
+ x = x.to(q.dtype)
103
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
104
+ x = self.out_proj(x)
 
 
105
  return x
106
 
107
 
108
+ class StreamingTransformerLayer(nn.Module):
 
109
 
110
  def __init__(self,
111
+ d_model,
112
+ num_heads,
113
+ dim_feedforward):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
115
 
116
+ super().__init__()
117
 
118
+ self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
119
+ num_heads=num_heads)
120
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
121
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
122
+ self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
123
+ num_heads=num_heads,
124
+ cross_attention=True)
125
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
127
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
128
 
 
148
 
149
  class StreamingTransformer(nn.Module):
150
 
151
+ def __init__(self,
152
+ d_model=1536,
153
+ num_heads=24,
154
+ num_layers=48,
155
+ dim_feedforward=6144,
156
+ cross_attention = True,
 
 
 
 
 
157
  positional_embedding: str = 'sin',
158
+ max_period: float = 10_000
159
+ ):
 
 
 
 
160
  super().__init__()
161
  assert d_model % num_heads == 0
162
 
163
  self.positional_embedding = positional_embedding
164
  self.max_period = max_period
 
 
 
 
 
 
 
 
 
 
165
  self.layers = nn.ModuleList()
166
  for idx in range(num_layers):
167
  self.layers.append(
168
+ StreamingTransformerLayer(
169
+ d_model=d_model,
170
+ num_heads=num_heads,
171
+ dim_feedforward=dim_feedforward
172
+ )
173
+ )
 
 
 
 
 
 
 
 
 
174
 
175
+ def forward(self,
176
+ x,
177
+ token_count=None,
178
+ cross_attention_src=None):
179
 
180
  B, T, C = x.shape
181
 
 
183
  if self.positional_embedding in ['sin', 'sin_rope']:
184
 
185
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
186
+ positions = positions + token_count #offsets.view(-1, 1, 1)
187
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
188
  x = x + pos_emb
189
 
 
191
 
192
  for j, lay in enumerate(self.layers):
193
  # print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
194
+ x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond
195
  # each layer (mha) keeps history of its own k,v for all tokens
196
  return x
msinference.py CHANGED
@@ -293,10 +293,41 @@ with open(f"Utils/all_langs.csv") as f:
293
 
294
 
295
 
296
- # LOAD hun / ron / serbian - rmc-script_latin / cyrillic-Carpathian (not Vlax)
297
-
298
-
299
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  def has_cyrillic(text):
302
  # https://stackoverflow.com/questions/48255244/python-check-if-a-string-contains-cyrillic-characters
@@ -358,7 +389,7 @@ class TextForeign(object):
358
  def foreign(text=None, # list of text
359
  lang='romanian',
360
  speed=None):
361
-
362
  lang = lang.lower() # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
363
 
364
  # https://huggingface.co/spaces/mms-meta/MMS
@@ -367,11 +398,11 @@ def foreign(text=None, # list of text
367
 
368
  lang_code = 'hun'
369
 
370
- elif 'ser' in lang or 'bosn' in lang or 'macedon' in lang or 'croatia' in lang:
371
 
372
  if has_cyrillic(text[0]): # check 0-th sentence if is cyrillic
373
 
374
- lang_code = 'rmc-script_cyrillic' # romani carpathian (also has lating/cyrillic Vlax)
375
 
376
  else:
377
 
@@ -387,6 +418,11 @@ def foreign(text=None, # list of text
387
  lang_code = 'deu'
388
  speed = 1.14 if speed is None else speed
389
 
 
 
 
 
 
390
  else:
391
 
392
  lang_code = lang.split()[0].strip()
@@ -431,20 +467,29 @@ def foreign(text=None, # list of text
431
  x = []
432
 
433
  for _t in text:
434
-
435
-
436
-
437
  if is_uroman:
438
  uroman_dir = "Utils/uroman"
439
  assert os.path.exists(uroman_dir)
440
  uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
441
  _t = text_mapper.uromanize(_t, uroman_pl)
442
 
443
- _t = _t.lower().replace("ţ", "ț").replace('ț','ts').replace('î', 'u') # Parse STTS2 pronounciation on tts_mult()
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  _t = text_mapper.filter_oov(_t, lang=lang)
446
 
447
- # print(f'{speed=}\n\n\n\n_______________________________ {_t}')
448
  stn_tst = text_mapper.get_text(_t, hps)
449
  with torch.no_grad():
450
  x_tst = stn_tst.unsqueeze(0).to(device)
@@ -468,14 +513,3 @@ def foreign(text=None, # list of text
468
  original_rate=16000,
469
  target_rate=24000)[0, :] # reshapes (64,) -> (1,64)
470
  return x
471
-
472
-
473
-
474
-
475
- # LANG = 'eng'
476
- # _t = 'Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: text: string to convert to a sequence'
477
-
478
- # x = synthesize(text=_t, lang=LANG, speed=1.14)
479
- # audiofile.write('_r.wav', x, 16000) # mms-tts = 16,000
480
-
481
-
 
293
 
294
 
295
 
296
+ # LOAD hun / ron / serbian - rmc-script_latin / cyrillic-Carpathian (not Vlax)
297
+ # ==============================================================================================
298
+ import re
299
+ from num2words import num2words
300
+
301
+ PHONEME_MAP = {
302
+ 'q': 'ku',
303
+ 'w': 'aou',
304
+ 'z': 's',
305
+ "š": "s",
306
+ 'th': 'ta',
307
+ 'v': 'vv',
308
+ # "ć": "č",
309
+ # "đ": "ď",
310
+ # "lj": "ľ",
311
+ # "nj": "ň",
312
+ "ž": "z",
313
+ # "c": "č"
314
+ }
315
+
316
+ # ALLOWED_PHONEMES = set("šč_bďph`-3žt 'ľzj5yuoóx1vfnaiedt́sṁkň2rčlg")
317
+
318
+ def number_to_phonemes(match):
319
+ number = int(match.group())
320
+ words = num2words(number, lang='sr')
321
+ return fix_phones(words.lower())
322
+ # return words
323
+
324
+ def fix_phones(text):
325
+ for src, target in PHONEME_MAP.items():
326
+ text = text.replace(src, target)
327
+ # text = re.sub(r'\s+', '` `', text) #.strip() #.lower()
328
+ # text = re.sub(r'\s+', '_ _', text) # almost proper pausing
329
+
330
+ return text.replace(',', '_ _').replace('.', '_ _')
331
 
332
  def has_cyrillic(text):
333
  # https://stackoverflow.com/questions/48255244/python-check-if-a-string-contains-cyrillic-characters
 
389
  def foreign(text=None, # list of text
390
  lang='romanian',
391
  speed=None):
392
+
393
  lang = lang.lower() # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
394
 
395
  # https://huggingface.co/spaces/mms-meta/MMS
 
398
 
399
  lang_code = 'hun'
400
 
401
+ elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]):
402
 
403
  if has_cyrillic(text[0]): # check 0-th sentence if is cyrillic
404
 
405
+ lang_code = 'rmc-script_cyrillic' # romani carpathian (also has latin / cyrillic Vlax)
406
 
407
  else:
408
 
 
418
  lang_code = 'deu'
419
  speed = 1.14 if speed is None else speed
420
 
421
+ elif 'alban' in lang:
422
+
423
+ lang_code = 'sqi'
424
+ speed = 1.04 if speed is None else speed
425
+
426
  else:
427
 
428
  lang_code = lang.split()[0].strip()
 
467
  x = []
468
 
469
  for _t in text:
 
 
 
470
  if is_uroman:
471
  uroman_dir = "Utils/uroman"
472
  assert os.path.exists(uroman_dir)
473
  uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
474
  _t = text_mapper.uromanize(_t, uroman_pl)
475
 
476
+ _t = _t.lower()
477
+
478
+ if lang_code == 'rmc-script_latin':
479
+
480
+ _t = re.sub(r'\d+', number_to_phonemes, _t)
481
+ _t = fix_phones(_t)
482
+
483
+ elif lang_code == 'ron':
484
+
485
+ _t = _t.replace("ţ", "ț"
486
+ ).replace('ț','ts').replace('î', 'u')
487
+
488
+ # /data/dkounadis/.hf7/hub/models--facebook--mms-tts/snapshots/44cc7fb408064ef9ea6e7c59130d88cac1274671/models/rmc-script_latin/vocab.txt
489
 
490
  _t = text_mapper.filter_oov(_t, lang=lang)
491
 
492
+ print(f'{speed=}\n\n\n\n_______________________________ {_t}')
493
  stn_tst = text_mapper.get_text(_t, hps)
494
  with torch.no_grad():
495
  x_tst = stn_tst.unsqueeze(0).to(device)
 
513
  original_rate=16000,
514
  target_rate=24000)[0, :] # reshapes (64,) -> (1,64)
515
  return x
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,7 +2,7 @@ torch
2
  torchaudio
3
  numpy
4
  audiofile
5
- audresample
6
  cached_path
7
  einops
8
  flask
 
2
  torchaudio
3
  numpy
4
  audiofile
5
+ num2words
6
  cached_path
7
  einops
8
  flask