thuannh commited on
Commit
d54c82f
·
verified ·
1 Parent(s): b46f992

Delete custom_component.py

Browse files
Files changed (1) hide show
  1. custom_component.py +0 -172
custom_component.py DELETED
@@ -1,172 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import whisper
4
- from whisper.model import AudioEncoder, ModelDimensions
5
- from typing import Dict, Optional
6
- from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables
7
- from huggingface_hub import hf_hub_download
8
- import torch.nn.functional as F
9
- import os
10
- from typing import List, Optional, Union
11
- import io
12
- import urllib
13
- from tqdm import tqdm
14
- import torchaudio
15
- _HF_MODELS = {
16
- "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt",
17
- }
18
- def available_models() -> List[str]:
19
- """Returns the names of available models"""
20
- return list(_HF_MODELS.keys())
21
- def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
22
- os.makedirs(root, exist_ok=True)
23
-
24
- expected_sha256 = url.split("/")[-2]
25
- download_target = os.path.join(root, os.path.basename(url))
26
-
27
- if os.path.exists(download_target) and not os.path.isfile(download_target):
28
- raise RuntimeError(f"{download_target} exists and is not a regular file")
29
-
30
- if os.path.isfile(download_target):
31
- with open(download_target, "rb") as f:
32
- model_bytes = f.read()
33
- return model_bytes if in_memory else download_target
34
-
35
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
36
- with tqdm(
37
- total=int(source.info().get("Content-Length")),
38
- ncols=80,
39
- unit="iB",
40
- unit_scale=True,
41
- unit_divisor=1024,
42
- ) as loop:
43
- while True:
44
- buffer = source.read(8192)
45
- if not buffer:
46
- break
47
-
48
- output.write(buffer)
49
- loop.update(len(buffer))
50
-
51
- model_bytes = open(download_target, "rb").read()
52
- return model_bytes if in_memory else download_target
53
- class CustomWhisperEncoder(nn.Module):
54
- """
55
- Lightweight wrapper that only loads the AudioEncoder part of Whisper
56
- """
57
- def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,):
58
- super().__init__()
59
- if device is None:
60
- device = "cuda" if torch.cuda.is_available() else "cpu"
61
- if download_root is None:
62
- default = os.path.join(os.path.expanduser("~"), ".cache")
63
- download_root = os.path.dirname(os.path.realpath(__file__)) #os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
64
-
65
- if name in _HF_MODELS:
66
- checkpoint_file = _download(_HF_MODELS[name], download_root, in_memory)
67
- elif os.path.isfile(name):
68
- checkpoint_file = open(name, "rb").read() if in_memory else name
69
- else:
70
- raise RuntimeError(
71
- f"Model {name} not found; available models = {available_models()}"
72
- )
73
-
74
- # Load weights
75
- with (
76
- io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
77
- ) as fp:
78
- checkpoint = torch.load(fp, map_location=device)
79
- del checkpoint_file
80
- dims = ModelDimensions(**checkpoint["dims"])
81
- self.encoder = AudioEncoder(
82
- dims.n_mels,
83
- dims.n_audio_ctx,
84
- dims.n_audio_state,
85
- dims.n_audio_head,
86
- dims.n_audio_layer,
87
- )
88
-
89
- self.encoder.load_state_dict(checkpoint["model_state_dict"])
90
-
91
- if device:
92
- self.to(device)
93
-
94
- self.eval()
95
-
96
- def forward(self, mel: torch.Tensor):
97
- return self.encoder(mel)
98
-
99
- class CustomRQBottleneckTransformer(RQBottleneckTransformer):
100
- def __init__(self, *args, **kwargs):
101
- super().__init__(*args, **kwargs)
102
- @classmethod
103
- def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
104
- repo_id=None, filename=None, local_filename=None):
105
- if repo_id is None and filename is None and local_filename is None:
106
- if ":" in ref:
107
- repo_id, filename = ref.split(":", 1)
108
- else:
109
- local_filename = ref
110
- if not local_filename:
111
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
112
-
113
- # Load the spec
114
- spec = torch.load(local_filename)
115
-
116
- # Create instance with minimal required components
117
- instance = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {}))))
118
-
119
- # Load only necessary state dict entries
120
- required_components = {
121
- 'rq', 'mlp', 'mlp_ln'
122
- }
123
- filtered_state_dict = {
124
- k: v for k, v in spec['state_dict'].items()
125
- if any(k.startswith(comp) for comp in required_components)
126
- }
127
-
128
- instance.load_state_dict(filtered_state_dict, strict=False)
129
- instance.eval()
130
- return instance
131
-
132
- def load_encoder(self, device=None):
133
- if self.whmodel is not None: return
134
- device = device or self.device
135
- # Use our custom encoder-only model
136
- if self.whmodel is None:
137
- encoder = CustomWhisperEncoder(self.whisper_model_name, device=device)
138
- self.whmodel = [encoder]
139
- multilingual = not self.whisper_model_name.endswith('.en')
140
- self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
141
-
142
- def optimzed_encode_mel(self, mel):
143
- assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
144
- self.load_encoder()
145
- n = mel.shape[-1]
146
- if n > whisper.audio.N_FRAMES:
147
- padding = 0
148
- padded = mel[:,:,:whisper.audio.N_FRAMES]
149
- else:
150
- padding = -n % whisper.audio.N_FRAMES
151
- padded = F.pad(mel, (0, padding), value=-1.5)
152
- embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2]
153
- stoks = self.quantize(embs)
154
- if self.tunables.mask_embs:
155
- return stoks[:,:n//2//self.downsample]
156
- else:
157
- return stoks
158
- # overide
159
- def encode_audio(self, audio):
160
- if isinstance(audio, str):
161
- x, sr = torchaudio.load(audio)
162
- x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
163
- audio = x.unsqueeze(0)
164
- return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device))
165
-
166
- if __name__ == "__main__":
167
- # Load the model
168
- vqmodel = CustomRQBottleneckTransformer.load_vq_only(
169
- "whisper-vq-stoks-v3-7lang-fixed.model"
170
- ).to("cuda")
171
- vqmodel.load_encoder('cuda')
172
- vqmodel.eval()